cline optimization on range request and memory usage
All checks were successful
build container / build-container (push) Successful in 29m33s
run go test / test (push) Successful in 26m15s

This commit is contained in:
2025-06-10 17:44:44 +08:00
parent 147659b0da
commit 80560f7408
6 changed files with 952 additions and 440 deletions

View File

@ -3,6 +3,7 @@ package cacheproxy
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
@ -47,45 +48,28 @@ func TestTryUpstream(t *testing.T) {
// 调用被测函数
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
// 断言
if err != nil {
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("tryUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("tryUpstream() response is nil")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, resp.StatusCode)
}
if chunks == nil {
t.Fatal("tryUpstream() chunks channel is nil")
if firstChunk == nil {
t.Fatal("tryUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
var finalErr error
LOOP:
for {
select {
case chunk, ok := <-chunks:
if !ok {
break LOOP
}
if chunk.error != nil {
finalErr = chunk.error
break LOOP
}
receivedBody.Write(chunk.buffer)
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for chunk")
}
}
if finalErr != nil && finalErr != io.EOF {
t.Errorf("Expected final error to be io.EOF, got %v", finalErr)
}
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != expectedBody {
t.Errorf("Expected body '%s', got '%s'", expectedBody, receivedBody.String())
@ -119,7 +103,7 @@ func TestTryUpstream(t *testing.T) {
// 调用被测函数
req, _ := http.NewRequest("GET", "/testfile", nil)
lastModified := time.Now().UTC()
resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, lastModified)
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, lastModified)
// 断言
if err != nil {
@ -131,8 +115,8 @@ func TestTryUpstream(t *testing.T) {
if resp.StatusCode != http.StatusNotModified {
t.Errorf("Expected status code %d, got %d", http.StatusNotModified, resp.StatusCode)
}
if chunks != nil {
t.Error("Expected chunks channel to be nil for 304 response")
if firstChunk != nil {
t.Error("Expected firstChunk to be nil for 304 response")
}
})
@ -158,7 +142,7 @@ func TestTryUpstream(t *testing.T) {
// 调用被测函数
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
// 断言
if err != nil {
@ -166,8 +150,8 @@ func TestTryUpstream(t *testing.T) {
}
// 对于一个明确的失败如404我们期望函数返回 nil, nil, nil
// 因为 response checker 默认只接受 200
if resp != nil || chunks != nil {
t.Errorf("Expected response and chunks to be nil, got resp=%v, chunks=%v", resp, chunks)
if resp != nil || firstChunk != nil {
t.Errorf("Expected response and firstChunk to be nil, got resp=%v, chunks=%v", resp, firstChunk)
}
})
@ -243,17 +227,15 @@ func TestTryUpstream(t *testing.T) {
})
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
if err != nil {
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("Expected no error, got %v", err)
}
if resp == nil || chunks == nil {
t.Fatal("Expected successful response and chunks, but got nil")
}
// drain the channel
for range chunks {
if resp == nil || firstChunk == nil {
t.Fatal("Expected successful response and firstChunk, but got nil")
}
resp.Body.Close()
})
// Case B: Checker 失败,应该返回 nil
@ -281,13 +263,13 @@ func TestTryUpstream(t *testing.T) {
})
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if resp != nil || chunks != nil {
t.Fatal("Expected response and chunks to be nil due to checker failure")
if resp != nil || firstChunk != nil {
t.Fatal("Expected response and firstChunk to be nil due to checker failure")
}
})
})
@ -326,20 +308,18 @@ func TestTryUpstream(t *testing.T) {
})
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
if err != nil {
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("Expected no error, got %v", err)
}
if resp == nil || chunks == nil {
t.Fatal("Expected successful response and chunks, but got nil")
if resp == nil || firstChunk == nil {
t.Fatal("Expected successful response and firstChunk, but got nil")
}
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status OK, got %d", resp.StatusCode)
}
// drain the channel
for range chunks {
}
resp.Body.Close()
})
// Case B: 重定向被禁止
@ -355,7 +335,7 @@ func TestTryUpstream(t *testing.T) {
})
req, _ := http.NewRequest("GET", "/testfile", nil)
resp, chunks, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
resp, firstChunk, err := server.tryUpstream(context.Background(), 0, 0, req, time.Time{})
// 当重定向规则不匹配时,自定义的 CheckRedirect 函数会返回 http.ErrUseLastResponse。
// 这使得 http.Client 直接返回 302 响应,而不是跟随跳转。
@ -364,8 +344,8 @@ func TestTryUpstream(t *testing.T) {
if err != nil {
t.Fatalf("Expected no error for a checker failure, got %v", err)
}
if resp != nil || chunks != nil {
t.Fatalf("Expected response and chunks to be nil for a disallowed redirect, got resp=%v, chunks=%v", resp, chunks)
if resp != nil || firstChunk != nil {
t.Fatalf("Expected response and firstChunk to be nil for a disallowed redirect, got resp=%v, chunks=%v", resp, firstChunk)
}
})
})
@ -413,38 +393,27 @@ func TestFastestUpstream(t *testing.T) {
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{})
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil {
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("fastestUpstream() response is nil")
}
defer resp.Body.Close()
if upstreamIndex != 1 {
t.Errorf("Expected fastest upstream index to be 1 (fastServer), got %d", upstreamIndex)
}
if chunks == nil {
t.Fatal("fastestUpstream() chunks channel is nil")
if firstChunk == nil {
t.Fatal("fastestUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
LOOP:
for {
select {
case chunk, ok := <-chunks:
if !ok {
break LOOP
}
if chunk.error != nil && chunk.error != io.EOF {
t.Fatalf("unexpected error from chunk channel: %v", chunk.error)
}
receivedBody.Write(chunk.buffer)
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for chunk")
}
}
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != "fast" {
t.Errorf("Expected body from fast server, got '%s'", receivedBody.String())
@ -490,29 +459,27 @@ func TestFastestUpstream(t *testing.T) {
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{})
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil {
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("fastestUpstream() response is nil")
}
defer resp.Body.Close()
if upstreamIndex != 1 {
t.Errorf("Expected fastest upstream index to be 1 (workingServer), got %d", upstreamIndex)
}
if chunks == nil {
t.Fatal("fastestUpstream() chunks channel is nil")
if firstChunk == nil {
t.Fatal("fastestUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
for chunk := range chunks {
if chunk.error != nil && chunk.error != io.EOF {
t.Fatalf("unexpected error from chunk channel: %v", chunk.error)
}
receivedBody.Write(chunk.buffer)
}
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != "working" {
t.Errorf("Expected body from working server, got '%s'", receivedBody.String())
@ -546,7 +513,7 @@ func TestFastestUpstream(t *testing.T) {
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{})
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
@ -554,8 +521,8 @@ func TestFastestUpstream(t *testing.T) {
if upstreamIndex != -1 {
t.Errorf("Expected upstream index to be -1, got %d", upstreamIndex)
}
if resp != nil || chunks != nil {
t.Errorf("Expected response and chunks to be nil, got resp=%v, chunks=%v", resp, chunks)
if resp != nil || firstChunk != nil {
t.Errorf("Expected response and firstChunk to be nil, got resp=%v, chunks=%v", resp, firstChunk)
}
})
@ -585,7 +552,7 @@ func TestFastestUpstream(t *testing.T) {
req, _ := http.NewRequest("GET", "/testfile", nil)
lastModified := time.Now().UTC()
upstreamIndex, resp, chunks, err := server.fastestUpstream(req, lastModified)
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, lastModified)
if err != nil {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
@ -599,8 +566,8 @@ func TestFastestUpstream(t *testing.T) {
if resp.StatusCode != http.StatusNotModified {
t.Errorf("Expected status code %d, got %d", http.StatusNotModified, resp.StatusCode)
}
if chunks != nil {
t.Error("Expected chunks channel to be nil for 304 response")
if firstChunk != nil {
t.Error("Expected firstChunk to be nil for 304 response")
}
})
@ -645,29 +612,27 @@ func TestFastestUpstream(t *testing.T) {
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{})
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil {
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("fastestUpstream() response is nil")
}
defer resp.Body.Close()
if upstreamIndex != 1 {
t.Errorf("Expected upstream index to be 1 (slow-high), got %d", upstreamIndex)
}
if chunks == nil {
t.Fatal("fastestUpstream() chunks channel is nil")
if firstChunk == nil {
t.Fatal("fastestUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
for chunk := range chunks {
if chunk.error != nil && chunk.error != io.EOF {
t.Fatalf("unexpected error from chunk channel: %v", chunk.error)
}
receivedBody.Write(chunk.buffer)
}
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != "slow-high" {
t.Errorf("Expected body from slow-high server, got '%s'", receivedBody.String())
@ -713,29 +678,27 @@ func TestFastestUpstream(t *testing.T) {
})
req, _ := http.NewRequest("GET", "/testfile", nil)
upstreamIndex, resp, chunks, err := server.fastestUpstream(req, time.Time{})
upstreamIndex, resp, firstChunk, err := server.fastestUpstream(req, time.Time{})
if err != nil {
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
t.Fatalf("fastestUpstream() unexpected error: %v", err)
}
if resp == nil {
t.Fatal("fastestUpstream() response is nil")
}
defer resp.Body.Close()
if upstreamIndex != 0 {
t.Errorf("Expected upstream index to be 0 (working-low), got %d", upstreamIndex)
}
if chunks == nil {
t.Fatal("fastestUpstream() chunks channel is nil")
if firstChunk == nil {
t.Fatal("fastestUpstream() firstChunk is nil")
}
var receivedBody bytes.Buffer
for chunk := range chunks {
if chunk.error != nil && chunk.error != io.EOF {
t.Fatalf("unexpected error from chunk channel: %v", chunk.error)
}
receivedBody.Write(chunk.buffer)
}
receivedBody.Write(firstChunk)
rest, _ := io.ReadAll(resp.Body)
receivedBody.Write(rest)
if receivedBody.String() != "working-low" {
t.Errorf("Expected body from working-low server, got '%s'", receivedBody.String())
@ -1288,9 +1251,465 @@ func TestIntegration_HandleRequestWithCache(t *testing.T) {
t.Errorf("Expected error message about crossing boundary, got: %s", rr.Body.String())
}
})
// 4.8. Range 请求竞态条件 (Race Condition on Range Request):
// 目的: 专门测试在冷缓存上处理范围请求时,生产者和消费者之间的竞态条件。
// 当上游响应非常快时,生产者 goroutine 可能会在消费者 goroutine
// 有机会复制dup文件句柄之前就过早地关闭并重命名了临时文件
// 导致消费者持有一个无效的文件句柄。
// 流程:
// - 模拟一个极快响应的上游。
// - 多次并发地发起范围请求。
// - 验证即使在这种高速场景下,消费者依然能返回正确的局部内容。
t.Run("RaceConditionOnRangeRequest", func(t *testing.T) {
for i := 0; i < 20; i++ { // 多次运行以增加命中竞态条件的概率
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Parallel() // 并行运行测试以增加调度压力
fullBody := "race condition test body content"
mockUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", strconv.Itoa(len(fullBody)))
w.WriteHeader(http.StatusOK)
w.Write([]byte(fullBody))
}))
defer mockUpstream.Close()
tmpDir := t.TempDir()
server := NewServer(Config{
Upstreams: []Upstream{
{Server: mockUpstream.URL},
},
Storage: Storage{
Local: &LocalStorage{
Path: tmpDir,
},
},
Misc: MiscConfig{
FirstChunkBytes: 1024, // 确保一次性读完
ChunkBytes: 1024,
},
})
req := httptest.NewRequest("GET", "/race.txt", nil)
req.Header.Set("Range", "bytes=5-15")
rr := httptest.NewRecorder()
server.HandleRequestWithCache(rr, req)
// 在有问题的实现中,这里会因为文件句柄被过早关闭而失败
if rr.Code != http.StatusPartialContent {
t.Fatalf("Expected status code %d, got %d", http.StatusPartialContent, rr.Code)
}
expectedPartialBody := fullBody[5:16]
if rr.Body.String() != expectedPartialBody {
t.Fatalf("Expected partial body '%s', got '%s'", expectedPartialBody, rr.Body.String())
}
})
}
})
}
func isTimeoutError(err error) bool {
e, ok := err.(interface{ Timeout() bool })
return ok && e.Timeout()
}
// TestServeRangedRequest 包含了针对 serveRangedRequest 函数的单元测试。
// 这些测试独立地验证消费者逻辑,通过模拟一个正在进行中的下载流 (StreamObject)
// 来覆盖各种客户端请求场景。
func TestServeRangedRequest(t *testing.T) {
// 5.1. 请求的范围已完全可用:
// 目的: 验证当客户端请求的数据范围在下载完成前已经完全在临时文件中时,
// 函数能立即响应并返回正确的数据,而无需等待整个文件下载完毕。
t.Run("RequestRangeFullyAvailable", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
fileSize := 1000 // 文件总大小为 1000
// 创建并写入临时文件
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
initialData := make([]byte, 200) // 文件已有 200 字节
if _, err := tmpFile.Write(initialData); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
streamObj := &StreamObject{
Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}},
TempFile: tmpFile,
Offset: int64(len(initialData)),
Done: false, // 下载仍在进行
Error: nil,
mu: mu,
cond: cond,
}
// 将模拟的 stream object 放入 server 的 map 中
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 2. 执行
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=50-150")
rr := httptest.NewRecorder()
server.serveRangedRequest(rr, req, "/testfile", time.Time{})
// 3. 断言
if rr.Code != http.StatusPartialContent {
t.Errorf("Expected status %d, got %d", http.StatusPartialContent, rr.Code)
}
expectedBodyLen := 101 // 150 - 50 + 1
if rr.Body.Len() != expectedBodyLen {
t.Errorf("Expected body length %d, got %d", expectedBodyLen, rr.Body.Len())
}
expectedContentRange := "bytes 50-150/1000"
if rr.Header().Get("Content-Range") != expectedContentRange {
t.Errorf("Expected Content-Range '%s', got '%s'", expectedContentRange, rr.Header().Get("Content-Range"))
}
})
// 5.2. 请求的范围需要等待后续数据:
// 目的: 验证当请求的数据部分可用时,函数能先发送可用的部分,
// 然后等待生产者(下载 goroutine提供更多数据并最终完成响应。
t.Run("RequestWaitsForData", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
fileSize := 200 // 文件总大小为 200
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
// 初始只有 100 字节
initialData := make([]byte, 100)
for i := 0; i < 100; i++ {
initialData[i] = byte(i)
}
if _, err := tmpFile.Write(initialData); err != nil {
t.Fatalf("Failed to write initial data: %v", err)
}
streamObj := &StreamObject{
Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}},
TempFile: tmpFile,
Offset: int64(len(initialData)),
Done: false,
Error: nil,
mu: mu,
cond: cond,
}
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 2. 执行
req := httptest.NewRequest("GET", "/testfile", nil)
// 请求 150 字节,但目前只有 100 字节可用
req.Header.Set("Range", "bytes=0-149")
rr := httptest.NewRecorder()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
server.serveRangedRequest(rr, req, "/testfile", time.Time{})
}()
// 模拟生产者在短暂延迟后写入更多数据
go func() {
time.Sleep(50 * time.Millisecond) // 等待消费者进入等待状态
streamObj.mu.Lock()
defer streamObj.mu.Unlock()
// 写入剩下的 100 字节
remainingData := make([]byte, 100)
for i := 0; i < 100; i++ {
remainingData[i] = byte(100 + i)
}
n, err := streamObj.TempFile.Write(remainingData)
if err != nil {
t.Errorf("Producer failed to write: %v", err)
return
}
streamObj.Offset += int64(n)
streamObj.Done = true // 下载完成
streamObj.cond.Broadcast() // 唤醒消费者
}()
// 等待 serveRangedRequest 完成
wg.Wait()
// 3. 断言
if rr.Code != http.StatusPartialContent {
t.Errorf("Expected status %d, got %d", http.StatusPartialContent, rr.Code)
}
expectedBodyLen := 150
if rr.Body.Len() != expectedBodyLen {
t.Errorf("Expected body length %d, got %d", expectedBodyLen, rr.Body.Len())
}
// 验证内容是否正确
body := rr.Body.Bytes()
for i := 0; i < 150; i++ {
if body[i] != byte(i) {
t.Fatalf("Body content mismatch at index %d: expected %d, got %d", i, i, body[i])
}
}
expectedContentRange := "bytes 0-149/200"
if rr.Header().Get("Content-Range") != expectedContentRange {
t.Errorf("Expected Content-Range '%s', got '%s'", expectedContentRange, rr.Header().Get("Content-Range"))
}
})
// 5.3. 多个客户端并发请求:
// 目的: 验证多个客户端可以同时从同一个流中读取数据,即使它们请求的是不同或重叠的范围。
t.Run("MultipleConcurrentClients", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
fileSize := 300
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
// 初始数据
initialData := make([]byte, 100)
if _, err := tmpFile.Write(initialData); err != nil {
t.Fatalf("Failed to write initial data: %v", err)
}
streamObj := &StreamObject{
Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}},
TempFile: tmpFile,
Offset: int64(len(initialData)),
Done: false,
Error: nil,
mu: mu,
cond: cond,
}
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 模拟生产者持续写入数据
go func() {
// 写入第二部分
time.Sleep(50 * time.Millisecond)
streamObj.mu.Lock()
n, _ := streamObj.TempFile.Write(make([]byte, 100))
streamObj.Offset += int64(n)
streamObj.cond.Broadcast()
streamObj.mu.Unlock()
// 写入最后一部分并完成
time.Sleep(50 * time.Millisecond)
streamObj.mu.Lock()
n, _ = streamObj.TempFile.Write(make([]byte, 100))
streamObj.Offset += int64(n)
streamObj.Done = true
streamObj.cond.Broadcast()
streamObj.mu.Unlock()
}()
var wg sync.WaitGroup
wg.Add(2)
// 客户端 1: 请求 [50, 150]
rr1 := httptest.NewRecorder()
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=50-150")
server.serveRangedRequest(rr1, req, "/testfile", time.Time{})
}()
// 客户端 2: 请求 [180, 280]
rr2 := httptest.NewRecorder()
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=180-280")
server.serveRangedRequest(rr2, req, "/testfile", time.Time{})
}()
wg.Wait()
// 断言客户端 1
if rr1.Code != http.StatusPartialContent {
t.Errorf("Client 1: Expected status %d, got %d", http.StatusPartialContent, rr1.Code)
}
if rr1.Body.Len() != 101 {
t.Errorf("Client 1: Expected body length 101, got %d", rr1.Body.Len())
}
if rr1.Header().Get("Content-Range") != "bytes 50-150/300" {
t.Errorf("Client 1: Incorrect Content-Range header: %s", rr1.Header().Get("Content-Range"))
}
// 断言客户端 2
if rr2.Code != http.StatusPartialContent {
t.Errorf("Client 2: Expected status %d, got %d", http.StatusPartialContent, rr2.Code)
}
if rr2.Body.Len() != 101 {
t.Errorf("Client 2: Expected body length 101, got %d", rr2.Body.Len())
}
if rr2.Header().Get("Content-Range") != "bytes 180-280/300" {
t.Errorf("Client 2: Incorrect Content-Range header: %s", rr2.Header().Get("Content-Range"))
}
})
// 5.4. 下载过程中发生错误:
// 目的: 验证当生产者(下载 goroutine遇到错误时所有等待的消费者都会被正确通知并中断
// 并向各自的客户端返回一个错误。
t.Run("DownloadFailsWhileWaiting", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
fileSize := 200
downloadError := errors.New("network connection lost")
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
// 初始有 50 字节
if _, err := tmpFile.Write(make([]byte, 50)); err != nil {
t.Fatalf("Failed to write initial data: %v", err)
}
streamObj := &StreamObject{
Headers: http.Header{"Content-Length": []string{strconv.Itoa(fileSize)}},
TempFile: tmpFile,
Offset: 50,
Done: false,
Error: nil,
mu: mu,
cond: cond,
}
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 模拟生产者在短暂延迟后遇到错误
go func() {
time.Sleep(50 * time.Millisecond)
streamObj.mu.Lock()
defer streamObj.mu.Unlock()
streamObj.Error = downloadError
streamObj.Done = true // 标记为完成(虽然是失败)
streamObj.cond.Broadcast() // 唤醒所有等待的消费者
}()
var wg sync.WaitGroup
wg.Add(2)
// 客户端 1: 请求 [100, 150], 需要等待
rr1 := httptest.NewRecorder()
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=100-150")
server.serveRangedRequest(rr1, req, "/testfile", time.Time{})
}()
// 客户端 2: 也请求 [100, 150], 同样等待
rr2 := httptest.NewRecorder()
go func() {
defer wg.Done()
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=100-150")
server.serveRangedRequest(rr2, req, "/testfile", time.Time{})
}()
wg.Wait()
// 断言两个客户端都收到了内部服务器错误
if rr1.Code != http.StatusInternalServerError {
t.Errorf("Client 1: Expected status %d, got %d", http.StatusInternalServerError, rr1.Code)
}
if !strings.Contains(rr1.Body.String(), downloadError.Error()) {
t.Errorf("Client 1: Expected error message '%s', got '%s'", downloadError.Error(), rr1.Body.String())
}
if rr2.Code != http.StatusInternalServerError {
t.Errorf("Client 2: Expected status %d, got %d", http.StatusInternalServerError, rr2.Code)
}
if !strings.Contains(rr2.Body.String(), downloadError.Error()) {
t.Errorf("Client 2: Expected error message '%s', got '%s'", downloadError.Error(), rr2.Body.String())
}
})
// 5.5. 上游无 Content-Length:
// 目的: 验证当上游响应没有 Content-Length 头时,系统如何优雅地处理。
// 在当前实现中,缺少 Content-Length 会导致无法计算范围,应返回错误。
t.Run("NoContentLength", func(t *testing.T) {
// 1. 设置
server := NewServer(Config{})
mu := &sync.Mutex{}
cond := sync.NewCond(mu)
tmpFile, err := os.CreateTemp(t.TempDir(), "test-")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
defer tmpFile.Close()
streamObj := &StreamObject{
Headers: http.Header{}, // 没有 Content-Length
TempFile: tmpFile,
Offset: 50,
Done: false,
Error: nil,
mu: mu,
cond: cond,
}
server.lu.Lock()
server.o["/testfile"] = streamObj
server.lu.Unlock()
// 2. 执行
req := httptest.NewRequest("GET", "/testfile", nil)
req.Header.Set("Range", "bytes=0-10")
rr := httptest.NewRecorder()
// 在这个场景下serveRangedRequest 应该会立即返回错误,无需 goroutine
server.serveRangedRequest(rr, req, "/testfile", time.Time{})
// 3. 断言
// 因为没有 Content-Length, totalSize 会是 0, 导致 "Range not satisfiable"
if rr.Code != http.StatusRequestedRangeNotSatisfiable {
t.Errorf("Expected status %d, got %d", http.StatusRequestedRangeNotSatisfiable, rr.Code)
}
})
}