Files
cache-proxy/server_test.go
guochao 80560f7408
All checks were successful
build container / build-container (push) Successful in 29m33s
run go test / test (push) Successful in 26m15s
cline optimization on range request and memory usage
2025-06-10 17:44:44 +08:00

1716 lines
54 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cacheproxy
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// TestTryUpstream 包含了针对 `tryUpstream` 函数的单元测试。
// `tryUpstream` 负责与单个上游服务器通信,是下载逻辑的基础。
func TestTryUpstream(t *testing.T) {
// 1.1. 正常获取:
// 目的: 验证函数能成功从上游获取数据流。
// 流程:
// - 启动一个返回 200 OK 和预设文件内容的模拟服务器。
// - 调用 tryUpstream 指向该服务器。
// - 检查响应状态码、chunks 管道内容和错误。
t.Run("NormalFetch", func(t *testing.T) {
// 模拟上游服务器
expectedBody := "hello world"
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(expectedBody))
}))
defer mockUpstream.Close()
// 创建 Server 实例
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
Misc: MiscConfig{
FirstChunkBytes: 1024,
ChunkBytes: 1024,
},
})
// 调用被测函数
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
// 断言
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("tryUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("tryUpstream() response is nil")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
if firstChunk == nil {
t.Fatal("tryUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != expectedBody {
t.Errorf("Expected body '%s', got '%s'", expectedBody, receivedBody.String())
}
})
// 1.2. 上游返回 304 Not Modified:
// 目的: 验证当上游返回 304 时,函数能正确处理。
// 流程:
// - 模拟上游服务器,当请求头包含 If-Modified-Since 时返回 304。
// - 调用 tryUpstream 并传入 lastModified 时间。
// - 检查响应状态码、chunks 管道和错误。
t.Run("UpstreamReturns304", func(t *testing.T) {
// 模拟上游服务器
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("If-Modified-Since") != "" {
w.WriteHeader(http.StatusNotModified)
return
}
w.WriteHeader(http.StatusOK)
}))
defer mockUpstream.Close()
// 创建 Server 实例
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
})
// 调用被测函数
req, _ := http.NewRequest("GET", "/testfile", nil)
lastModified := time.Now().UTC()
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, lastModified)
// 断言
if err != nil {
t.Fatalf("tryUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("tryUpstream() response is nil")
}
if resp.StatusCode != http.StatusNotModified {
t.Errorf("Expected status code %d, got %d", http.StatusNotModified, resp.StatusCode)
}
if firstChunk != nil {
t.Error("Expected firstChunk to be nil for 304 response")
}
})
// 1.3. 上游返回错误状态码 (例如 404):
// 目的: 验证函数能正确处理上游返回的非成功状态码。
// 流程:
// - 模拟上游服务器返回 404 Not Found。
// - 调用 tryUpstream。
// - 检查返回值是否都为 nil 或表示失败。
t.Run("UpstreamReturnsErrorStatus", func(t *testing.T) {
// 模拟上游服务器
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer mockUpstream.Close()
// 创建 Server 实例
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
})
// 调用被测函数
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
// 断言
if err != nil {
t.Fatalf("tryUpstream() unexpected error: %v", err)
}
// 对于一个明确的失败如404我们期望函数返回 nil, nil, nil
// 因为 response checker 默认只接受 200
if resp != nil || firstChunk != nil {
t.Errorf("Expected response and firstChunk to be nil, got resp=%v, chunks=%v", resp, firstChunk)
}
})
// 1.4. 请求超时:
// 目的: 验证 context 超时或取消机制能正常工作。
// 流程:
// - 模拟一个响应缓慢的上游。
// - 使用一个带短时 timeout 的 context 调用 tryUpstream。
// - 检查返回的错误是否为超时错误 (如 context.DeadlineExceeded)。
t.Run("RequestTimeout", func(t *testing.T) {
// 模拟一个响应缓慢的上游服务器
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(200 * time.Millisecond) // 模拟网络延迟
w.WriteHeader(http.StatusOK)
}))
defer mockUpstream.Close()
// 创建 Server 实例
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
})
// 创建一个带超时的 context
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
// 调用被测函数
req, _ := http.NewRequest("GET", "/testfile", nil)
_, _, err := server.tryUpstream(ctx, 0, 0, req, time.Time{})
// 断言
if err == nil {
t.Fatal("Expected an error, but got nil")
}
if err != context.DeadlineExceeded && !isTimeoutError(err) {
t.Errorf("Expected error to be context.DeadlineExceeded or a timeout error, got %v", err)
}
})
// 1.5. 上游响应校验 (Checker):
// 目的: 验证 Checker 配置(如检查响应头)能正常工作。
// 流程:
// - Case A: 配置 Checker 要求存在特定头,模拟服务器返回此头。
// - Case B: 配置 Checker 要求存在特定头,模拟服务器不返回此头。
// - 检查 Case A 是否成功Case B 是否失败。
t.Run("ResponseChecker", func(t *testing.T) {
// Case A: Checker 匹配,应该成功
t.Run("CheckerSucceeds", func(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/octet-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte("data"))
}))
defer mockUpstream.Close()
matchPattern := "application/octet-stream"
server := NewServer(Config{
Upstreams: []Upstream{
{
Server: mockUpstream.URL,
Checkers: []Checker{
{
Headers: []HeaderChecker{
{Name: "Content-Type", Match: &matchPattern},
},
},
},
},
},
Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("Expected no error, got %v", err)
}
if resp == nil || firstChunk == nil {
t.Fatal("Expected successful response and firstChunk, but got nil")
}
resp.Body.Close()
})
// Case B: Checker 失败,应该返回 nil
t.Run("CheckerFails", func(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html") // 不匹配的 header
w.WriteHeader(http.StatusOK)
}))
defer mockUpstream.Close()
matchPattern := "application/octet-stream"
server := NewServer(Config{
Upstreams: []Upstream{
{
Server: mockUpstream.URL,
Checkers: []Checker{
{
Headers: []HeaderChecker{
{Name: "Content-Type", Match: &matchPattern},
},
},
},
},
},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if resp != nil || firstChunk != nil {
t.Fatal("Expected response and firstChunk to be nil due to checker failure")
}
})
})
// 1.6. 上游重定向:
// 目的: 验证 AllowedRedirect 规则是否生效。
// 流程:
// - Case A: 模拟上游 A 返回 302 重定向到 B且 AllowedRedirect 规则匹配 B。
// - Case B: 模拟上游 A 返回 302 重定向到 B但 AllowedRedirect 规则不匹配 B。
// - 检查 Case A 是否成功从 B 获取数据Case B 是否失败。
t.Run("UpstreamRedirect", func(t *testing.T) {
// 模拟目标上游服务器
targetUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("redirected"))
}))
defer targetUpstream.Close()
// 模拟初始上游服务器,它会重定向到目标服务器
initialUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, targetUpstream.URL, http.StatusFound)
}))
defer initialUpstream.Close()
// Case A: 重定向被允许
t.Run("RedirectAllowed", func(t *testing.T) {
allowedPattern := targetUpstream.URL // 允许精确匹配
server := NewServer(Config{
Upstreams: []Upstream{
{
Server: initialUpstream.URL,
AllowedRedirect: &allowedPattern,
},
},
Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("Expected no error, got %v", err)
}
if resp == nil || firstChunk == nil {
t.Fatal("Expected successful response and firstChunk, but got nil")
}
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status OK, got %d", resp.StatusCode)
}
resp.Body.Close()
})
// Case B: 重定向被禁止
t.Run("RedirectDisallowed", func(t *testing.T) {
disallowedPattern := "http://disallowed.example.com"
server := NewServer(Config{
Upstreams: []Upstream{
{
Server: initialUpstream.URL,
AllowedRedirect: &disallowedPattern,
},
},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
// 当重定向规则不匹配时,自定义的 CheckRedirect 函数会返回 http.ErrUseLastResponse。
// 这使得 http.Client 直接返回 302 响应,而不是跟随跳转。
// 随后tryUpstream 中的状态码检查器 (默认只接受 200) 会拒绝这个 302 响应。
// 因此tryUpstream 最终返回 (nil, nil, nil),表示该上游尝试失败。
if err != nil {
t.Fatalf("Expected no error for a checker failure, got %v", err)
}
if resp != nil || firstChunk != nil {
t.Fatalf("Expected response and firstChunk to be nil for a disallowed redirect, got resp=%v, chunks=%v", resp, firstChunk)
}
})
})
}
// TestFastestUpstream 包含了针对 `fastestUpstream` 函数的单元测试。
// `fastestUpstream` 实现了“竞争式请求”模式,是项目的核心。
func TestFastestUpstream(t *testing.T) {
// 2.1. 竞速成功:
// 目的: 验证函数能从多个上游中选出最快的一个,并取消其他慢的请求。
// 流程:
// - 启动多个响应延迟不同的模拟服务器。
// - 调用 fastestUpstream。
// - 检查返回的索引是否指向最快的服务器,并验证其他服务器的 context 是否被取消。
t.Run("RaceSuccess", func(t *testing.T) {
slowCancelled := make(chan bool, 1)
// 慢速服务器,会记录其请求是否被取消
slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-time.After(200 * time.Millisecond):
w.WriteHeader(http.StatusOK)
w.Write([]byte("slow"))
case <-r.Context().Done():
slowCancelled <- true
return
}
}))
defer slowServer.Close()
// 快速服务器
fastServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(50 * time.Millisecond)
w.WriteHeader(http.StatusOK)
w.Write([]byte("fast"))
}))
defer fastServer.Close()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: slowServer.URL}, // index 0
{Server: fastServer.URL}, // index 1
},
Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("fastestUpstream() response is nil")
}
defer resp.Body.Close()
if upstreamIndex != 1 {
t.Errorf("Expected fastest upstream index to be 1 (fastServer), got %d", upstreamIndex)
}
if firstChunk == nil {
t.Fatal("fastestUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != "fast" {
t.Errorf("Expected body from fast server, got '%s'", receivedBody.String())
}
// 验证慢速服务器的请求是否被取消
select {
case cancelled := <-slowCancelled:
if !cancelled {
t.Error("Expected slow server request to be cancelled, but it wasn't")
}
case <-time.After(500 * time.Millisecond):
t.Error("Timeout waiting for slow server cancellation signal")
}
})
// 2.2. 部分上游失败:
// 目的: 验证当部分上游不可用时,仍能从可用的上游中选择。
// 流程:
// - 启动多个模拟服务器,其中部分返回错误或超时。
// - 调用 fastestUpstream。
// - 检查返回的索引是否指向唯一正常的服务器。
t.Run("PartialUpstreamFailure", func(t *testing.T) {
// Failing server
failingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal server error", http.StatusInternalServerError)
}))
defer failingServer.Close()
// Working server
workingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("working"))
}))
defer workingServer.Close()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: failingServer.URL}, // index 0
{Server: workingServer.URL}, // index 1
},
Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("fastestUpstream() response is nil")
}
defer resp.Body.Close()
if upstreamIndex != 1 {
t.Errorf("Expected fastest upstream index to be 1 (workingServer), got %d", upstreamIndex)
}
if firstChunk == nil {
t.Fatal("fastestUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != "working" {
t.Errorf("Expected body from working server, got '%s'", receivedBody.String())
}
})
// 2.3. 所有上游失败:
// 目的: 验证当所有上游都不可用时,函数能优雅地失败。
// 流程:
// - 启动多个都返回错误的模拟服务器。
// - 调用 fastestUpstream。
// - 检查返回索引是否为 -1。
t.Run("AllUpstreamsFail", func(t *testing.T) {
// Failing server 1 (500)
failingServer1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "internal server error", http.StatusInternalServerError)
}))
defer failingServer1.Close()
// Failing server 2 (404)
failingServer2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer failingServer2.Close()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: failingServer1.URL}, // index 0
{Server: failingServer2.URL}, // index 1
},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if upstreamIndex != -1 {
t.Errorf("Expected upstream index to be -1, got %d", upstreamIndex)
}
if resp != nil || firstChunk != nil {
t.Errorf("Expected response and firstChunk to be nil, got resp=%v, chunks=%v", resp, firstChunk)
}
})
// 2.4. 所有上游返回 304:
// 目的: 验证当所有上游都认为内容未修改时的行为。
// 流程:
// - 启动多个都返回 304 的模拟服务器。
// - 调用 fastestUpstream 并传入 lastModified。
// - 检查返回的响应状态码是否为 304。
t.Run("AllUpstreamsReturn304", func(t *testing.T) {
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotModified)
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotModified)
}))
defer server2.Close()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: server1.URL},
{Server: server2.URL},
},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
lastModified := time.Now().UTC()
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, lastModified)
if err != nil {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("fastestUpstream() response is nil")
}
if upstreamIndex == -1 {
t.Error("Expected a valid upstream index, got -1")
}
if resp.StatusCode != http.StatusNotModified {
t.Errorf("Expected status code %d, got %d", http.StatusNotModified, resp.StatusCode)
}
if firstChunk != nil {
t.Error("Expected firstChunk to be nil for 304 response")
}
})
// 2.5. 优先级组 (PriorityGroups):
// 目的: 验证函数是否优先尝试高优先级的上游。
// 流程:
// - 配置一个高优先级但响应慢的上游和一个低优先级但响应快的上游。
// - 调用 fastestUpstream。
// - 检查返回的索引是否指向高优先级的服务器。
t.Run("PriorityGroups", func(t *testing.T) {
// Slow, high priority server
slowHighPriorityServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
w.Write([]byte("slow-high"))
}))
defer slowHighPriorityServer.Close()
// Fast, low priority server
fastLowPriorityServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("fast-low"))
}))
defer fastLowPriorityServer.Close()
server := NewServer(Config{
Upstreams: []Upstream{
{ // index 0
Server: fastLowPriorityServer.URL,
PriorityGroups: []UpstreamPriorityGroup{
{Match: ".*", Priority: 1},
},
},
{ // index 1
Server: slowHighPriorityServer.URL,
PriorityGroups: []UpstreamPriorityGroup{
{Match: ".*", Priority: 10},
},
},
},
Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("fastestUpstream() response is nil")
}
defer resp.Body.Close()
if upstreamIndex != 1 {
t.Errorf("Expected upstream index to be 1 (slow-high), got %d", upstreamIndex)
}
if firstChunk == nil {
t.Fatal("fastestUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != "slow-high" {
t.Errorf("Expected body from slow-high server, got '%s'", receivedBody.String())
}
})
// 2.6. 优先级组失败切换:
// 目的: 验证高优先级组全部失败后,是否会切换到次高优先级组。
// 流程:
// - 配置高优先级上游全部失败,低优先级上游正常。
// - 调用 fastestUpstream。
// - 检查返回的索引是否指向低优先级的服务器。
t.Run("PriorityGroupFailover", func(t *testing.T) {
// Failing, high priority server
failingHighPriorityServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
}))
defer failingHighPriorityServer.Close()
// Working, low priority server
workingLowPriorityServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("working-low"))
}))
defer workingLowPriorityServer.Close()
server := NewServer(Config{
Upstreams: []Upstream{
{ // index 0
Server: workingLowPriorityServer.URL,
PriorityGroups: []UpstreamPriorityGroup{
{Match: ".*", Priority: 1},
},
},
{ // index 1
Server: failingHighPriorityServer.URL,
PriorityGroups: []UpstreamPriorityGroup{
{Match: ".*", Priority: 10},
},
},
},
Misc: MiscConfig{FirstChunkBytes: 10, ChunkBytes: 10},
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("fastestUpstream() response is nil")
}
defer resp.Body.Close()
if upstreamIndex != 0 {
t.Errorf("Expected upstream index to be 0 (working-low), got %d", upstreamIndex)
}
if firstChunk == nil {
t.Fatal("fastestUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != "working-low" {
t.Errorf("Expected body from working-low server, got '%s'", receivedBody.String())
}
})
}
// TestCheckLocal 包含了针对 `checkLocal` 函数的单元测试。
// `checkLocal` 负责检查本地缓存的状态。
func TestCheckLocal(t *testing.T) {
// 3.1. 缓存不存在:
// 目的: 验证当文件在本地不存在时,返回正确状态。
// 流程:
// - 提供一个不存在的文件路径调用 checkLocal。
// - 检查返回状态是否为 localNotExists。
t.Run("CacheNotExists", func(t *testing.T) {
server := NewServer(Config{})
tmpDir := t.TempDir()
nonExistentFile := filepath.Join(tmpDir, "non-existent-file")
status, _, err := server.checkLocal(nil, nil, nonExistentFile)
if err != nil {
t.Fatalf("checkLocal() unexpected error: %v", err)
}
if status != localNotExists {
t.Errorf("Expected status to be localNotExists, got %v", status)
}
})
// 3.2. 缓存存在且有效:
// 目的: 验证当文件存在且未过期时,返回正确状态。
// 流程:
// - 创建一个新文件,并配置一个较长的 RefreshAfter。
// - 调用 checkLocal。
// - 检查返回状态是否为 localExists。
t.Run("CacheExistsAndValid", func(t *testing.T) {
server := NewServer(Config{
Cache: Cache{
RefreshAfter: time.Hour,
},
})
tmpDir := t.TempDir()
validFile := filepath.Join(tmpDir, "valid-file")
if err := os.WriteFile(validFile, []byte("data"), 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
status, _, err := server.checkLocal(nil, nil, validFile)
if err != nil {
t.Fatalf("checkLocal() unexpected error: %v", err)
}
if status != localExists {
t.Errorf("Expected status to be localExists, got %v", status)
}
})
// 3.3. 缓存存在但已过期:
// 目的: 验证当文件存在但已过期时,返回“需要检查更新”的状态。
// 流程:
// - 创建一个旧文件,或设置一个极短的 RefreshAfter。
// - 调用 checkLocal。
// - 检查返回状态是否为 localExistsButNeedHead。
t.Run("CacheExistsButExpired", func(t *testing.T) {
server := NewServer(Config{
Cache: Cache{
RefreshAfter: time.Nanosecond,
},
})
tmpDir := t.TempDir()
expiredFile := filepath.Join(tmpDir, "expired-file")
if err := os.WriteFile(expiredFile, []byte("data"), 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
// 确保文件的修改时间在过去
time.Sleep(2 * time.Nanosecond)
status, _, err := server.checkLocal(nil, nil, expiredFile)
if err != nil {
t.Fatalf("checkLocal() unexpected error: %v", err)
}
if status != localExistsButNeedHead {
t.Errorf("Expected status to be localExistsButNeedHead, got %v", status)
}
})
// 3.4. 缓存策略 "always" 和 "never":
// 目的: 验证 refresh: "always" 和 "never" 策略是否按预期工作。
// 流程:
// - Case "always": 配置匹配的 always 策略,检查是否返回 localExistsButNeedHead。
// - Case "never": 配置匹配的 never 策略,检查是否返回 localExists。
t.Run("CachePolicyAlwaysAndNever", func(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test-policy-file.txt")
if err := os.WriteFile(testFile, []byte("data"), 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
// Case "always"
t.Run("Always", func(t *testing.T) {
server := NewServer(Config{
Cache: Cache{
Policies: []CachePolicyOnPath{
{Match: `\.txt$`, RefreshAfter: "always"},
},
},
})
status, _, err := server.checkLocal(nil, nil, testFile)
if err != nil {
t.Fatalf("checkLocal() unexpected error: %v", err)
}
if status != localExistsButNeedHead {
t.Errorf("Expected status to be localExistsButNeedHead for 'always' policy, got %v", status)
}
})
// Case "never"
t.Run("Never", func(t *testing.T) {
server := NewServer(Config{
Cache: Cache{
RefreshAfter: time.Nanosecond, // global policy is expired
Policies: []CachePolicyOnPath{
{Match: `\.txt$`, RefreshAfter: "never"},
},
},
})
// Make sure file is "old" enough to trigger the global policy if not for the "never" rule
time.Sleep(2 * time.Nanosecond)
status, _, err := server.checkLocal(nil, nil, testFile)
if err != nil {
t.Fatalf("checkLocal() unexpected error: %v", err)
}
if status != localExists {
t.Errorf("Expected status to be localExists for 'never' policy, got %v", status)
}
})
})
// 3.5. 文件权限错误处理:
// 目的: 验证当因权限问题无法访问文件时HandleRequestWithCache 能返回 403 Forbidden。
// 流程:
// - 创建一个目录,并移除其执行权限,使其内部的文件无法被 Stat。
// - 通过 HandleRequestWithCache 请求该文件。
// - 检查 HTTP 响应状态码是否为 403 Forbidden。
// 注意: 这本质上是一个集成测试,验证了 checkLocal 的权限错误被正确传递并处理。
t.Run("FilePermissionError", func(t *testing.T) {
// This test might not be reliable on Windows, as file permissions work differently.
if os.Getuid() == 0 {
t.Skip("Skipping test: running as root, cannot test permission errors effectively.")
}
tmpDir := t.TempDir()
permDir := filepath.Join(tmpDir, "perm_dir")
// 1. Create directory with full permissions for the owner
if err := os.Mkdir(permDir, 0700); err != nil {
t.Fatalf("Failed to create test directory: %v", err)
}
defer os.Chmod(permDir, 0755) // Cleanup
// 2. Create a file inside it
unreadableFilePath := filepath.Join(permDir, "file")
if err := os.WriteFile(unreadableFilePath, []byte("data"), 0644); err != nil {
t.Fatalf("Failed to create test file: %v", err)
}
// 3. Revoke execute permission on the directory, making the file inside un-stat-able
if err := os.Chmod(permDir, 0600); err != nil {
t.Fatalf("Failed to chmod perm_dir: %v", err)
}
server := NewServer(Config{
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
},
},
})
req := httptest.NewRequest("GET", "/perm_dir/file", nil)
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("Expected status code %d, got %d", http.StatusForbidden, rr.Code)
}
})
}
// TestIntegration_HandleRequestWithCache 包含了集成测试。
// 这部分测试的是核心的业务流程,将多个单元组合起来进行验证。
func TestIntegration_HandleRequestWithCache(t *testing.T) {
// 4.1. 首次请求(冷缓存):
// 目的: 验证首次请求能成功从上游下载、响应客户端并写入缓存。
// 流程:
// - 保证缓存目录为空。
// - 发起 HTTP 请求到 HandleRequestWithCache。
// - 检查客户端响应和本地缓存文件。
t.Run("FirstRequestColdCache", func(t *testing.T) {
expectedBody := "from upstream"
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "13")
w.WriteHeader(http.StatusOK)
w.Write([]byte(expectedBody))
}))
defer mockUpstream.Close()
tmpDir := t.TempDir()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
},
},
Misc: MiscConfig{
FirstChunkBytes: 1024,
ChunkBytes: 1024,
},
})
req := httptest.NewRequest("GET", "/test.txt", nil)
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code)
}
if rr.Body.String() != expectedBody {
t.Errorf("Expected body '%s', got '%s'", expectedBody, rr.Body.String())
}
// 等待后台缓存写入完成
time.Sleep(100 * time.Millisecond)
cachedFilePath := filepath.Join(tmpDir, "test.txt")
cachedData, err := os.ReadFile(cachedFilePath)
if err != nil {
t.Fatalf("Failed to read cached file: %v", err)
}
if string(cachedData) != expectedBody {
t.Errorf("Expected cached file content '%s', got '%s'", expectedBody, string(cachedData))
}
})
// 4.2. 缓存命中:
// 目的: 验证对已缓存文件的请求,能直接从本地提供服务。
// 流程:
// - 预先在缓存目录中放置文件。
// - 模拟一个会失败的上游(确保没有网络调用)。
// - 发起请求并检查响应。
t.Run("CacheHit", func(t *testing.T) {
cachedBody := "from cache"
// Mock upstream that will fail if called
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("Upstream should not be called on cache hit")
w.WriteHeader(http.StatusInternalServerError)
}))
defer mockUpstream.Close()
tmpDir := t.TempDir()
// Pre-populate the cache
cachedFilePath := filepath.Join(tmpDir, "cached.txt")
if err := os.WriteFile(cachedFilePath, []byte(cachedBody), 0644); err != nil {
t.Fatalf("Failed to write pre-populated cache file: %v", err)
}
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
},
},
Cache: Cache{
RefreshAfter: time.Hour, // Ensure cache is valid
},
})
req := httptest.NewRequest("GET", "/cached.txt", nil)
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code)
}
if rr.Body.String() != cachedBody {
t.Errorf("Expected body '%s', got '%s'", cachedBody, rr.Body.String())
}
})
// 4.3. 并发请求同一文件:
// 目的: 验证多个并发请求同一文件时,只会触发一次上游下载。
// 流程:
// - 使用 goroutine 模拟并发请求。
// - 模拟一个能记录调用次数的上游。
// - 检查所有客户端的响应和上游调用次数。
t.Run("ConcurrentRequestsSameFile", func(t *testing.T) {
expectedBody := "concurrent download"
var upstreamRequestCount int32
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&upstreamRequestCount, 1)
// Simulate a slow download
time.Sleep(200 * time.Millisecond)
w.Header().Set("Content-Length", "19")
w.WriteHeader(http.StatusOK)
w.Write([]byte(expectedBody))
}))
defer mockUpstream.Close()
tmpDir := t.TempDir()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
},
},
Misc: MiscConfig{
FirstChunkBytes: 1024,
ChunkBytes: 1024,
},
})
var wg sync.WaitGroup
concurrentRequests := 5
wg.Add(concurrentRequests)
for i := 0; i < concurrentRequests; i++ {
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/concurrent.txt", nil)
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code)
}
if rr.Body.String() != expectedBody {
t.Errorf("Expected body '%s', got '%s'", expectedBody, rr.Body.String())
}
}()
}
wg.Wait()
finalCount := atomic.LoadInt32(&upstreamRequestCount)
if finalCount != 1 {
t.Errorf("Expected upstream to be called once, but it was called %d times", finalCount)
}
})
// 4.4. 上游下载中途失败:
// 目的: 验证当上游下载中断时,不缓存不完整的文件。
// 流程:
// - 模拟一个只发送部分数据就断开连接的上游。
// - 检查本地没有生成(或留存)不完整的缓存文件。
t.Run("UpstreamDownloadFailsMidway", func(t *testing.T) {
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "100") // Pretend we'll send 100 bytes
w.WriteHeader(http.StatusOK)
w.Write([]byte("only part of the data")) // Send only a part
// And then abruptly close the connection
conn, _, _ := w.(http.Hijacker).Hijack()
conn.Close()
}))
defer mockUpstream.Close()
tmpDir := t.TempDir()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
},
},
Misc: MiscConfig{
FirstChunkBytes: 10,
ChunkBytes: 10,
},
})
req := httptest.NewRequest("GET", "/incomplete.txt", nil)
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
// The client will receive a response with a body that unexpectedly ends.
// The key is to check that the incomplete file is not cached.
time.Sleep(100 * time.Millisecond) // Wait for async cache write to fail
cachedFilePath := filepath.Join(tmpDir, "incomplete.txt")
if _, err := os.Stat(cachedFilePath); !os.IsNotExist(err) {
t.Errorf("Expected incomplete file not to be cached, but it exists.")
}
})
// 4.5. Range 请求(冷缓存):
// 目的: 验证 Range 请求在冷缓存时,能先完整下载再响应。
// 流程:
// - 发起一个带 Range 头的 HTTP 请求。
// - 检查文件是否被完整下载,以及客户端是否收到 206 响应。
t.Run("RangeRequestColdCache", func(t *testing.T) {
fullBody := "this is the full content for range request"
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// The upstream should not receive the Range header, as the proxy should fetch the full file.
if r.Header.Get("Range") != "" {
t.Error("Upstream received unexpected Range header")
}
w.Header().Set("Content-Length", strconv.Itoa(len(fullBody)))
w.WriteHeader(http.StatusOK)
w.Write([]byte(fullBody))
}))
defer mockUpstream.Close()
tmpDir := t.TempDir()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
},
},
Misc: MiscConfig{
FirstChunkBytes: 1024,
ChunkBytes: 1024,
},
})
req := httptest.NewRequest("GET", "/range.txt", nil)
req.Header.Set("Range", "bytes=5-15") // Request a sub-range
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
if rr.Code != http.StatusPartialContent {
t.Errorf("Expected status code %d, got %d", http.StatusPartialContent, rr.Code)
}
expectedPartialBody := "is the full" // bytes from index 5 to 15
if rr.Body.String() != expectedPartialBody {
t.Errorf("Expected partial body '%s', got '%s'", expectedPartialBody, rr.Body.String())
}
expectedContentRange := "bytes 5-15/42"
if rr.Header().Get("Content-Range") != expectedContentRange {
t.Errorf("Expected Content-Range header '%s', got '%s'", expectedContentRange, rr.Header().Get("Content-Range"))
}
// Wait for the background caching to complete
time.Sleep(100 * time.Millisecond)
// Verify that the full file was cached
cachedFilePath := filepath.Join(tmpDir, "range.txt")
cachedData, err := os.ReadFile(cachedFilePath)
if err != nil {
t.Fatalf("Failed to read cached file: %v", err)
}
if string(cachedData) != fullBody {
t.Errorf("Expected full file to be cached, content was '%s'", string(cachedData))
}
})
// 4.6. X-Accel-Redirect:
// 目的: 验证 X-Accel-Redirect 功能是否按预期工作。
// 流程:
// - 启用 Accel 配置并发起带特定头的请求。
// - 检查响应头是否包含正确的重定向路径。
t.Run("XAccelRedirect", func(t *testing.T) {
// 1. Setup
cachedBody := "this will not be sent"
tmpDir := t.TempDir()
// Pre-populate the cache
cachedFilePath := filepath.Join(tmpDir, "accel.txt")
if err := os.WriteFile(cachedFilePath, []byte(cachedBody), 0644); err != nil {
t.Fatalf("Failed to write pre-populated cache file: %v", err)
}
server := NewServer(Config{
// No upstreams needed as we are hitting the cache
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
Accel: Accel{
EnableByHeader: "X-Sendfile-Type",
RespondWithHeaders: []string{"X-Accel-Redirect"},
},
},
},
Cache: Cache{
RefreshAfter: time.Hour, // Ensure cache is valid
},
})
// 2. Make request
req := httptest.NewRequest("GET", "/accel.txt", nil)
// This header enables the accel logic and provides the base path for the redirect
req.Header.Set("X-Sendfile-Type", "/internal_redirect/")
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
// 3. Assertions
if rr.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code)
}
if rr.Body.Len() > 0 {
t.Errorf("Expected empty body for accel redirect, but got body: %s", rr.Body.String())
}
expectedRedirectPath := "/internal_redirect/accel.txt"
redirectHeader := rr.Header().Get("X-Accel-Redirect")
if redirectHeader != expectedRedirectPath {
t.Errorf("Expected X-Accel-Redirect header to be '%s', got '%s'", expectedRedirectPath, redirectHeader)
}
})
// 4.7. 路径穿越攻击:
// 目的: 验证系统能阻止非法的路径访问。
// 流程:
// - 发起一个包含 ../ 的恶意路径请求。
// - 检查是否返回 400 Bad Request。
t.Run("PathTraversalAttack", func(t *testing.T) {
tmpDir := t.TempDir()
server := NewServer(Config{
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
},
},
})
req := httptest.NewRequest("GET", "/../../../../etc/passwd", nil)
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
if rr.Code != http.StatusBadRequest {
t.Errorf("Expected status code %d for path traversal, got %d", http.StatusBadRequest, rr.Code)
}
if !strings.Contains(rr.Body.String(), "crossing local directory boundary") {
t.Errorf("Expected error message about crossing boundary, got: %s", rr.Body.String())
}
})
// 4.8. Range 请求竞态条件 (Race Condition on Range Request):
// 目的: 专门测试在冷缓存上处理范围请求时,生产者和消费者之间的竞态条件。
// 当上游响应非常快时,生产者 goroutine 可能会在消费者 goroutine
// 有机会复制dup文件句柄之前就过早地关闭并重命名了临时文件
// 导致消费者持有一个无效的文件句柄。
// 流程:
// - 模拟一个极快响应的上游。
// - 多次并发地发起范围请求。
// - 验证即使在这种高速场景下,消费者依然能返回正确的局部内容。
t.Run("RaceConditionOnRangeRequest", func(t *testing.T) {
for i := 0; i < 20; i++ { // 多次运行以增加命中竞态条件的概率
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Parallel() // 并行运行测试以增加调度压力
fullBody := "race condition test body content"
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", strconv.Itoa(len(fullBody)))
w.WriteHeader(http.StatusOK)
w.Write([]byte(fullBody))
}))
defer mockUpstream.Close()
tmpDir := t.TempDir()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
},
},
Misc: MiscConfig{
FirstChunkBytes: 1024, // 确保一次性读完
ChunkBytes: 1024,
},
})
req := httptest.NewRequest("GET", "/race.txt", nil)
req.Header.Set("Range", "bytes=5-15")
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
// 在有问题的实现中,这里会因为文件句柄被过早关闭而失败
if rr.Code != http.StatusPartialContent {
t.Fatalf("Expected status code %d, got %d", http.StatusPartialContent, rr.Code)
}
expectedPartialBody := fullBody[5:16]
if rr.Body.String() != expectedPartialBody {
t.Fatalf("Expected partial body '%s', got '%s'", expectedPartialBody, rr.Body.String())
}
})
}
})
}
func isTimeoutError(err error) bool {
e, ok := err.(interface{ Timeout() bool })
return ok && e.Timeout()
}
// TestServeRangedRequest 包含了针对 serveRangedRequest 函数的单元测试。
// 这些测试独立地验证消费者逻辑,通过模拟一个正在进行中的下载流 (StreamObject)
// 来覆盖各种客户端请求场景。
func TestServeRangedRequest(t *testing.T) {
// 5.1. 请求的范围已完全可用:
// 目的: 验证当客户端请求的数据范围在下载完成前已经完全在临时文件中时,
// 函数能立即响应并返回正确的数据,而无需等待整个文件下载完毕。
t.Run("RequestRangeFullyAvailable", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
fileSize := 1000 // 文件总大小为 1000
// 创建并写入临时文件
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
initialData := make([]byte, 200) // 文件已有 200 字节
if _, err := tmpFile.Write(initialData); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
streamObj := &StreamObject{
Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}},
TempFile: tmpFile,
Offset: int64(len(initialData)),
Done: false, // 下载仍在进行
Error: nil,
mu: mu,
cond: cond,
}
// 将模拟的 stream object 放入 server 的 map 中
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 2. 执行
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=50-150")
rr := httptest.NewRecorder()
server.serveRangedRequest(rr, req, "/testfile", time.Time{})
// 3. 断言
if rr.Code != http.StatusPartialContent {
t.Errorf("Expected status %d, got %d", http.StatusPartialContent, rr.Code)
}
expectedBodyLen := 101 // 150 - 50 + 1
if rr.Body.Len() != expectedBodyLen {
t.Errorf("Expected body length %d, got %d", expectedBodyLen, rr.Body.Len())
}
expectedContentRange := "bytes 50-150/1000"
if rr.Header().Get("Content-Range") != expectedContentRange {
t.Errorf("Expected Content-Range '%s', got '%s'", expectedContentRange, rr.Header().Get("Content-Range"))
}
})
// 5.2. 请求的范围需要等待后续数据:
// 目的: 验证当请求的数据部分可用时,函数能先发送可用的部分,
// 然后等待生产者(下载 goroutine提供更多数据并最终完成响应。
t.Run("RequestWaitsForData", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
fileSize := 200 // 文件总大小为 200
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
// 初始只有 100 字节
initialData := make([]byte, 100)
for i := 0; i < 100; i++ {
initialData[i] = byte(i)
}
if _, err := tmpFile.Write(initialData); err != nil {
t.Fatalf("Failed to write initial data: %v", err)
}
streamObj := &StreamObject{
Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}},
TempFile: tmpFile,
Offset: int64(len(initialData)),
Done: false,
Error: nil,
mu: mu,
cond: cond,
}
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 2. 执行
req := httptest.NewRequest("GET", "/testfile", nil)
// 请求 150 字节,但目前只有 100 字节可用
req.Header.Set("Range", "bytes=0-149")
rr := httptest.NewRecorder()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
server.serveRangedRequest(rr, req, "/testfile", time.Time{})
}()
// 模拟生产者在短暂延迟后写入更多数据
go func() {
time.Sleep(50 * time.Millisecond) // 等待消费者进入等待状态
streamObj.mu.Lock()
defer streamObj.mu.Unlock()
// 写入剩下的 100 字节
remainingData := make([]byte, 100)
for i := 0; i < 100; i++ {
remainingData[i] = byte(100 + i)
}
n, err := streamObj.TempFile.Write(remainingData)
if err != nil {
t.Errorf("Producer failed to write: %v", err)
return
}
streamObj.Offset += int64(n)
streamObj.Done = true // 下载完成
streamObj.cond.Broadcast() // 唤醒消费者
}()
// 等待 serveRangedRequest 完成
wg.Wait()
// 3. 断言
if rr.Code != http.StatusPartialContent {
t.Errorf("Expected status %d, got %d", http.StatusPartialContent, rr.Code)
}
expectedBodyLen := 150
if rr.Body.Len() != expectedBodyLen {
t.Errorf("Expected body length %d, got %d", expectedBodyLen, rr.Body.Len())
}
// 验证内容是否正确
body := rr.Body.Bytes()
for i := 0; i < 150; i++ {
if body[i] != byte(i) {
t.Fatalf("Body content mismatch at index %d: expected %d, got %d", i, i, body[i])
}
}
expectedContentRange := "bytes 0-149/200"
if rr.Header().Get("Content-Range") != expectedContentRange {
t.Errorf("Expected Content-Range '%s', got '%s'", expectedContentRange, rr.Header().Get("Content-Range"))
}
})
// 5.3. 多个客户端并发请求:
// 目的: 验证多个客户端可以同时从同一个流中读取数据,即使它们请求的是不同或重叠的范围。
t.Run("MultipleConcurrentClients", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
fileSize := 300
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
// 初始数据
initialData := make([]byte, 100)
if _, err := tmpFile.Write(initialData); err != nil {
t.Fatalf("Failed to write initial data: %v", err)
}
streamObj := &StreamObject{
Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}},
TempFile: tmpFile,
Offset: int64(len(initialData)),
Done: false,
Error: nil,
mu: mu,
cond: cond,
}
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 模拟生产者持续写入数据
go func() {
// 写入第二部分
time.Sleep(50 * time.Millisecond)
streamObj.mu.Lock()
n, _ := streamObj.TempFile.Write(make([]byte, 100))
streamObj.Offset += int64(n)
streamObj.cond.Broadcast()
streamObj.mu.Unlock()
// 写入最后一部分并完成
time.Sleep(50 * time.Millisecond)
streamObj.mu.Lock()
n, _ = streamObj.TempFile.Write(make([]byte, 100))
streamObj.Offset += int64(n)
streamObj.Done = true
streamObj.cond.Broadcast()
streamObj.mu.Unlock()
}()
var wg sync.WaitGroup
wg.Add(2)
// 客户端 1: 请求 [50, 150]
rr1 := httptest.NewRecorder()
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=50-150")
server.serveRangedRequest(rr1, req, "/testfile", time.Time{})
}()
// 客户端 2: 请求 [180, 280]
rr2 := httptest.NewRecorder()
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=180-280")
server.serveRangedRequest(rr2, req, "/testfile", time.Time{})
}()
wg.Wait()
// 断言客户端 1
if rr1.Code != http.StatusPartialContent {
t.Errorf("Client 1: Expected status %d, got %d", http.StatusPartialContent, rr1.Code)
}
if rr1.Body.Len() != 101 {
t.Errorf("Client 1: Expected body length 101, got %d", rr1.Body.Len())
}
if rr1.Header().Get("Content-Range") != "bytes 50-150/300" {
t.Errorf("Client 1: Incorrect Content-Range header: %s", rr1.Header().Get("Content-Range"))
}
// 断言客户端 2
if rr2.Code != http.StatusPartialContent {
t.Errorf("Client 2: Expected status %d, got %d", http.StatusPartialContent, rr2.Code)
}
if rr2.Body.Len() != 101 {
t.Errorf("Client 2: Expected body length 101, got %d", rr2.Body.Len())
}
if rr2.Header().Get("Content-Range") != "bytes 180-280/300" {
t.Errorf("Client 2: Incorrect Content-Range header: %s", rr2.Header().Get("Content-Range"))
}
})
// 5.4. 下载过程中发生错误:
// 目的: 验证当生产者(下载 goroutine遇到错误时所有等待的消费者都会被正确通知并中断
// 并向各自的客户端返回一个错误。
t.Run("DownloadFailsWhileWaiting", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
fileSize := 200
downloadError := errors.New("network connection lost")
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
// 初始有 50 字节
if _, err := tmpFile.Write(make([]byte, 50)); err != nil {
t.Fatalf("Failed to write initial data: %v", err)
}
streamObj := &StreamObject{
Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}},
TempFile: tmpFile,
Offset: 50,
Done: false,
Error: nil,
mu: mu,
cond: cond,
}
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 模拟生产者在短暂延迟后遇到错误
go func() {
time.Sleep(50 * time.Millisecond)
streamObj.mu.Lock()
defer streamObj.mu.Unlock()
streamObj.Error = downloadError
streamObj.Done = true // 标记为完成(虽然是失败)
streamObj.cond.Broadcast() // 唤醒所有等待的消费者
}()
var wg sync.WaitGroup
wg.Add(2)
// 客户端 1: 请求 [100, 150], 需要等待
rr1 := httptest.NewRecorder()
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=100-150")
server.serveRangedRequest(rr1, req, "/testfile", time.Time{})
}()
// 客户端 2: 也请求 [100, 150], 同样等待
rr2 := httptest.NewRecorder()
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=100-150")
server.serveRangedRequest(rr2, req, "/testfile", time.Time{})
}()
wg.Wait()
// 断言两个客户端都收到了内部服务器错误
if rr1.Code != http.StatusInternalServerError {
t.Errorf("Client 1: Expected status %d, got %d", http.StatusInternalServerError, rr1.Code)
}
if !strings.Contains(rr1.Body.String(), downloadError.Error()) {
t.Errorf("Client 1: Expected error message '%s', got '%s'", downloadError.Error(), rr1.Body.String())
}
if rr2.Code != http.StatusInternalServerError {
t.Errorf("Client 2: Expected status %d, got %d", http.StatusInternalServerError, rr2.Code)
}
if !strings.Contains(rr2.Body.String(), downloadError.Error()) {
t.Errorf("Client 2: Expected error message '%s', got '%s'", downloadError.Error(), rr2.Body.String())
}
})
// 5.5. 上游无 Content-Length:
// 目的: 验证当上游响应没有 Content-Length 头时,系统如何优雅地处理。
// 在当前实现中,缺少 Content-Length 会导致无法计算范围,应返回错误。
t.Run("NoContentLength", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
streamObj := &StreamObject{
Headers: http.Header{}, // 没有 Content-Length
TempFile: tmpFile,
Offset: 50,
Done: false,
Error: nil,
mu: mu,
cond: cond,
}
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 2. 执行
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=0-10")
rr := httptest.NewRecorder()
// 在这个场景下serveRangedRequest 应该会立即返回错误,无需 goroutine
server.serveRangedRequest(rr, req, "/testfile", time.Time{})
// 3. 断言
// 因为没有 Content-Length, totalSize 会是 0, 导致 "Range not satisfiable"
if rr.Code != http.StatusRequestedRangeNotSatisfiable {
t.Errorf("Expected status %d, got %d", http.StatusRequestedRangeNotSatisfiable, rr.Code)
}
})
}