package cacheproxy import ( "bytes" "context" "io" "net/http" "net/http/httptest" "os" "path/filepath" "strconv" "strings" "sync" "sync/atomic" "testing" "time" ) // TestTryUpstream 包含了针对 `tryUpstream` 函数的单元测试。 // `tryUpstream` 负责与单个上游服务器通信,是下载逻辑的基础。 func TestTryUpstream(t *testing.T) { // 1.1. 正常获取: // 目的: 验证函数能成功从上游获取数据流。 // 流程: // - 启动一个返回 200 OK 和预设文件内容的模拟服务器。 // - 调用 tryUpstream 指向该服务器。 // - 检查响应状态码、chunks 管道内容和错误。 t.Run("NormalFetch", func(t *testing.T) { // 模拟上游服务器 expectedBody := "hello world" mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(expectedBody)) })) defer mockUpstream.Close() // 创建 Server 实例 server := NewServer(Config{ Upstreams: []Upstream{ {Server: mockUpstream.URL}, }, Misc: MiscConfig{ FirstChunkBytes: 1024, ChunkBytes: 1024, }, }) // 调用被测函数 req, _ := http.NewRequest("GET", "/testfile", nil) resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) // 断言 if err != nil { t.Fatalf("tryUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("tryUpstream() response is nil") } if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode) } if chunks == nil { t.Fatal("tryUpstream() chunks channel is nil") } var receivedBody bytes.Buffer var finalErr error LOOP: for { select { case chunk, ok := <-chunks: if !ok { break LOOP } if chunk.error != nil { finalErr = chunk.error break LOOP } receivedBody.Write(chunk.buffer) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for chunk") } } if finalErr != nil && finalErr != io.EOF { t.Errorf("Expected final error to be io.EOF, got %v", finalErr) } if receivedBody.String() != expectedBody { t.Errorf("Expected body '%s', got '%s'", expectedBody, receivedBody.String()) } }) // 1.2. 上游返回 304 Not Modified: // 目的: 验证当上游返回 304 时,函数能正确处理。 // 流程: // - 模拟上游服务器,当请求头包含 If-Modified-Since 时返回 304。 // - 调用 tryUpstream 并传入 lastModified 时间。 // - 检查响应状态码、chunks 管道和错误。 t.Run("UpstreamReturns304", func(t *testing.T) { // 模拟上游服务器 mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("If-Modified-Since") != "" { w.WriteHeader(http.StatusNotModified) return } w.WriteHeader(http.StatusOK) })) defer mockUpstream.Close() // 创建 Server 实例 server := NewServer(Config{ Upstreams: []Upstream{ {Server: mockUpstream.URL}, }, }) // 调用被测函数 req, _ := http.NewRequest("GET", "/testfile", nil) lastModified := time.Now().UTC() resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, lastModified) // 断言 if err != nil { t.Fatalf("tryUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("tryUpstream() response is nil") } if resp.StatusCode != http.StatusNotModified { t.Errorf("Expected status code %d, got %d", http.StatusNotModified, resp.StatusCode) } if chunks != nil { t.Error("Expected chunks channel to be nil for 304 response") } }) // 1.3. 上游返回错误状态码 (例如 404): // 目的: 验证函数能正确处理上游返回的非成功状态码。 // 流程: // - 模拟上游服务器返回 404 Not Found。 // - 调用 tryUpstream。 // - 检查返回值是否都为 nil 或表示失败。 t.Run("UpstreamReturnsErrorStatus", func(t *testing.T) { // 模拟上游服务器 mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) })) defer mockUpstream.Close() // 创建 Server 实例 server := NewServer(Config{ Upstreams: []Upstream{ {Server: mockUpstream.URL}, }, }) // 调用被测函数 req, _ := http.NewRequest("GET", "/testfile", nil) resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) // 断言 if err != nil { t.Fatalf("tryUpstream() unexpected error: %v", err) } // 对于一个明确的失败(如404),我们期望函数返回 nil, nil, nil // 因为 response checker 默认只接受 200 if resp != nil || chunks != nil { t.Errorf("Expected response and chunks to be nil, got resp=%v, chunks=%v", resp, chunks) } }) // 1.4. 请求超时: // 目的: 验证 context 超时或取消机制能正常工作。 // 流程: // - 模拟一个响应缓慢的上游。 // - 使用一个带短时 timeout 的 context 调用 tryUpstream。 // - 检查返回的错误是否为超时错误 (如 context.DeadlineExceeded)。 t.Run("RequestTimeout", func(t *testing.T) { // 模拟一个响应缓慢的上游服务器 mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(200 * time.Millisecond) // 模拟网络延迟 w.WriteHeader(http.StatusOK) })) defer mockUpstream.Close() // 创建 Server 实例 server := NewServer(Config{ Upstreams: []Upstream{ {Server: mockUpstream.URL}, }, }) // 创建一个带超时的 context ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() // 调用被测函数 req, _ := http.NewRequest("GET", "/testfile", nil) _, _, err := server.tryUpstream(ctx, 0, 0, req, time.Time{}) // 断言 if err == nil { t.Fatal("Expected an error, but got nil") } if err != context.DeadlineExceeded && !isTimeoutError(err) { t.Errorf("Expected error to be context.DeadlineExceeded or a timeout error, got %v", err) } }) // 1.5. 上游响应校验 (Checker): // 目的: 验证 Checker 配置(如检查响应头)能正常工作。 // 流程: // - Case A: 配置 Checker 要求存在特定头,模拟服务器返回此头。 // - Case B: 配置 Checker 要求存在特定头,模拟服务器不返回此头。 // - 检查 Case A 是否成功,Case B 是否失败。 t.Run("ResponseChecker", func(t *testing.T) { // Case A: Checker 匹配,应该成功 t.Run("CheckerSucceeds", func(t *testing.T) { mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/octet-stream") w.WriteHeader(http.StatusOK) w.Write([]byte("data")) })) defer mockUpstream.Close() matchPattern := "application/octet-stream" server := NewServer(Config{ Upstreams: []Upstream{ { Server: mockUpstream.URL, Checkers: []Checker{ { Headers: []HeaderChecker{ {Name: "Content-Type", Match: &matchPattern}, }, }, }, }, }, Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10}, }) req, _ := http.NewRequest("GET", "/testfile", nil) resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } if resp == nil || chunks == nil { t.Fatal("Expected successful response and chunks, but got nil") } // drain the channel for range chunks { } }) // Case B: Checker 失败,应该返回 nil t.Run("CheckerFails", func(t *testing.T) { mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") // 不匹配的 header w.WriteHeader(http.StatusOK) })) defer mockUpstream.Close() matchPattern := "application/octet-stream" server := NewServer(Config{ Upstreams: []Upstream{ { Server: mockUpstream.URL, Checkers: []Checker{ { Headers: []HeaderChecker{ {Name: "Content-Type", Match: &matchPattern}, }, }, }, }, }, }) req, _ := http.NewRequest("GET", "/testfile", nil) resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } if resp != nil || chunks != nil { t.Fatal("Expected response and chunks to be nil due to checker failure") } }) }) // 1.6. 上游重定向: // 目的: 验证 AllowedRedirect 规则是否生效。 // 流程: // - Case A: 模拟上游 A 返回 302 重定向到 B,且 AllowedRedirect 规则匹配 B。 // - Case B: 模拟上游 A 返回 302 重定向到 B,但 AllowedRedirect 规则不匹配 B。 // - 检查 Case A 是否成功从 B 获取数据,Case B 是否失败。 t.Run("UpstreamRedirect", func(t *testing.T) { // 模拟目标上游服务器 targetUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("redirected")) })) defer targetUpstream.Close() // 模拟初始上游服务器,它会重定向到目标服务器 initialUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, targetUpstream.URL, http.StatusFound) })) defer initialUpstream.Close() // Case A: 重定向被允许 t.Run("RedirectAllowed", func(t *testing.T) { allowedPattern := targetUpstream.URL // 允许精确匹配 server := NewServer(Config{ Upstreams: []Upstream{ { Server: initialUpstream.URL, AllowedRedirect: &allowedPattern, }, }, Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10}, }) req, _ := http.NewRequest("GET", "/testfile", nil) resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) if err != nil { t.Fatalf("Expected no error, got %v", err) } if resp == nil || chunks == nil { t.Fatal("Expected successful response and chunks, but got nil") } if resp.StatusCode != http.StatusOK { t.Errorf("Expected status OK, got %d", resp.StatusCode) } // drain the channel for range chunks { } }) // Case B: 重定向被禁止 t.Run("RedirectDisallowed", func(t *testing.T) { disallowedPattern := "http://disallowed.example.com" server := NewServer(Config{ Upstreams: []Upstream{ { Server: initialUpstream.URL, AllowedRedirect: &disallowedPattern, }, }, }) req, _ := http.NewRequest("GET", "/testfile", nil) resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{}) // 当重定向规则不匹配时,自定义的 CheckRedirect 函数会返回 http.ErrUseLastResponse。 // 这使得 http.Client 直接返回 302 响应,而不是跟随跳转。 // 随后,tryUpstream 中的状态码检查器 (默认只接受 200) 会拒绝这个 302 响应。 // 因此,tryUpstream 最终返回 (nil, nil, nil),表示该上游尝试失败。 if err != nil { t.Fatalf("Expected no error for a checker failure, got %v", err) } if resp != nil || chunks != nil { t.Fatalf("Expected response and chunks to be nil for a disallowed redirect, got resp=%v, chunks=%v", resp, chunks) } }) }) } // TestFastestUpstream 包含了针对 `fastestUpstream` 函数的单元测试。 // `fastestUpstream` 实现了“竞争式请求”模式,是项目的核心。 func TestFastestUpstream(t *testing.T) { // 2.1. 竞速成功: // 目的: 验证函数能从多个上游中选出最快的一个,并取消其他慢的请求。 // 流程: // - 启动多个响应延迟不同的模拟服务器。 // - 调用 fastestUpstream。 // - 检查返回的索引是否指向最快的服务器,并验证其他服务器的 context 是否被取消。 t.Run("RaceSuccess", func(t *testing.T) { slowCancelled := make(chan bool, 1) // 慢速服务器,会记录其请求是否被取消 slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { select { case <-time.After(200 * time.Millisecond): w.WriteHeader(http.StatusOK) w.Write([]byte("slow")) case <-r.Context().Done(): slowCancelled <- true return } })) defer slowServer.Close() // 快速服务器 fastServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(50 * time.Millisecond) w.WriteHeader(http.StatusOK) w.Write([]byte("fast")) })) defer fastServer.Close() server := NewServer(Config{ Upstreams: []Upstream{ {Server: slowServer.URL}, // index 0 {Server: fastServer.URL}, // index 1 }, Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10}, }) req, _ := http.NewRequest("GET", "/testfile", nil) upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) if err != nil { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("fastestUpstream() response is nil") } if upstreamIndex != 1 { t.Errorf("Expected fastest upstream index to be 1 (fastServer), got %d", upstreamIndex) } if chunks == nil { t.Fatal("fastestUpstream() chunks channel is nil") } var receivedBody bytes.Buffer LOOP: for { select { case chunk, ok := <-chunks: if !ok { break LOOP } if chunk.error != nil && chunk.error != io.EOF { t.Fatalf("unexpected error from chunk channel: %v", chunk.error) } receivedBody.Write(chunk.buffer) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for chunk") } } if receivedBody.String() != "fast" { t.Errorf("Expected body from fast server, got '%s'", receivedBody.String()) } // 验证慢速服务器的请求是否被取消 select { case cancelled := <-slowCancelled: if !cancelled { t.Error("Expected slow server request to be cancelled, but it wasn't") } case <-time.After(500 * time.Millisecond): t.Error("Timeout waiting for slow server cancellation signal") } }) // 2.2. 部分上游失败: // 目的: 验证当部分上游不可用时,仍能从可用的上游中选择。 // 流程: // - 启动多个模拟服务器,其中部分返回错误或超时。 // - 调用 fastestUpstream。 // - 检查返回的索引是否指向唯一正常的服务器。 t.Run("PartialUpstreamFailure", func(t *testing.T) { // Failing server failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal server error", http.StatusInternalServerError) })) defer failingServer.Close() // Working server workingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("working")) })) defer workingServer.Close() server := NewServer(Config{ Upstreams: []Upstream{ {Server: failingServer.URL}, // index 0 {Server: workingServer.URL}, // index 1 }, Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10}, }) req, _ := http.NewRequest("GET", "/testfile", nil) upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) if err != nil { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("fastestUpstream() response is nil") } if upstreamIndex != 1 { t.Errorf("Expected fastest upstream index to be 1 (workingServer), got %d", upstreamIndex) } if chunks == nil { t.Fatal("fastestUpstream() chunks channel is nil") } var receivedBody bytes.Buffer for chunk := range chunks { if chunk.error != nil && chunk.error != io.EOF { t.Fatalf("unexpected error from chunk channel: %v", chunk.error) } receivedBody.Write(chunk.buffer) } if receivedBody.String() != "working" { t.Errorf("Expected body from working server, got '%s'", receivedBody.String()) } }) // 2.3. 所有上游失败: // 目的: 验证当所有上游都不可用时,函数能优雅地失败。 // 流程: // - 启动多个都返回错误的模拟服务器。 // - 调用 fastestUpstream。 // - 检查返回索引是否为 -1。 t.Run("AllUpstreamsFail", func(t *testing.T) { // Failing server 1 (500) failingServer1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "internal server error", http.StatusInternalServerError) })) defer failingServer1.Close() // Failing server 2 (404) failingServer2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) })) defer failingServer2.Close() server := NewServer(Config{ Upstreams: []Upstream{ {Server: failingServer1.URL}, // index 0 {Server: failingServer2.URL}, // index 1 }, }) req, _ := http.NewRequest("GET", "/testfile", nil) upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) if err != nil { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if upstreamIndex != -1 { t.Errorf("Expected upstream index to be -1, got %d", upstreamIndex) } if resp != nil || chunks != nil { t.Errorf("Expected response and chunks to be nil, got resp=%v, chunks=%v", resp, chunks) } }) // 2.4. 所有上游返回 304: // 目的: 验证当所有上游都认为内容未修改时的行为。 // 流程: // - 启动多个都返回 304 的模拟服务器。 // - 调用 fastestUpstream 并传入 lastModified。 // - 检查返回的响应状态码是否为 304。 t.Run("AllUpstreamsReturn304", func(t *testing.T) { server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotModified) })) defer server1.Close() server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotModified) })) defer server2.Close() server := NewServer(Config{ Upstreams: []Upstream{ {Server: server1.URL}, {Server: server2.URL}, }, }) req, _ := http.NewRequest("GET", "/testfile", nil) lastModified := time.Now().UTC() upstreamIndex, resp, chunks, err := server.fastestUpstream(req, lastModified) if err != nil { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("fastestUpstream() response is nil") } if upstreamIndex == -1 { t.Error("Expected a valid upstream index, got -1") } if resp.StatusCode != http.StatusNotModified { t.Errorf("Expected status code %d, got %d", http.StatusNotModified, resp.StatusCode) } if chunks != nil { t.Error("Expected chunks channel to be nil for 304 response") } }) // 2.5. 优先级组 (PriorityGroups): // 目的: 验证函数是否优先尝试高优先级的上游。 // 流程: // - 配置一个高优先级但响应慢的上游和一个低优先级但响应快的上游。 // - 调用 fastestUpstream。 // - 检查返回的索引是否指向高优先级的服务器。 t.Run("PriorityGroups", func(t *testing.T) { // Slow, high priority server slowHighPriorityServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(100 * time.Millisecond) w.WriteHeader(http.StatusOK) w.Write([]byte("slow-high")) })) defer slowHighPriorityServer.Close() // Fast, low priority server fastLowPriorityServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("fast-low")) })) defer fastLowPriorityServer.Close() server := NewServer(Config{ Upstreams: []Upstream{ { // index 0 Server: fastLowPriorityServer.URL, PriorityGroups: []UpstreamPriorityGroup{ {Match: ".*", Priority: 1}, }, }, { // index 1 Server: slowHighPriorityServer.URL, PriorityGroups: []UpstreamPriorityGroup{ {Match: ".*", Priority: 10}, }, }, }, Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10}, }) req, _ := http.NewRequest("GET", "/testfile", nil) upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) if err != nil { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("fastestUpstream() response is nil") } if upstreamIndex != 1 { t.Errorf("Expected upstream index to be 1 (slow-high), got %d", upstreamIndex) } if chunks == nil { t.Fatal("fastestUpstream() chunks channel is nil") } var receivedBody bytes.Buffer for chunk := range chunks { if chunk.error != nil && chunk.error != io.EOF { t.Fatalf("unexpected error from chunk channel: %v", chunk.error) } receivedBody.Write(chunk.buffer) } if receivedBody.String() != "slow-high" { t.Errorf("Expected body from slow-high server, got '%s'", receivedBody.String()) } }) // 2.6. 优先级组失败切换: // 目的: 验证高优先级组全部失败后,是否会切换到次高优先级组。 // 流程: // - 配置高优先级上游全部失败,低优先级上游正常。 // - 调用 fastestUpstream。 // - 检查返回的索引是否指向低优先级的服务器。 t.Run("PriorityGroupFailover", func(t *testing.T) { // Failing, high priority server failingHighPriorityServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "service unavailable", http.StatusServiceUnavailable) })) defer failingHighPriorityServer.Close() // Working, low priority server workingLowPriorityServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("working-low")) })) defer workingLowPriorityServer.Close() server := NewServer(Config{ Upstreams: []Upstream{ { // index 0 Server: workingLowPriorityServer.URL, PriorityGroups: []UpstreamPriorityGroup{ {Match: ".*", Priority: 1}, }, }, { // index 1 Server: failingHighPriorityServer.URL, PriorityGroups: []UpstreamPriorityGroup{ {Match: ".*", Priority: 10}, }, }, }, Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10}, }) req, _ := http.NewRequest("GET", "/testfile", nil) upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{}) if err != nil { t.Fatalf("fastestUpstream() unexpected error: %v", err) } if resp == nil { t.Fatal("fastestUpstream() response is nil") } if upstreamIndex != 0 { t.Errorf("Expected upstream index to be 0 (working-low), got %d", upstreamIndex) } if chunks == nil { t.Fatal("fastestUpstream() chunks channel is nil") } var receivedBody bytes.Buffer for chunk := range chunks { if chunk.error != nil && chunk.error != io.EOF { t.Fatalf("unexpected error from chunk channel: %v", chunk.error) } receivedBody.Write(chunk.buffer) } if receivedBody.String() != "working-low" { t.Errorf("Expected body from working-low server, got '%s'", receivedBody.String()) } }) } // TestCheckLocal 包含了针对 `checkLocal` 函数的单元测试。 // `checkLocal` 负责检查本地缓存的状态。 func TestCheckLocal(t *testing.T) { // 3.1. 缓存不存在: // 目的: 验证当文件在本地不存在时,返回正确状态。 // 流程: // - 提供一个不存在的文件路径调用 checkLocal。 // - 检查返回状态是否为 localNotExists。 t.Run("CacheNotExists", func(t *testing.T) { server := NewServer(Config{}) tmpDir := t.TempDir() nonExistentFile := filepath.Join(tmpDir, "non-existent-file") status, _, err := server.checkLocal(nil, nil, nonExistentFile) if err != nil { t.Fatalf("checkLocal() unexpected error: %v", err) } if status != localNotExists { t.Errorf("Expected status to be localNotExists, got %v", status) } }) // 3.2. 缓存存在且有效: // 目的: 验证当文件存在且未过期时,返回正确状态。 // 流程: // - 创建一个新文件,并配置一个较长的 RefreshAfter。 // - 调用 checkLocal。 // - 检查返回状态是否为 localExists。 t.Run("CacheExistsAndValid", func(t *testing.T) { server := NewServer(Config{ Cache: Cache{ RefreshAfter: time.Hour, }, }) tmpDir := t.TempDir() validFile := filepath.Join(tmpDir, "valid-file") if err := os.WriteFile(validFile, []byte("data"), 0644); err != nil { t.Fatalf("Failed to create test file: %v", err) } status, _, err := server.checkLocal(nil, nil, validFile) if err != nil { t.Fatalf("checkLocal() unexpected error: %v", err) } if status != localExists { t.Errorf("Expected status to be localExists, got %v", status) } }) // 3.3. 缓存存在但已过期: // 目的: 验证当文件存在但已过期时,返回“需要检查更新”的状态。 // 流程: // - 创建一个旧文件,或设置一个极短的 RefreshAfter。 // - 调用 checkLocal。 // - 检查返回状态是否为 localExistsButNeedHead。 t.Run("CacheExistsButExpired", func(t *testing.T) { server := NewServer(Config{ Cache: Cache{ RefreshAfter: time.Nanosecond, }, }) tmpDir := t.TempDir() expiredFile := filepath.Join(tmpDir, "expired-file") if err := os.WriteFile(expiredFile, []byte("data"), 0644); err != nil { t.Fatalf("Failed to create test file: %v", err) } // 确保文件的修改时间在过去 time.Sleep(2 * time.Nanosecond) status, _, err := server.checkLocal(nil, nil, expiredFile) if err != nil { t.Fatalf("checkLocal() unexpected error: %v", err) } if status != localExistsButNeedHead { t.Errorf("Expected status to be localExistsButNeedHead, got %v", status) } }) // 3.4. 缓存策略 "always" 和 "never": // 目的: 验证 refresh: "always" 和 "never" 策略是否按预期工作。 // 流程: // - Case "always": 配置匹配的 always 策略,检查是否返回 localExistsButNeedHead。 // - Case "never": 配置匹配的 never 策略,检查是否返回 localExists。 t.Run("CachePolicyAlwaysAndNever", func(t *testing.T) { tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "test-policy-file.txt") if err := os.WriteFile(testFile, []byte("data"), 0644); err != nil { t.Fatalf("Failed to create test file: %v", err) } // Case "always" t.Run("Always", func(t *testing.T) { server := NewServer(Config{ Cache: Cache{ Policies: []CachePolicyOnPath{ {Match: `\.txt$`, RefreshAfter: "always"}, }, }, }) status, _, err := server.checkLocal(nil, nil, testFile) if err != nil { t.Fatalf("checkLocal() unexpected error: %v", err) } if status != localExistsButNeedHead { t.Errorf("Expected status to be localExistsButNeedHead for 'always' policy, got %v", status) } }) // Case "never" t.Run("Never", func(t *testing.T) { server := NewServer(Config{ Cache: Cache{ RefreshAfter: time.Nanosecond, // global policy is expired Policies: []CachePolicyOnPath{ {Match: `\.txt$`, RefreshAfter: "never"}, }, }, }) // Make sure file is "old" enough to trigger the global policy if not for the "never" rule time.Sleep(2 * time.Nanosecond) status, _, err := server.checkLocal(nil, nil, testFile) if err != nil { t.Fatalf("checkLocal() unexpected error: %v", err) } if status != localExists { t.Errorf("Expected status to be localExists for 'never' policy, got %v", status) } }) }) // 3.5. 文件权限错误处理: // 目的: 验证当因权限问题无法访问文件时,HandleRequestWithCache 能返回 403 Forbidden。 // 流程: // - 创建一个目录,并移除其执行权限,使其内部的文件无法被 Stat。 // - 通过 HandleRequestWithCache 请求该文件。 // - 检查 HTTP 响应状态码是否为 403 Forbidden。 // 注意: 这本质上是一个集成测试,验证了 checkLocal 的权限错误被正确传递并处理。 t.Run("FilePermissionError", func(t *testing.T) { // This test might not be reliable on Windows, as file permissions work differently. if os.Getuid() == 0 { t.Skip("Skipping test: running as root, cannot test permission errors effectively.") } tmpDir := t.TempDir() permDir := filepath.Join(tmpDir, "perm_dir") // 1. Create directory with full permissions for the owner if err := os.Mkdir(permDir, 0700); err != nil { t.Fatalf("Failed to create test directory: %v", err) } defer os.Chmod(permDir, 0755) // Cleanup // 2. Create a file inside it unreadableFilePath := filepath.Join(permDir, "file") if err := os.WriteFile(unreadableFilePath, []byte("data"), 0644); err != nil { t.Fatalf("Failed to create test file: %v", err) } // 3. Revoke execute permission on the directory, making the file inside un-stat-able if err := os.Chmod(permDir, 0600); err != nil { t.Fatalf("Failed to chmod perm_dir: %v", err) } server := NewServer(Config{ Storage: Storage{ Local: &LocalStorage{ Path: tmpDir, }, }, }) req := httptest.NewRequest("GET", "/perm_dir/file", nil) rr := httptest.NewRecorder() server.HandleRequestWithCache(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("Expected status code %d, got %d", http.StatusForbidden, rr.Code) } }) } // TestIntegration_HandleRequestWithCache 包含了集成测试。 // 这部分测试的是核心的业务流程,将多个单元组合起来进行验证。 func TestIntegration_HandleRequestWithCache(t *testing.T) { // 4.1. 首次请求(冷缓存): // 目的: 验证首次请求能成功从上游下载、响应客户端并写入缓存。 // 流程: // - 保证缓存目录为空。 // - 发起 HTTP 请求到 HandleRequestWithCache。 // - 检查客户端响应和本地缓存文件。 t.Run("FirstRequestColdCache", func(t *testing.T) { expectedBody := "from upstream" mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", "13") w.WriteHeader(http.StatusOK) w.Write([]byte(expectedBody)) })) defer mockUpstream.Close() tmpDir := t.TempDir() server := NewServer(Config{ Upstreams: []Upstream{ {Server: mockUpstream.URL}, }, Storage: Storage{ Local: &LocalStorage{ Path: tmpDir, }, }, Misc: MiscConfig{ FirstChunkBytes: 1024, ChunkBytes: 1024, }, }) req := httptest.NewRequest("GET", "/test.txt", nil) rr := httptest.NewRecorder() server.HandleRequestWithCache(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) } if rr.Body.String() != expectedBody { t.Errorf("Expected body '%s', got '%s'", expectedBody, rr.Body.String()) } // 等待后台缓存写入完成 time.Sleep(100 * time.Millisecond) cachedFilePath := filepath.Join(tmpDir, "test.txt") cachedData, err := os.ReadFile(cachedFilePath) if err != nil { t.Fatalf("Failed to read cached file: %v", err) } if string(cachedData) != expectedBody { t.Errorf("Expected cached file content '%s', got '%s'", expectedBody, string(cachedData)) } }) // 4.2. 缓存命中: // 目的: 验证对已缓存文件的请求,能直接从本地提供服务。 // 流程: // - 预先在缓存目录中放置文件。 // - 模拟一个会失败的上游(确保没有网络调用)。 // - 发起请求并检查响应。 t.Run("CacheHit", func(t *testing.T) { cachedBody := "from cache" // Mock upstream that will fail if called mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("Upstream should not be called on cache hit") w.WriteHeader(http.StatusInternalServerError) })) defer mockUpstream.Close() tmpDir := t.TempDir() // Pre-populate the cache cachedFilePath := filepath.Join(tmpDir, "cached.txt") if err := os.WriteFile(cachedFilePath, []byte(cachedBody), 0644); err != nil { t.Fatalf("Failed to write pre-populated cache file: %v", err) } server := NewServer(Config{ Upstreams: []Upstream{ {Server: mockUpstream.URL}, }, Storage: Storage{ Local: &LocalStorage{ Path: tmpDir, }, }, Cache: Cache{ RefreshAfter: time.Hour, // Ensure cache is valid }, }) req := httptest.NewRequest("GET", "/cached.txt", nil) rr := httptest.NewRecorder() server.HandleRequestWithCache(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) } if rr.Body.String() != cachedBody { t.Errorf("Expected body '%s', got '%s'", cachedBody, rr.Body.String()) } }) // 4.3. 并发请求同一文件: // 目的: 验证多个并发请求同一文件时,只会触发一次上游下载。 // 流程: // - 使用 goroutine 模拟并发请求。 // - 模拟一个能记录调用次数的上游。 // - 检查所有客户端的响应和上游调用次数。 t.Run("ConcurrentRequestsSameFile", func(t *testing.T) { expectedBody := "concurrent download" var upstreamRequestCount int32 mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&upstreamRequestCount, 1) // Simulate a slow download time.Sleep(200 * time.Millisecond) w.Header().Set("Content-Length", "19") w.WriteHeader(http.StatusOK) w.Write([]byte(expectedBody)) })) defer mockUpstream.Close() tmpDir := t.TempDir() server := NewServer(Config{ Upstreams: []Upstream{ {Server: mockUpstream.URL}, }, Storage: Storage{ Local: &LocalStorage{ Path: tmpDir, }, }, Misc: MiscConfig{ FirstChunkBytes: 1024, ChunkBytes: 1024, }, }) var wg sync.WaitGroup concurrentRequests := 5 wg.Add(concurrentRequests) for i := 0; i < concurrentRequests; i++ { go func() { defer wg.Done() req := httptest.NewRequest("GET", "/concurrent.txt", nil) rr := httptest.NewRecorder() server.HandleRequestWithCache(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) } if rr.Body.String() != expectedBody { t.Errorf("Expected body '%s', got '%s'", expectedBody, rr.Body.String()) } }() } wg.Wait() finalCount := atomic.LoadInt32(&upstreamRequestCount) if finalCount != 1 { t.Errorf("Expected upstream to be called once, but it was called %d times", finalCount) } }) // 4.4. 上游下载中途失败: // 目的: 验证当上游下载中断时,不缓存不完整的文件。 // 流程: // - 模拟一个只发送部分数据就断开连接的上游。 // - 检查本地没有生成(或留存)不完整的缓存文件。 t.Run("UpstreamDownloadFailsMidway", func(t *testing.T) { mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", "100") // Pretend we'll send 100 bytes w.WriteHeader(http.StatusOK) w.Write([]byte("only part of the data")) // Send only a part // And then abruptly close the connection conn, _, _ := w.(http.Hijacker).Hijack() conn.Close() })) defer mockUpstream.Close() tmpDir := t.TempDir() server := NewServer(Config{ Upstreams: []Upstream{ {Server: mockUpstream.URL}, }, Storage: Storage{ Local: &LocalStorage{ Path: tmpDir, }, }, Misc: MiscConfig{ FirstChunkBytes: 10, ChunkBytes: 10, }, }) req := httptest.NewRequest("GET", "/incomplete.txt", nil) rr := httptest.NewRecorder() server.HandleRequestWithCache(rr, req) // The client will receive a response with a body that unexpectedly ends. // The key is to check that the incomplete file is not cached. time.Sleep(100 * time.Millisecond) // Wait for async cache write to fail cachedFilePath := filepath.Join(tmpDir, "incomplete.txt") if _, err := os.Stat(cachedFilePath); !os.IsNotExist(err) { t.Errorf("Expected incomplete file not to be cached, but it exists.") } }) // 4.5. Range 请求(冷缓存): // 目的: 验证 Range 请求在冷缓存时,能先完整下载再响应。 // 流程: // - 发起一个带 Range 头的 HTTP 请求。 // - 检查文件是否被完整下载,以及客户端是否收到 206 响应。 t.Run("RangeRequestColdCache", func(t *testing.T) { fullBody := "this is the full content for range request" mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // The upstream should not receive the Range header, as the proxy should fetch the full file. if r.Header.Get("Range") != "" { t.Error("Upstream received unexpected Range header") } w.Header().Set("Content-Length", strconv.Itoa(len(fullBody))) w.WriteHeader(http.StatusOK) w.Write([]byte(fullBody)) })) defer mockUpstream.Close() tmpDir := t.TempDir() server := NewServer(Config{ Upstreams: []Upstream{ {Server: mockUpstream.URL}, }, Storage: Storage{ Local: &LocalStorage{ Path: tmpDir, }, }, Misc: MiscConfig{ FirstChunkBytes: 1024, ChunkBytes: 1024, }, }) req := httptest.NewRequest("GET", "/range.txt", nil) req.Header.Set("Range", "bytes=5-15") // Request a sub-range rr := httptest.NewRecorder() server.HandleRequestWithCache(rr, req) if rr.Code != http.StatusPartialContent { t.Errorf("Expected status code %d, got %d", http.StatusPartialContent, rr.Code) } expectedPartialBody := "is the full" // bytes from index 5 to 15 if rr.Body.String() != expectedPartialBody { t.Errorf("Expected partial body '%s', got '%s'", expectedPartialBody, rr.Body.String()) } expectedContentRange := "bytes 5-15/42" if rr.Header().Get("Content-Range") != expectedContentRange { t.Errorf("Expected Content-Range header '%s', got '%s'", expectedContentRange, rr.Header().Get("Content-Range")) } // Wait for the background caching to complete time.Sleep(100 * time.Millisecond) // Verify that the full file was cached cachedFilePath := filepath.Join(tmpDir, "range.txt") cachedData, err := os.ReadFile(cachedFilePath) if err != nil { t.Fatalf("Failed to read cached file: %v", err) } if string(cachedData) != fullBody { t.Errorf("Expected full file to be cached, content was '%s'", string(cachedData)) } }) // 4.6. X-Accel-Redirect: // 目的: 验证 X-Accel-Redirect 功能是否按预期工作。 // 流程: // - 启用 Accel 配置并发起带特定头的请求。 // - 检查响应头是否包含正确的重定向路径。 t.Run("XAccelRedirect", func(t *testing.T) { // 1. Setup cachedBody := "this will not be sent" tmpDir := t.TempDir() // Pre-populate the cache cachedFilePath := filepath.Join(tmpDir, "accel.txt") if err := os.WriteFile(cachedFilePath, []byte(cachedBody), 0644); err != nil { t.Fatalf("Failed to write pre-populated cache file: %v", err) } server := NewServer(Config{ // No upstreams needed as we are hitting the cache Storage: Storage{ Local: &LocalStorage{ Path: tmpDir, Accel: Accel{ EnableByHeader: "X-Sendfile-Type", RespondWithHeaders: []string{"X-Accel-Redirect"}, }, }, }, Cache: Cache{ RefreshAfter: time.Hour, // Ensure cache is valid }, }) // 2. Make request req := httptest.NewRequest("GET", "/accel.txt", nil) // This header enables the accel logic and provides the base path for the redirect req.Header.Set("X-Sendfile-Type", "/internal_redirect/") rr := httptest.NewRecorder() server.HandleRequestWithCache(rr, req) // 3. Assertions if rr.Code != http.StatusOK { t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) } if rr.Body.Len() > 0 { t.Errorf("Expected empty body for accel redirect, but got body: %s", rr.Body.String()) } expectedRedirectPath := "/internal_redirect/accel.txt" redirectHeader := rr.Header().Get("X-Accel-Redirect") if redirectHeader != expectedRedirectPath { t.Errorf("Expected X-Accel-Redirect header to be '%s', got '%s'", expectedRedirectPath, redirectHeader) } }) // 4.7. 路径穿越攻击: // 目的: 验证系统能阻止非法的路径访问。 // 流程: // - 发起一个包含 ../ 的恶意路径请求。 // - 检查是否返回 400 Bad Request。 t.Run("PathTraversalAttack", func(t *testing.T) { tmpDir := t.TempDir() server := NewServer(Config{ Storage: Storage{ Local: &LocalStorage{ Path: tmpDir, }, }, }) req := httptest.NewRequest("GET", "/../../../../etc/passwd", nil) rr := httptest.NewRecorder() server.HandleRequestWithCache(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("Expected status code %d for path traversal, got %d", http.StatusBadRequest, rr.Code) } if !strings.Contains(rr.Body.String(), "crossing local directory boundary") { t.Errorf("Expected error message about crossing boundary, got: %s", rr.Body.String()) } }) } func isTimeoutError(err error) bool { e, ok := err.(interface{ Timeout() bool }) return ok && e.Timeout() }