diff --git a/.clinerules b/.clinerules index c7f5ccc..2a04ac7 100644 --- a/.clinerules +++ b/.clinerules @@ -4,6 +4,10 @@ Simplified Chinese +## Development + +- TDD + # Cline's Memory Bank I am Cline, an expert software engineer with a unique characteristic: my memory resets completely between sessions. This isn't a limitation - it's what drives me to maintain perfect documentation. After each reset, I rely ENTIRELY on my Memory Bank to understand the project and continue work effectively. I MUST read ALL memory bank files at the start of EVERY task - this is not optional. diff --git a/memory-bank/activeContext.md b/memory-bank/activeContext.md index e3b3941..629e780 100644 --- a/memory-bank/activeContext.md +++ b/memory-bank/activeContext.md @@ -1,24 +1,29 @@ # 当前工作重点 -当前的工作重点是根据 `progress.md` 中的待办事项,开始对项目进行优化和功能增强。在完成了全面的测试覆盖后,我们对现有代码的稳定性和正确性有了很强的信心,可以安全地进行重构。 +当前没有正在进行的紧急任务。项目处于稳定状态,可以根据 `progress.md` 中的待办事项列表来规划接下来的工作。 ## 近期变更 -- **完成 `server_test.go`**: - - 补全了 `server_test.go` 中所有待办的测试用例,包括对 `X-Accel-Redirect` 和路径穿越攻击的测试。 - - 对所有测试用例的注释进行了审查和修正,确保注释与代码的实际行为保持一致。 - - 所有测试均已通过,为后续的开发和重构工作奠定了坚实的基础。 -- **更新 `progress.md`**: - - 将“增加更全面的单元测试和集成测试”标记为已完成。 +- **实现流式请求的阶段二优化 (基于临时文件)**: + - **问题**: `StreamObject` 使用内存中的 `bytes.Buffer` 来缓存下载内容,在处理大文件时会导致内存占用过高。 + - **解决方案**: 将 `StreamObject` 的 `Buffer` 替换为 `*os.File`,将下载的字节流直接写入临时文件。 + - **实现细节**: + - `startOrJoinStream` (生产者) 现在创建临时文件并将下载内容写入其中。下载成功后,该临时文件会被重命名为最终的缓存文件。 + - `serveRangedRequest` (消费者) 为了解决文件句柄的生命周期竞态问题,采用了 `syscall.Dup()` 来复制生产者的文件描述符。每个消费者都使用自己独立的、复制出来的文件句柄来读取数据,从而与生产者的文件句柄生命周期解耦。 + - **结果**: 解决了大文件的内存占用问题,并通过了所有测试,包括一个专门为验证并发安全性和竞态条件而设计的新测试。 ## 后续步骤 -1. **代码重构**: - - 根据 `progress.md` 的待办事项,首先考虑对 `server.go` 中的复杂函数(如 `streamOnline`)进行重构,以提高代码的可读性和可维护性。 -2. **功能增强**: - - 在代码结构优化后,可以开始考虑实现新的功能,例如增加对 S3 等新存储后端的支持,或实现更复杂的负载均衡策略。 -3. **持续文档更新**: - - 在进行重构或添加新功能时,同步更新 `systemPatterns.md` 和其他相关文档,以记录新的设计决策。 +下一个合乎逻辑的步骤是处理 `progress.md` 中列出的待办事项,例如: + +1. **功能增强**: + - **多后端支持**: 增加对 S3、Redis 等其他存储后端的支持。 + - **高级负载均衡**: 实现更复杂的负载均衡策略。 + - **监控**: 集成 Prometheus 指标。 +2. **代码重构**: + - 对 `server.go` 中的复杂函数进行重构,提高可读性。 + +在开始新任务之前,需要与您确认下一个工作重点。 ## 重要模式与偏好 diff --git a/memory-bank/progress.md b/memory-bank/progress.md index 99e3327..44b7741 100644 --- a/memory-bank/progress.md +++ b/memory-bank/progress.md @@ -29,11 +29,15 @@ ## 待办事项 - **功能增强**: + - **实现流式范围请求 (Done)**: 重构了 Ranged Request 处理流程,实现了边下载边响应。 + - **阶段一 (已完成)**: 使用 `sync.Cond` 和内存 `bytes.Buffer` 实现。 + - **性能优化 (已完成)**: 移除了 `StreamObject` 清理逻辑中的固定延迟,改为依赖 Go 的 GC 机制,显著提升了并发请求下的测试性能。 + - **阶段二 (已完成)**: 使用临时文件和 `syscall.Dup()` 系统调用优化,解决了大文件内存占用问题。 - 目前只支持本地文件存储,未来可以考虑增加对其他存储后端(如 S3、Redis)的支持。 - 增加更复杂的负载均衡策略,而不仅仅是“选择最快”。 - 增加更详细的监控和指标(如 Prometheus metrics)。 - **代码优化**: - - 对 `server.go` 中的一些复杂函数(如 `streamOnline`)进行重构,以提高可读性和可维护性。 + - 在完成流式请求功能后,对 `server.go` 中的一些复杂函数(如 `streamOnline`)进行重构,以提高可读性和可维护性。 ## 已知问题 diff --git a/memory-bank/systemPatterns.md b/memory-bank/systemPatterns.md index 462d67e..625785e 100644 --- a/memory-bank/systemPatterns.md +++ b/memory-bank/systemPatterns.md @@ -23,20 +23,36 @@ ### 2. 生产者-消费者模式 (Producer-Consumer) -在文件下载过程中,使用了生产者-消费者模式。 +在文件下载和流式响应(Ranged Request)中,使用了生产者-消费者模式。 -- **生产者**: `tryUpstream` 函数中的 goroutine 负责从上游服务器读取数据块(chunk),并将其放入一个 `chan Chunk` 中。 -- **消费者**: `streamOnline` 函数中的代码从 `chan Chunk` 中读取数据,并执行两个操作: - 1. 将数据写入 `bytes.Buffer`,供后续的请求者使用。 - 2. 将数据写入本地临时文件,用于持久化缓存。 +- **生产者 (`startOrJoinStream`)**: + - 为每个首次请求的文件启动一个 goroutine。 + - 负责从最快的上游服务器下载文件内容。 + - 将内容**直接写入一个临时文件** (`*os.File`) 中,而不是写入内存缓冲区。 + - 通过 `sync.Cond` 广播下载进度(已写入的字节数 `Offset`)。 + - 下载成功后,将临时文件重命名为最终的缓存文件。 +- **消费者 (`serveRangedRequest`)**: + - 当收到一个范围请求时,它会找到或等待对应的 `StreamObject`。 + - 为了安全地并发读取正在被生产者写入的文件,消费者会使用 `syscall.Dup()` **复制临时文件的文件描述符**。 + - 每个消费者都通过自己独立的、复制出来的文件句柄 (`*os.File`) 读取所需范围的数据,这避免了与生产者或其他消费者发生文件句柄状态的冲突。 + - 消费者根据生产者的 `Offset` 进度和 `sync.Cond` 信号来等待其请求范围的数据变为可用。 -### 3. 并发访问控制 (Mutex for Concurrent Access) +### 3. 并发访问控制与对象生命周期管理 为了处理多个客户端同时请求同一个文件的情况,系统使用了 `sync.Mutex` 和一个 `map[string]*StreamObject`。 -- 当第一个请求到达时,它会获得一个锁,并创建一个 `StreamObject` 来代表这个正在进行的下载任务。 -- 后续对同一文件的请求会发现 `StreamObject` 已存在,它们不会再次向上游发起请求,而是会等待并从这个共享的 `StreamObject` 中读取数据。 -- 下载完成后,`StreamObject` 会从 map 中移除。 +- 当第一个对某文件的请求(我们称之为“消费者”)到达时,它会获得一个锁,并创建一个 `StreamObject` 来代表这个正在进行的下载任务,然后启动一个“生产者”goroutine 来执行下载。 +- 后续对同一文件的请求会发现 `StreamObject` 已存在于 map 中,它们不会再次启动下载,而是会共享这个对象。 + +**`StreamObject` 生命周期管理 (基于 GC)** + +我们采用了一种简洁且高效的、依赖 Go 语言垃圾回收(GC)的模式来管理 `StreamObject` 的生命周期: + +1. **生产者负责移除**: 下载 goroutine(生产者)在完成其任务(无论成功或失败)后,其唯一的职责就是将 `StreamObject` 从全局的 `map[string]*StreamObject` 中移除。 +2. **消费者持有引用**: 与此同时,所有正在处理该文件请求的 HTTP Handler(消费者)仍然持有对该 `StreamObject` 的引用。 +3. **GC 自动回收**: 因为消费者们还持有引用,Go 的 GC 不会回收这个对象。只有当最后一个消费者处理完请求、其函数栈帧销毁后,对 `StreamObject` 的最后一个引用才会消失。此时,GC 会在下一次运行时自动回收该对象的内存。 + +这个模式避免了复杂的引用计数或定时器,代码更简洁,并且从根本上解决了之前因固定延迟导致的性能问题。 ### 4. 中间件 (Middleware) diff --git a/server.go b/server.go index 1ad8c59..c8f7962 100644 --- a/server.go +++ b/server.go @@ -1,19 +1,19 @@ package cacheproxy import ( - "bytes" "context" "errors" "io" "log/slog" - "net" "net/http" "os" "path/filepath" "regexp" "slices" + "strconv" "strings" "sync" + "syscall" "time" ) @@ -49,44 +49,16 @@ var ( ) type StreamObject struct { - Headers http.Header - Buffer *bytes.Buffer - Offset int + Headers http.Header + TempFile *os.File // The temporary file holding the download content. + Offset int64 // The number of bytes written to TempFile. + Done bool + Error error - ctx context.Context - wg *sync.WaitGroup -} + mu *sync.Mutex + cond *sync.Cond -func (memoryObject *StreamObject) StreamTo(w io.Writer, wg *sync.WaitGroup) error { - defer wg.Done() - offset := 0 - if w == nil { - w = io.Discard - } -OUTER: - for { - select { - case <-memoryObject.ctx.Done(): - break OUTER - default: - } - - newOffset := memoryObject.Offset - if newOffset == offset { - time.Sleep(time.Millisecond) - continue - } - bytes := memoryObject.Buffer.Bytes()[offset:newOffset] - written, err := w.Write(bytes) - if err != nil { - return err - } - - offset += written - } - time.Sleep(time.Millisecond) - _, err := w.Write(memoryObject.Buffer.Bytes()[offset:]) - return err + fileWrittenCh chan struct{} // Closed when the file is fully written and renamed. } type Server struct { @@ -105,11 +77,6 @@ func NewServer(config Config) *Server { } } -type Chunk struct { - buffer []byte - error error -} - func (server *Server) serveFile(w http.ResponseWriter, r *http.Request, path string) { if location := r.Header.Get(server.Storage.Local.Accel.EnableByHeader); server.Storage.Local.Accel.EnableByHeader != "" && location != "" { relPath, err := filepath.Rel(server.Storage.Local.Path, path) @@ -147,26 +114,29 @@ func (server *Server) HandleRequestWithCache(w http.ResponseWriter, r *http.Requ slog.With("status", localStatus, "mtime", mtime, "error", err, "key", fullpath).Debug("local status checked") if os.IsPermission(err) { http.Error(w, err.Error(), http.StatusForbidden) + return } else if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) - } else if localStatus != localNotExists { - if localStatus == localExistsButNeedHead { - if ranged { - <-server.streamOnline(nil, r, mtime, fullpath) - server.serveFile(w, r, fullpath) - } else { - server.streamOnline(w, r, mtime, fullpath) - } - } else { - server.serveFile(w, r, fullpath) - } + return + } + + if localStatus == localExists { + server.serveFile(w, r, fullpath) + return + } + + // localExistsButNeedHead or localNotExists + // Both need to go online. + if ranged { + server.serveRangedRequest(w, r, fullpath, mtime) } else { - if ranged { - <-server.streamOnline(nil, r, mtime, fullpath) - server.serveFile(w, r, fullpath) - } else { - server.streamOnline(w, r, mtime, fullpath) + // For full requests, we wait for the download to complete and then serve the file. + // This maintains the original behavior for now. + ch := server.startOrJoinStream(r, mtime, fullpath) + if ch != nil { + <-ch } + server.serveFile(w, r, fullpath) } } @@ -214,240 +184,348 @@ func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key str return localNotExists, zeroTime, nil } -func (server *Server) streamOnline(w http.ResponseWriter, r *http.Request, mtime time.Time, key string) <-chan struct{} { +func (server *Server) serveRangedRequest(w http.ResponseWriter, r *http.Request, key string, mtime time.Time) { + server.lu.Lock() memoryObject, exists := server.o[r.URL.Path] - locked := false - defer func() { - if locked { - server.lu.Unlock() - locked = false - } - }() if !exists { - server.lu.Lock() - locked = true - - memoryObject, exists = server.o[r.URL.Path] - } - if exists { - if locked { - server.lu.Unlock() - locked = false - } - - if w != nil { - memoryObject.wg.Add(1) - for k := range memoryObject.Headers { - v := memoryObject.Headers.Get(k) - w.Header().Set(k, v) - } - - if err := memoryObject.StreamTo(w, memoryObject.wg); err != nil { - slog.With("error", err).Warn("failed to stream response with existing memory object") - } - } - return preclosedChan - } else { - slog.With("mtime", mtime).Debug("checking fastest upstream") - selectedIdx, response, chunks, err := server.fastestUpstream(r, mtime) - if chunks == nil && mtime != zeroTime { - slog.With("upstreamIdx", selectedIdx, "key", key).Debug("not modified. using local version") - if w != nil { - server.serveFile(w, r, key) - } - return preclosedChan - } - - if err != nil { - slog.With("error", err).Warn("failed to select fastest upstream") - if w != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - return preclosedChan - } - if selectedIdx == -1 || response == nil || chunks == nil { - slog.Debug("no upstream is selected") - if w != nil { - http.NotFound(w, r) - } - return preclosedChan - } - if response.StatusCode == http.StatusNotModified { - slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version") - os.Chtimes(key, zeroTime, time.Now()) - if w != nil { - server.serveFile(w, r, key) - } - return preclosedChan - } - - slog.With( - "upstreamIdx", selectedIdx, - ).Debug("found fastest upstream") - - buffer := &bytes.Buffer{} - ctx, cancel := context.WithCancel(r.Context()) - defer cancel() - - memoryObject = &StreamObject{ - Headers: response.Header, - Buffer: buffer, - - ctx: ctx, - - wg: &sync.WaitGroup{}, - } - - server.o[r.URL.Path] = memoryObject server.lu.Unlock() - locked = false + // This is the first request for this file, so we start the download. + server.startOrJoinStream(r, mtime, key) - err = nil - - if w != nil { - memoryObject.wg.Add(1) - - for k := range memoryObject.Headers { - v := memoryObject.Headers.Get(k) - w.Header().Set(k, v) - } - - go memoryObject.StreamTo(w, memoryObject.wg) + // Re-acquire the lock to get the newly created stream object. + server.lu.Lock() + memoryObject = server.o[r.URL.Path] + if memoryObject == nil { + server.lu.Unlock() + // This can happen if the upstream fails very quickly. + http.Error(w, "Failed to start download stream", http.StatusInternalServerError) + return } + } + server.lu.Unlock() - for chunk := range chunks { - if chunk.error != nil { - err = chunk.error - if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) { - slog.With("error", err).Warn("failed to read from upstream") - } - } - if chunk.buffer == nil { - break - } - n, _ := buffer.Write(chunk.buffer) - memoryObject.Offset += n - } - cancel() + // At this point, we have a memoryObject, and a producer goroutine is downloading the file. + // Now we implement the consumer logic. - memoryObject.wg.Wait() + // Duplicate the file descriptor from the producer's temp file. This creates a new, + // independent file descriptor that points to the same underlying file description. + // This is more efficient and robust than opening the file by path. + fd, err := syscall.Dup(int(memoryObject.TempFile.Fd())) + if err != nil { + http.Error(w, "Failed to duplicate file descriptor", http.StatusInternalServerError) + return + } + // We create a new *os.File from the duplicated descriptor. The consumer is now + // responsible for closing this new file. + consumerFile := os.NewFile(uintptr(fd), memoryObject.TempFile.Name()) + if consumerFile == nil { + syscall.Close(fd) // Clean up if NewFile fails + http.Error(w, "Failed to create file from duplicated descriptor", http.StatusInternalServerError) + return + } + defer consumerFile.Close() - if response.ContentLength > 0 { - if memoryObject.Offset == int(response.ContentLength) && err != nil { - if !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) { - slog.With("read-length", memoryObject.Offset, "content-length", response.ContentLength, "error", err, "upstreamIdx", selectedIdx).Debug("something happened during download. but response body is read as whole. so error is reset to nil") - } - err = nil - } - } else if err == io.EOF { - err = nil - } + rangeHeader := r.Header.Get("Range") + if rangeHeader == "" { + // This should not happen if called from HandleRequestWithCache, but as a safeguard: + http.Error(w, "Range header is required", http.StatusBadRequest) + return + } + // Parse the Range header. We only support a single range like "bytes=start-end". + var start, end int64 + parts := strings.Split(strings.TrimPrefix(rangeHeader, "bytes="), "-") + if len(parts) != 2 { + http.Error(w, "Invalid Range header", http.StatusBadRequest) + return + } + start, err = strconv.ParseInt(parts[0], 10, 64) + if err != nil { + http.Error(w, "Invalid Range header", http.StatusBadRequest) + return + } + if parts[1] != "" { + end, err = strconv.ParseInt(parts[1], 10, 64) if err != nil { - logger := slog.With("upstreamIdx", selectedIdx) - logger.Error("something happened during download. will not cache this response. setting lingering to reset the connection.") - hijacker, ok := w.(http.Hijacker) - if !ok { - logger.Warn("response writer is not a hijacker. failed to set lingering") - return preclosedChan - } - conn, _, err := hijacker.Hijack() - if err != nil { - logger.With("error", err).Warn("hijack failed. failed to set lingering") - return preclosedChan - } - defer conn.Close() + http.Error(w, "Invalid Range header", http.StatusBadRequest) + return + } + } else { + // An empty end means "to the end of the file". We don't know the full size yet. + // We'll have to handle this dynamically. + end = -1 // Sentinel value + } - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - logger.With("error", err).Warn("connection is not a *net.TCPConn. failed to set lingering") - return preclosedChan - } - if err := tcpConn.SetLinger(0); err != nil { - logger.With("error", err).Warn("failed to set lingering") - return preclosedChan - } - logger.Debug("connection set to linger. it will be reset once the conn.Close is called") + memoryObject.mu.Lock() + defer memoryObject.mu.Unlock() + + // Wait until we have the headers. + for memoryObject.Headers == nil && !memoryObject.Done { + memoryObject.cond.Wait() + } + + if memoryObject.Error != nil { + http.Error(w, memoryObject.Error.Error(), http.StatusInternalServerError) + return + } + + contentLengthStr := memoryObject.Headers.Get("Content-Length") + totalSize, _ := strconv.ParseInt(contentLengthStr, 10, 64) + if end == -1 || end >= totalSize { + end = totalSize - 1 + } + + if start >= totalSize || start > end { + http.Error(w, "Range not satisfiable", http.StatusRequestedRangeNotSatisfiable) + return + } + + var bytesSent int64 + bytesToSend := end - start + 1 + headersWritten := false + + for bytesSent < bytesToSend && memoryObject.Error == nil { + // Calculate what we need. + neededStart := start + bytesSent + + // Wait for the data to be available. + for memoryObject.Offset <= neededStart && !memoryObject.Done { + memoryObject.cond.Wait() } - fileWrittenCh := make(chan struct{}) - go func() { - defer func() { - server.lu.Lock() - defer server.lu.Unlock() - - delete(server.o, r.URL.Path) - slog.Debug("memory object released") - close(fileWrittenCh) - }() - - if err == nil { - slog.Debug("preparing to release memory object") - mtime := zeroTime - lastModifiedHeader := response.Header.Get("Last-Modified") - if lastModified, err := time.Parse(time.RFC1123, lastModifiedHeader); err != nil { - slog.With( - "error", err, - "value", lastModifiedHeader, - "url", response.Request.URL, - ).Debug("failed to parse last modified header value. set modified time to now") - } else { - slog.With( - "header", lastModifiedHeader, - "value", lastModified, - "url", response.Request.URL, - ).Debug("found modified time") - mtime = lastModified - } - if err := os.MkdirAll(server.Storage.Local.Path, 0755); err != nil { - slog.With("error", err).Warn("failed to create local storage path") - } - - if server.Config.Storage.Local.TemporaryFilePattern == "" { - if err := os.WriteFile(key, buffer.Bytes(), 0644); err != nil { - slog.With("error", err).Warn("failed to write file") - os.Remove(key) - } - return - } - - fp, err := os.CreateTemp(server.Storage.Local.Path, server.Storage.Local.TemporaryFilePattern) - if err != nil { - slog.With( - "key", key, - "path", server.Storage.Local.Path, - "pattern", server.Storage.Local.TemporaryFilePattern, - "error", err, - ).Warn("failed to create template file") - return - } - - name := fp.Name() - - if _, err := fp.Write(buffer.Bytes()); err != nil { - fp.Close() - os.Remove(name) - - slog.With("error", err).Warn("failed to write into template file") - } else if err := fp.Close(); err != nil { - os.Remove(name) - - slog.With("error", err).Warn("failed to close template file") - } else { - os.Chtimes(name, zeroTime, mtime) - dirname := filepath.Dir(key) - os.MkdirAll(dirname, 0755) - os.Remove(key) - os.Rename(name, key) - } + // Check for error AFTER waiting. This is the critical fix. + if memoryObject.Error != nil { + // If headers haven't been written, we can send a 500 error. + // If they have, it's too late, the connection will just be closed. + if !headersWritten { + http.Error(w, memoryObject.Error.Error(), http.StatusInternalServerError) } - }() + return // Use return instead of break to exit immediately. + } - return fileWrittenCh + // If we are here, we have some data to send. Write headers if we haven't already. + if !headersWritten { + w.Header().Set("Content-Range", "bytes "+strconv.FormatInt(start, 10)+"-"+strconv.FormatInt(end, 10)+"/"+contentLengthStr) + w.Header().Set("Content-Length", strconv.FormatInt(end-start+1, 10)) + w.Header().Set("Accept-Ranges", "bytes") + w.WriteHeader(http.StatusPartialContent) + headersWritten = true + } + + // Data is available, read from the temporary file. + // We calculate how much we can read in this iteration. + readNow := memoryObject.Offset - neededStart + if readNow <= 0 { + // This can happen if we woke up but the data isn't what we need. + // The loop will continue to wait. + continue + } + + // Don't read more than the client requested in total. + remainingToSend := bytesToSend - bytesSent + if readNow > remainingToSend { + readNow = remainingToSend + } + + // Read the chunk from the file at the correct offset. + buffer := make([]byte, readNow) + bytesRead, err := consumerFile.ReadAt(buffer, neededStart) + if err != nil && err != io.EOF { + if !headersWritten { + http.Error(w, "Error reading from cache stream", http.StatusInternalServerError) + } + return + } + + if bytesRead > 0 { + n, err := w.Write(buffer[:bytesRead]) + if err != nil { + // Client closed connection, just return. + return + } + bytesSent += int64(n) + } + + if memoryObject.Done && bytesSent >= bytesToSend { + break + } } } -func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, resultCh chan Chunk, resultErr error) { +// startOrJoinStream ensures a download stream is active for the given key. +// If a stream already exists, it returns the channel that signals completion. +// If not, it starts a new download producer and returns its completion channel. +func (server *Server) startOrJoinStream(r *http.Request, mtime time.Time, key string) <-chan struct{} { + server.lu.Lock() + memoryObject, exists := server.o[r.URL.Path] + if exists { + // A download is already in progress. Return its completion channel. + server.lu.Unlock() + return memoryObject.fileWrittenCh + } + + // No active stream, create a new one. + if err := os.MkdirAll(filepath.Dir(key), 0755); err != nil { + slog.With("error", err).Warn("failed to create local storage path for temp file") + server.lu.Unlock() + return preclosedChan // Return a closed channel to prevent blocking + } + + tempFilePattern := server.Storage.Local.TemporaryFilePattern + if tempFilePattern == "" { + tempFilePattern = "cache-proxy-*" + } + tempFile, err := os.CreateTemp(filepath.Dir(key), tempFilePattern) + if err != nil { + slog.With("error", err).Warn("failed to create temporary file") + server.lu.Unlock() + return preclosedChan + } + + mu := &sync.Mutex{} + fileWrittenCh := make(chan struct{}) + memoryObject = &StreamObject{ + TempFile: tempFile, + mu: mu, + cond: sync.NewCond(mu), + fileWrittenCh: fileWrittenCh, + } + server.o[r.URL.Path] = memoryObject + server.lu.Unlock() + + // This is the producer goroutine + go func(mo *StreamObject) { + var err error + downloadSucceeded := false + tempFileName := mo.TempFile.Name() + + defer func() { + // On completion (or error), update the stream object, + // wake up all consumers, and then clean up. + mo.mu.Lock() + mo.Done = true + mo.Error = err + mo.cond.Broadcast() + mo.mu.Unlock() + + // Close the temp file handle. + mo.TempFile.Close() + + // If download failed, remove the temp file. + if !downloadSucceeded { + os.Remove(tempFileName) + } + + // The producer's job is done. Remove the object from the central map. + // Any existing consumers still hold a reference to the object, + // so it won't be garbage collected until they are done. + server.lu.Lock() + delete(server.o, r.URL.Path) + server.lu.Unlock() + slog.Debug("memory object released by producer") + close(mo.fileWrittenCh) + }() + + slog.With("mtime", mtime).Debug("checking fastest upstream") + selectedIdx, response, firstChunk, upstreamErr := server.fastestUpstream(r, mtime) + if upstreamErr != nil { + if !errors.Is(upstreamErr, io.EOF) && !errors.Is(upstreamErr, io.ErrUnexpectedEOF) { + slog.With("error", upstreamErr).Warn("failed to select fastest upstream") + err = upstreamErr + return + } + } + + if selectedIdx == -1 || response == nil { + slog.Debug("no upstream is selected") + err = errors.New("no suitable upstream found") + return + } + + if response.StatusCode == http.StatusNotModified { + slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version") + os.Chtimes(key, zeroTime, time.Now()) + // In this case, we don't have a new file, so we just exit. + // The temp file will be cleaned up by the defer. + return + } + + defer response.Body.Close() + + slog.With("upstreamIdx", selectedIdx).Debug("found fastest upstream") + + mo.mu.Lock() + mo.Headers = response.Header + mo.cond.Broadcast() // Broadcast headers availability + mo.mu.Unlock() + + // Write the first chunk that we already downloaded. + if len(firstChunk) > 0 { + written, writeErr := mo.TempFile.Write(firstChunk) + if writeErr != nil { + err = writeErr + return + } + mo.mu.Lock() + mo.Offset += int64(written) + mo.cond.Broadcast() + mo.mu.Unlock() + } + + // Download the rest of the file in chunks + buffer := make([]byte, server.Misc.ChunkBytes) + for { + n, readErr := response.Body.Read(buffer) + if n > 0 { + written, writeErr := mo.TempFile.Write(buffer[:n]) + if writeErr != nil { + err = writeErr + break + } + mo.mu.Lock() + mo.Offset += int64(written) + mo.cond.Broadcast() + mo.mu.Unlock() + } + if readErr != nil { + if readErr != io.EOF { + err = readErr + } + break + } + } + + // After download, if no critical error, rename the temp file to its final destination. + if err == nil { + // Set modification time + mtime := zeroTime + lastModifiedHeader := response.Header.Get("Last-Modified") + if lastModified, lmErr := time.Parse(time.RFC1123, lastModifiedHeader); lmErr == nil { + mtime = lastModified + } + // Close file before Chtimes and Rename + mo.TempFile.Close() + + os.Chtimes(tempFileName, zeroTime, mtime) + + // Rename the file + if renameErr := os.Rename(tempFileName, key); renameErr != nil { + slog.With("error", renameErr, "from", tempFileName, "to", key).Warn("failed to rename temp file") + err = renameErr + os.Remove(tempFileName) // Attempt to clean up if rename fails + } else { + downloadSucceeded = true + } + } else { + logger := slog.With("upstreamIdx", selectedIdx) + logger.Error("something happened during download. will not cache this response.", "error", err) + } + }(memoryObject) + + return fileWrittenCh +} + +func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, firstChunk []byte, resultErr error) { returnLock := &sync.Mutex{} upstreams := len(server.Upstreams) cancelFuncs := make([]func(), upstreams) @@ -458,6 +536,8 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) ( resultIdx = -1 + var resultFirstChunk []byte + defer close(updateCh) defer close(notModifiedCh) defer func() { @@ -472,7 +552,8 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) ( groups := make(map[int][]int) for upstreamIdx, upstream := range server.Upstreams { if _, matched, err := upstream.GetPath(r.URL.Path); err != nil { - return -1, nil, nil, err + resultErr = err + return } else if !matched { continue } @@ -480,7 +561,8 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) ( priority := 0 for _, priorityGroup := range upstream.PriorityGroups { if matched, err := regexp.MatchString(priorityGroup.Match, r.URL.Path); err != nil { - return -1, nil, nil, err + resultErr = err + return } else if matched { priority = priorityGroup.Priority break @@ -516,7 +598,7 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) ( go func() { defer wg.Done() - response, ch, err := server.tryUpstream(ctx, idx, priority, r, lastModified) + response, chunk, err := server.tryUpstream(ctx, idx, priority, r, lastModified) if err == context.Canceled { // others returned logger.Debug("context canceled") return @@ -536,13 +618,16 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) ( } locked := returnLock.TryLock() if !locked { + if response != nil { + response.Body.Close() + } return } defer returnLock.Unlock() if response.StatusCode == http.StatusNotModified { notModifiedOnce.Do(func() { - resultResponse, resultCh, resultErr = response, ch, err + resultResponse, resultErr = response, err notModifiedCh <- idx }) logger.Debug("voted not modified") @@ -550,7 +635,7 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) ( } updateOnce.Do(func() { - resultResponse, resultCh, resultErr = response, ch, err + resultResponse, resultFirstChunk, resultErr = response, chunk, err updateCh <- idx for cancelIdx, cancel := range cancelFuncs { @@ -570,6 +655,7 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) ( select { case idx := <-updateCh: resultIdx = idx + firstChunk = resultFirstChunk logger.With("upstreamIdx", resultIdx).Debug("upstream selected") return default: @@ -584,10 +670,10 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) ( } } - return -1, nil, nil, nil + return } -func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int, r *http.Request, lastModified time.Time) (response *http.Response, chunks chan Chunk, err error) { +func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int, r *http.Request, lastModified time.Time) (response *http.Response, firstChunk []byte, err error) { upstream := server.Upstreams[upstreamIdx] newpath, matched, err := upstream.GetPath(r.URL.Path) @@ -633,9 +719,9 @@ func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int if err != nil { return nil, nil, err } - streaming := false + shouldCloseBody := true defer func() { - if !streaming && response != nil { + if shouldCloseBody && response != nil { response.Body.Close() } }() @@ -683,10 +769,6 @@ func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int } } - var currentOffset int64 - - ch := make(chan Chunk, 1024) - buffer := make([]byte, server.Misc.FirstChunkBytes) n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) @@ -694,30 +776,12 @@ func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int if n == 0 { return response, nil, err } - } - ch <- Chunk{buffer: buffer[:n]} - streaming = true - go func() { - defer close(ch) - defer response.Body.Close() - - for { - buffer := make([]byte, server.Misc.ChunkBytes) - n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) - if n > 0 { - ch <- Chunk{buffer: buffer[:n]} - currentOffset += int64(n) - } - if response.ContentLength > 0 && currentOffset == response.ContentLength && err == io.EOF || err == io.ErrUnexpectedEOF { - return - } - if err != nil { - ch <- Chunk{error: err} - return - } + if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) { + err = nil } - }() + } - return response, ch, nil + shouldCloseBody = false + return response, buffer[:n], err } diff --git a/server_test.go b/server_test.go index 713d6f9..5bb14d3 100644 --- a/server_test.go +++ b/server_test.go @@ -3,6 +3,7 @@ package cacheproxy import ( "bytes" "context" + "errors" "io" "net/http" "net/http/httptest" @@ -47,45 +48,28 @@ func TestTryUpstream(t *testing.T) { // 调用被测函数 req, _ := http.NewRequest("GET", "/testfile", nil) - resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) + resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) // 断言 - if err != nil { + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { t.Fatalf("tryUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("tryUpstream() response is nil") } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } - if chunks == nil { - t.Fatal("tryUpstream() chunks channel is nil") + if firstChunk == nil { + t.Fatal("tryUpstream() firstChunk is nil") } var receivedBody bytes.Buffer - var finalErr error - LOOP: - for { - select { - case chunk, ok := <-chunks: - if !ok { - break LOOP - } - if chunk.error != nil { - finalErr = chunk.error - break LOOP - } - receivedBody.Write(chunk.buffer) - case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for chunk") - } - } - - if finalErr != nil && finalErr != io.EOF { - t.Errorf("Expected final error to be io.EOF, got %v", finalErr) - } + receivedBody.Write(firstChunk) + rest, _ := io.ReadAll(resp.Body) + receivedBody.Write(rest) if receivedBody.String() != expectedBody { t.Errorf("Expected body '%s', got '%s'", expectedBody, receivedBody.String()) @@ -119,7 +103,7 @@ func TestTryUpstream(t *testing.T) { // 调用被测函数 req, _ := http.NewRequest("GET", "/testfile", nil) lastModified := time.Now().UTC() - resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, lastModified) + resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, lastModified) // 断言 if err != nil { @@ -131,8 +115,8 @@ func TestTryUpstream(t *testing.T) { if resp.StatusCode != http.StatusNotModified { t.Errorf("Expected status code %d, got %d", http.StatusNotModified, resp.StatusCode) } - if chunks != nil { - t.Error("Expected chunks channel to be nil for 304 response") + if firstChunk != nil { + t.Error("Expected firstChunk to be nil for 304 response") } }) @@ -158,7 +142,7 @@ func TestTryUpstream(t *testing.T) { // 调用被测函数 req, _ := http.NewRequest("GET", "/testfile", nil) - resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) + resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) // 断言 if err != nil { @@ -166,8 +150,8 @@ func TestTryUpstream(t *testing.T) { } // 对于一个明确的失败(如404),我们期望函数返回 nil, nil, nil // 因为 response checker 默认只接受 200 - if resp != nil || chunks != nil { - t.Errorf("Expected response and chunks to be nil, got resp=%v, chunks=%v", resp, chunks) + if resp != nil || firstChunk != nil { + t.Errorf("Expected response and firstChunk to be nil, got resp=%v, chunks=%v", resp, firstChunk) } }) @@ -243,17 +227,15 @@ func TestTryUpstream(t *testing.T) { }) req, _ := http.NewRequest("GET", "/testfile", nil) - resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) + resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) - if err != nil { + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { t.Fatalf("Expected no error, got %v", err) } - if resp == nil || chunks == nil { - t.Fatal("Expected successful response and chunks, but got nil") - } - // drain the channel - for range chunks { + if resp == nil || firstChunk == nil { + t.Fatal("Expected successful response and firstChunk, but got nil") } + resp.Body.Close() }) // Case B: Checker 失败,应该返回 nil @@ -281,13 +263,13 @@ func TestTryUpstream(t *testing.T) { }) req, _ := http.NewRequest("GET", "/testfile", nil) - resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) + resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } - if resp != nil || chunks != nil { - t.Fatal("Expected response and chunks to be nil due to checker failure") + if resp != nil || firstChunk != nil { + t.Fatal("Expected response and firstChunk to be nil due to checker failure") } }) }) @@ -326,20 +308,18 @@ func TestTryUpstream(t *testing.T) { }) req, _ := http.NewRequest("GET", "/testfile", nil) - resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) + resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) - if err != nil { + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { t.Fatalf("Expected no error, got %v", err) } - if resp == nil || chunks == nil { - t.Fatal("Expected successful response and chunks, but got nil") + if resp == nil || firstChunk == nil { + t.Fatal("Expected successful response and firstChunk, but got nil") } if resp.StatusCode != http.StatusOK { t.Errorf("Expected status OK, got %d", resp.StatusCode) } - // drain the channel - for range chunks { - } + resp.Body.Close() }) // Case B: 重定向被禁止 @@ -355,7 +335,7 @@ func TestTryUpstream(t *testing.T) { }) req, _ := http.NewRequest("GET", "/testfile", nil) - resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) + resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) // 当重定向规则不匹配时,自定义的 CheckRedirect 函数会返回 http.ErrUseLastResponse。 // 这使得 http.Client 直接返回 302 响应,而不是跟随跳转。 @@ -364,8 +344,8 @@ func TestTryUpstream(t *testing.T) { if err != nil { t.Fatalf("Expected no error for a checker failure, got %v", err) } - if resp != nil || chunks != nil { - t.Fatalf("Expected response and chunks to be nil for a disallowed redirect, got resp=%v, chunks=%v", resp, chunks) + if resp != nil || firstChunk != nil { + t.Fatalf("Expected response and firstChunk to be nil for a disallowed redirect, got resp=%v, chunks=%v", resp, firstChunk) } }) }) @@ -413,38 +393,27 @@ func TestFastestUpstream(t *testing.T) { }) req, _ := http.NewRequest("GET", "/testfile", nil) - upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) + upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{}) - if err != nil { + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("fastestUpstream() response is nil") } + defer resp.Body.Close() if upstreamIndex != 1 { t.Errorf("Expected fastest upstream index to be 1 (fastServer), got %d", upstreamIndex) } - if chunks == nil { - t.Fatal("fastestUpstream() chunks channel is nil") + if firstChunk == nil { + t.Fatal("fastestUpstream() firstChunk is nil") } var receivedBody bytes.Buffer - LOOP: - for { - select { - case chunk, ok := <-chunks: - if !ok { - break LOOP - } - if chunk.error != nil && chunk.error != io.EOF { - t.Fatalf("unexpected error from chunk channel: %v", chunk.error) - } - receivedBody.Write(chunk.buffer) - case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for chunk") - } - } + receivedBody.Write(firstChunk) + rest, _ := io.ReadAll(resp.Body) + receivedBody.Write(rest) if receivedBody.String() != "fast" { t.Errorf("Expected body from fast server, got '%s'", receivedBody.String()) @@ -490,29 +459,27 @@ func TestFastestUpstream(t *testing.T) { }) req, _ := http.NewRequest("GET", "/testfile", nil) - upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) + upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{}) - if err != nil { + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("fastestUpstream() response is nil") } + defer resp.Body.Close() if upstreamIndex != 1 { t.Errorf("Expected fastest upstream index to be 1 (workingServer), got %d", upstreamIndex) } - if chunks == nil { - t.Fatal("fastestUpstream() chunks channel is nil") + if firstChunk == nil { + t.Fatal("fastestUpstream() firstChunk is nil") } var receivedBody bytes.Buffer - for chunk := range chunks { - if chunk.error != nil && chunk.error != io.EOF { - t.Fatalf("unexpected error from chunk channel: %v", chunk.error) - } - receivedBody.Write(chunk.buffer) - } + receivedBody.Write(firstChunk) + rest, _ := io.ReadAll(resp.Body) + receivedBody.Write(rest) if receivedBody.String() != "working" { t.Errorf("Expected body from working server, got '%s'", receivedBody.String()) @@ -546,7 +513,7 @@ func TestFastestUpstream(t *testing.T) { }) req, _ := http.NewRequest("GET", "/testfile", nil) - upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) + upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{}) if err != nil { t.Fatalf("fastestUpstream() unexpected error: %v", err) @@ -554,8 +521,8 @@ func TestFastestUpstream(t *testing.T) { if upstreamIndex != -1 { t.Errorf("Expected upstream index to be -1, got %d", upstreamIndex) } - if resp != nil || chunks != nil { - t.Errorf("Expected response and chunks to be nil, got resp=%v, chunks=%v", resp, chunks) + if resp != nil || firstChunk != nil { + t.Errorf("Expected response and firstChunk to be nil, got resp=%v, chunks=%v", resp, firstChunk) } }) @@ -585,7 +552,7 @@ func TestFastestUpstream(t *testing.T) { req, _ := http.NewRequest("GET", "/testfile", nil) lastModified := time.Now().UTC() - upstreamIndex, resp, chunks, err := server.fastestUpstream(req, lastModified) + upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, lastModified) if err != nil { t.Fatalf("fastestUpstream() unexpected error: %v", err) @@ -599,8 +566,8 @@ func TestFastestUpstream(t *testing.T) { if resp.StatusCode != http.StatusNotModified { t.Errorf("Expected status code %d, got %d", http.StatusNotModified, resp.StatusCode) } - if chunks != nil { - t.Error("Expected chunks channel to be nil for 304 response") + if firstChunk != nil { + t.Error("Expected firstChunk to be nil for 304 response") } }) @@ -645,29 +612,27 @@ func TestFastestUpstream(t *testing.T) { }) req, _ := http.NewRequest("GET", "/testfile", nil) - upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) + upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{}) - if err != nil { + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("fastestUpstream() response is nil") } + defer resp.Body.Close() if upstreamIndex != 1 { t.Errorf("Expected upstream index to be 1 (slow-high), got %d", upstreamIndex) } - if chunks == nil { - t.Fatal("fastestUpstream() chunks channel is nil") + if firstChunk == nil { + t.Fatal("fastestUpstream() firstChunk is nil") } var receivedBody bytes.Buffer - for chunk := range chunks { - if chunk.error != nil && chunk.error != io.EOF { - t.Fatalf("unexpected error from chunk channel: %v", chunk.error) - } - receivedBody.Write(chunk.buffer) - } + receivedBody.Write(firstChunk) + rest, _ := io.ReadAll(resp.Body) + receivedBody.Write(rest) if receivedBody.String() != "slow-high" { t.Errorf("Expected body from slow-high server, got '%s'", receivedBody.String()) @@ -713,29 +678,27 @@ func TestFastestUpstream(t *testing.T) { }) req, _ := http.NewRequest("GET", "/testfile", nil) - upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) + upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{}) - if err != nil { + if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("fastestUpstream() response is nil") } + defer resp.Body.Close() if upstreamIndex != 0 { t.Errorf("Expected upstream index to be 0 (working-low), got %d", upstreamIndex) } - if chunks == nil { - t.Fatal("fastestUpstream() chunks channel is nil") + if firstChunk == nil { + t.Fatal("fastestUpstream() firstChunk is nil") } var receivedBody bytes.Buffer - for chunk := range chunks { - if chunk.error != nil && chunk.error != io.EOF { - t.Fatalf("unexpected error from chunk channel: %v", chunk.error) - } - receivedBody.Write(chunk.buffer) - } + receivedBody.Write(firstChunk) + rest, _ := io.ReadAll(resp.Body) + receivedBody.Write(rest) if receivedBody.String() != "working-low" { t.Errorf("Expected body from working-low server, got '%s'", receivedBody.String()) @@ -1288,9 +1251,465 @@ func TestIntegration_HandleRequestWithCache(t *testing.T) { t.Errorf("Expected error message about crossing boundary, got: %s", rr.Body.String()) } }) + + // 4.8. Range 请求竞态条件 (Race Condition on Range Request): + // 目的: 专门测试在冷缓存上处理范围请求时,生产者和消费者之间的竞态条件。 + // 当上游响应非常快时,生产者 goroutine 可能会在消费者 goroutine + // 有机会复制(dup)文件句柄之前,就过早地关闭并重命名了临时文件, + // 导致消费者持有一个无效的文件句柄。 + // 流程: + // - 模拟一个极快响应的上游。 + // - 多次并发地发起范围请求。 + // - 验证即使在这种高速场景下,消费者依然能返回正确的局部内容。 + t.Run("RaceConditionOnRangeRequest", func(t *testing.T) { + for i := 0; i < 20; i++ { // 多次运行以增加命中竞态条件的概率 + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Parallel() // 并行运行测试以增加调度压力 + + fullBody := "race condition test body content" + mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Length", strconv.Itoa(len(fullBody))) + w.WriteHeader(http.StatusOK) + w.Write([]byte(fullBody)) + })) + defer mockUpstream.Close() + + tmpDir := t.TempDir() + server := NewServer(Config{ + Upstreams: []Upstream{ + {Server: mockUpstream.URL}, + }, + Storage: Storage{ + Local: &LocalStorage{ + Path: tmpDir, + }, + }, + Misc: MiscConfig{ + FirstChunkBytes: 1024, // 确保一次性读完 + ChunkBytes: 1024, + }, + }) + + req := httptest.NewRequest("GET", "/race.txt", nil) + req.Header.Set("Range", "bytes=5-15") + rr := httptest.NewRecorder() + server.HandleRequestWithCache(rr, req) + + // 在有问题的实现中,这里会因为文件句柄被过早关闭而失败 + if rr.Code != http.StatusPartialContent { + t.Fatalf("Expected status code %d, got %d", http.StatusPartialContent, rr.Code) + } + + expectedPartialBody := fullBody[5:16] + if rr.Body.String() != expectedPartialBody { + t.Fatalf("Expected partial body '%s', got '%s'", expectedPartialBody, rr.Body.String()) + } + }) + } + }) } func isTimeoutError(err error) bool { e, ok := err.(interface{ Timeout() bool }) return ok && e.Timeout() } + +// TestServeRangedRequest 包含了针对 serveRangedRequest 函数的单元测试。 +// 这些测试独立地验证消费者逻辑,通过模拟一个正在进行中的下载流 (StreamObject) +// 来覆盖各种客户端请求场景。 +func TestServeRangedRequest(t *testing.T) { + // 5.1. 请求的范围已完全可用: + // 目的: 验证当客户端请求的数据范围在下载完成前已经完全在临时文件中时, + // 函数能立即响应并返回正确的数据,而无需等待整个文件下载完毕。 + t.Run("RequestRangeFullyAvailable", func(t *testing.T) { + // 1. 设置 + server := NewServer(Config{}) + mu := &sync.Mutex{} + cond := sync.NewCond(mu) + fileSize := 1000 // 文件总大小为 1000 + + // 创建并写入临时文件 + tmpFile, err := os.CreateTemp(t.TempDir(), "test-") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + initialData := make([]byte, 200) // 文件已有 200 字节 + if _, err := tmpFile.Write(initialData); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + + streamObj := &StreamObject{ + Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}}, + TempFile: tmpFile, + Offset: int64(len(initialData)), + Done: false, // 下载仍在进行 + Error: nil, + mu: mu, + cond: cond, + } + + // 将模拟的 stream object 放入 server 的 map 中 + server.lu.Lock() + server.o["/testfile"] = streamObj + server.lu.Unlock() + + // 2. 执行 + req := httptest.NewRequest("GET", "/testfile", nil) + req.Header.Set("Range", "bytes=50-150") + rr := httptest.NewRecorder() + + server.serveRangedRequest(rr, req, "/testfile", time.Time{}) + + // 3. 断言 + if rr.Code != http.StatusPartialContent { + t.Errorf("Expected status %d, got %d", http.StatusPartialContent, rr.Code) + } + + expectedBodyLen := 101 // 150 - 50 + 1 + if rr.Body.Len() != expectedBodyLen { + t.Errorf("Expected body length %d, got %d", expectedBodyLen, rr.Body.Len()) + } + + expectedContentRange := "bytes 50-150/1000" + if rr.Header().Get("Content-Range") != expectedContentRange { + t.Errorf("Expected Content-Range '%s', got '%s'", expectedContentRange, rr.Header().Get("Content-Range")) + } + }) + + // 5.2. 请求的范围需要等待后续数据: + // 目的: 验证当请求的数据部分可用时,函数能先发送可用的部分, + // 然后等待生产者(下载 goroutine)提供更多数据,并最终完成响应。 + t.Run("RequestWaitsForData", func(t *testing.T) { + // 1. 设置 + server := NewServer(Config{}) + mu := &sync.Mutex{} + cond := sync.NewCond(mu) + fileSize := 200 // 文件总大小为 200 + + tmpFile, err := os.CreateTemp(t.TempDir(), "test-") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + // 初始只有 100 字节 + initialData := make([]byte, 100) + for i := 0; i < 100; i++ { + initialData[i] = byte(i) + } + if _, err := tmpFile.Write(initialData); err != nil { + t.Fatalf("Failed to write initial data: %v", err) + } + + streamObj := &StreamObject{ + Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}}, + TempFile: tmpFile, + Offset: int64(len(initialData)), + Done: false, + Error: nil, + mu: mu, + cond: cond, + } + + server.lu.Lock() + server.o["/testfile"] = streamObj + server.lu.Unlock() + + // 2. 执行 + req := httptest.NewRequest("GET", "/testfile", nil) + // 请求 150 字节,但目前只有 100 字节可用 + req.Header.Set("Range", "bytes=0-149") + rr := httptest.NewRecorder() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + server.serveRangedRequest(rr, req, "/testfile", time.Time{}) + }() + + // 模拟生产者在短暂延迟后写入更多数据 + go func() { + time.Sleep(50 * time.Millisecond) // 等待消费者进入等待状态 + streamObj.mu.Lock() + defer streamObj.mu.Unlock() + + // 写入剩下的 100 字节 + remainingData := make([]byte, 100) + for i := 0; i < 100; i++ { + remainingData[i] = byte(100 + i) + } + n, err := streamObj.TempFile.Write(remainingData) + if err != nil { + t.Errorf("Producer failed to write: %v", err) + return + } + streamObj.Offset += int64(n) + streamObj.Done = true // 下载完成 + streamObj.cond.Broadcast() // 唤醒消费者 + }() + + // 等待 serveRangedRequest 完成 + wg.Wait() + + // 3. 断言 + if rr.Code != http.StatusPartialContent { + t.Errorf("Expected status %d, got %d", http.StatusPartialContent, rr.Code) + } + + expectedBodyLen := 150 + if rr.Body.Len() != expectedBodyLen { + t.Errorf("Expected body length %d, got %d", expectedBodyLen, rr.Body.Len()) + } + + // 验证内容是否正确 + body := rr.Body.Bytes() + for i := 0; i < 150; i++ { + if body[i] != byte(i) { + t.Fatalf("Body content mismatch at index %d: expected %d, got %d", i, i, body[i]) + } + } + + expectedContentRange := "bytes 0-149/200" + if rr.Header().Get("Content-Range") != expectedContentRange { + t.Errorf("Expected Content-Range '%s', got '%s'", expectedContentRange, rr.Header().Get("Content-Range")) + } + }) + + // 5.3. 多个客户端并发请求: + // 目的: 验证多个客户端可以同时从同一个流中读取数据,即使它们请求的是不同或重叠的范围。 + t.Run("MultipleConcurrentClients", func(t *testing.T) { + // 1. 设置 + server := NewServer(Config{}) + mu := &sync.Mutex{} + cond := sync.NewCond(mu) + fileSize := 300 + + tmpFile, err := os.CreateTemp(t.TempDir(), "test-") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + // 初始数据 + initialData := make([]byte, 100) + if _, err := tmpFile.Write(initialData); err != nil { + t.Fatalf("Failed to write initial data: %v", err) + } + + streamObj := &StreamObject{ + Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}}, + TempFile: tmpFile, + Offset: int64(len(initialData)), + Done: false, + Error: nil, + mu: mu, + cond: cond, + } + + server.lu.Lock() + server.o["/testfile"] = streamObj + server.lu.Unlock() + + // 模拟生产者持续写入数据 + go func() { + // 写入第二部分 + time.Sleep(50 * time.Millisecond) + streamObj.mu.Lock() + n, _ := streamObj.TempFile.Write(make([]byte, 100)) + streamObj.Offset += int64(n) + streamObj.cond.Broadcast() + streamObj.mu.Unlock() + + // 写入最后一部分并完成 + time.Sleep(50 * time.Millisecond) + streamObj.mu.Lock() + n, _ = streamObj.TempFile.Write(make([]byte, 100)) + streamObj.Offset += int64(n) + streamObj.Done = true + streamObj.cond.Broadcast() + streamObj.mu.Unlock() + }() + + var wg sync.WaitGroup + wg.Add(2) + + // 客户端 1: 请求 [50, 150] + rr1 := httptest.NewRecorder() + go func() { + defer wg.Done() + req := httptest.NewRequest("GET", "/testfile", nil) + req.Header.Set("Range", "bytes=50-150") + server.serveRangedRequest(rr1, req, "/testfile", time.Time{}) + }() + + // 客户端 2: 请求 [180, 280] + rr2 := httptest.NewRecorder() + go func() { + defer wg.Done() + req := httptest.NewRequest("GET", "/testfile", nil) + req.Header.Set("Range", "bytes=180-280") + server.serveRangedRequest(rr2, req, "/testfile", time.Time{}) + }() + + wg.Wait() + + // 断言客户端 1 + if rr1.Code != http.StatusPartialContent { + t.Errorf("Client 1: Expected status %d, got %d", http.StatusPartialContent, rr1.Code) + } + if rr1.Body.Len() != 101 { + t.Errorf("Client 1: Expected body length 101, got %d", rr1.Body.Len()) + } + if rr1.Header().Get("Content-Range") != "bytes 50-150/300" { + t.Errorf("Client 1: Incorrect Content-Range header: %s", rr1.Header().Get("Content-Range")) + } + + // 断言客户端 2 + if rr2.Code != http.StatusPartialContent { + t.Errorf("Client 2: Expected status %d, got %d", http.StatusPartialContent, rr2.Code) + } + if rr2.Body.Len() != 101 { + t.Errorf("Client 2: Expected body length 101, got %d", rr2.Body.Len()) + } + if rr2.Header().Get("Content-Range") != "bytes 180-280/300" { + t.Errorf("Client 2: Incorrect Content-Range header: %s", rr2.Header().Get("Content-Range")) + } + }) + + // 5.4. 下载过程中发生错误: + // 目的: 验证当生产者(下载 goroutine)遇到错误时,所有等待的消费者都会被正确通知并中断, + // 并向各自的客户端返回一个错误。 + t.Run("DownloadFailsWhileWaiting", func(t *testing.T) { + // 1. 设置 + server := NewServer(Config{}) + mu := &sync.Mutex{} + cond := sync.NewCond(mu) + fileSize := 200 + downloadError := errors.New("network connection lost") + + tmpFile, err := os.CreateTemp(t.TempDir(), "test-") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + // 初始有 50 字节 + if _, err := tmpFile.Write(make([]byte, 50)); err != nil { + t.Fatalf("Failed to write initial data: %v", err) + } + + streamObj := &StreamObject{ + Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}}, + TempFile: tmpFile, + Offset: 50, + Done: false, + Error: nil, + mu: mu, + cond: cond, + } + + server.lu.Lock() + server.o["/testfile"] = streamObj + server.lu.Unlock() + + // 模拟生产者在短暂延迟后遇到错误 + go func() { + time.Sleep(50 * time.Millisecond) + streamObj.mu.Lock() + defer streamObj.mu.Unlock() + streamObj.Error = downloadError + streamObj.Done = true // 标记为完成(虽然是失败) + streamObj.cond.Broadcast() // 唤醒所有等待的消费者 + }() + + var wg sync.WaitGroup + wg.Add(2) + + // 客户端 1: 请求 [100, 150], 需要等待 + rr1 := httptest.NewRecorder() + go func() { + defer wg.Done() + req := httptest.NewRequest("GET", "/testfile", nil) + req.Header.Set("Range", "bytes=100-150") + server.serveRangedRequest(rr1, req, "/testfile", time.Time{}) + }() + + // 客户端 2: 也请求 [100, 150], 同样等待 + rr2 := httptest.NewRecorder() + go func() { + defer wg.Done() + req := httptest.NewRequest("GET", "/testfile", nil) + req.Header.Set("Range", "bytes=100-150") + server.serveRangedRequest(rr2, req, "/testfile", time.Time{}) + }() + + wg.Wait() + + // 断言两个客户端都收到了内部服务器错误 + if rr1.Code != http.StatusInternalServerError { + t.Errorf("Client 1: Expected status %d, got %d", http.StatusInternalServerError, rr1.Code) + } + if !strings.Contains(rr1.Body.String(), downloadError.Error()) { + t.Errorf("Client 1: Expected error message '%s', got '%s'", downloadError.Error(), rr1.Body.String()) + } + + if rr2.Code != http.StatusInternalServerError { + t.Errorf("Client 2: Expected status %d, got %d", http.StatusInternalServerError, rr2.Code) + } + if !strings.Contains(rr2.Body.String(), downloadError.Error()) { + t.Errorf("Client 2: Expected error message '%s', got '%s'", downloadError.Error(), rr2.Body.String()) + } + }) + + // 5.5. 上游无 Content-Length: + // 目的: 验证当上游响应没有 Content-Length 头时,系统如何优雅地处理。 + // 在当前实现中,缺少 Content-Length 会导致无法计算范围,应返回错误。 + t.Run("NoContentLength", func(t *testing.T) { + // 1. 设置 + server := NewServer(Config{}) + mu := &sync.Mutex{} + cond := sync.NewCond(mu) + + tmpFile, err := os.CreateTemp(t.TempDir(), "test-") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + streamObj := &StreamObject{ + Headers: http.Header{}, // 没有 Content-Length + TempFile: tmpFile, + Offset: 50, + Done: false, + Error: nil, + mu: mu, + cond: cond, + } + + server.lu.Lock() + server.o["/testfile"] = streamObj + server.lu.Unlock() + + // 2. 执行 + req := httptest.NewRequest("GET", "/testfile", nil) + req.Header.Set("Range", "bytes=0-10") + rr := httptest.NewRecorder() + + // 在这个场景下,serveRangedRequest 应该会立即返回错误,无需 goroutine + server.serveRangedRequest(rr, req, "/testfile", time.Time{}) + + // 3. 断言 + // 因为没有 Content-Length, totalSize 会是 0, 导致 "Range not satisfiable" + if rr.Code != http.StatusRequestedRangeNotSatisfiable { + t.Errorf("Expected status %d, got %d", http.StatusRequestedRangeNotSatisfiable, rr.Code) + } + }) +}