cline optimization on range request and memory usage
This commit is contained in:
686
server.go
686
server.go
@ -1,19 +1,19 @@
|
||||
package cacheproxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -49,44 +49,16 @@ var (
|
||||
)
|
||||
|
||||
type StreamObject struct {
|
||||
Headers http.Header
|
||||
Buffer *bytes.Buffer
|
||||
Offset int
|
||||
Headers http.Header
|
||||
TempFile *os.File // The temporary file holding the download content.
|
||||
Offset int64 // The number of bytes written to TempFile.
|
||||
Done bool
|
||||
Error error
|
||||
|
||||
ctx context.Context
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
mu *sync.Mutex
|
||||
cond *sync.Cond
|
||||
|
||||
func (memoryObject *StreamObject) StreamTo(w io.Writer, wg *sync.WaitGroup) error {
|
||||
defer wg.Done()
|
||||
offset := 0
|
||||
if w == nil {
|
||||
w = io.Discard
|
||||
}
|
||||
OUTER:
|
||||
for {
|
||||
select {
|
||||
case <-memoryObject.ctx.Done():
|
||||
break OUTER
|
||||
default:
|
||||
}
|
||||
|
||||
newOffset := memoryObject.Offset
|
||||
if newOffset == offset {
|
||||
time.Sleep(time.Millisecond)
|
||||
continue
|
||||
}
|
||||
bytes := memoryObject.Buffer.Bytes()[offset:newOffset]
|
||||
written, err := w.Write(bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset += written
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
_, err := w.Write(memoryObject.Buffer.Bytes()[offset:])
|
||||
return err
|
||||
fileWrittenCh chan struct{} // Closed when the file is fully written and renamed.
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
@ -105,11 +77,6 @@ func NewServer(config Config) *Server {
|
||||
}
|
||||
}
|
||||
|
||||
type Chunk struct {
|
||||
buffer []byte
|
||||
error error
|
||||
}
|
||||
|
||||
func (server *Server) serveFile(w http.ResponseWriter, r *http.Request, path string) {
|
||||
if location := r.Header.Get(server.Storage.Local.Accel.EnableByHeader); server.Storage.Local.Accel.EnableByHeader != "" && location != "" {
|
||||
relPath, err := filepath.Rel(server.Storage.Local.Path, path)
|
||||
@ -147,26 +114,29 @@ func (server *Server) HandleRequestWithCache(w http.ResponseWriter, r *http.Requ
|
||||
slog.With("status", localStatus, "mtime", mtime, "error", err, "key", fullpath).Debug("local status checked")
|
||||
if os.IsPermission(err) {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
} else if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
} else if localStatus != localNotExists {
|
||||
if localStatus == localExistsButNeedHead {
|
||||
if ranged {
|
||||
<-server.streamOnline(nil, r, mtime, fullpath)
|
||||
server.serveFile(w, r, fullpath)
|
||||
} else {
|
||||
server.streamOnline(w, r, mtime, fullpath)
|
||||
}
|
||||
} else {
|
||||
server.serveFile(w, r, fullpath)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if localStatus == localExists {
|
||||
server.serveFile(w, r, fullpath)
|
||||
return
|
||||
}
|
||||
|
||||
// localExistsButNeedHead or localNotExists
|
||||
// Both need to go online.
|
||||
if ranged {
|
||||
server.serveRangedRequest(w, r, fullpath, mtime)
|
||||
} else {
|
||||
if ranged {
|
||||
<-server.streamOnline(nil, r, mtime, fullpath)
|
||||
server.serveFile(w, r, fullpath)
|
||||
} else {
|
||||
server.streamOnline(w, r, mtime, fullpath)
|
||||
// For full requests, we wait for the download to complete and then serve the file.
|
||||
// This maintains the original behavior for now.
|
||||
ch := server.startOrJoinStream(r, mtime, fullpath)
|
||||
if ch != nil {
|
||||
<-ch
|
||||
}
|
||||
server.serveFile(w, r, fullpath)
|
||||
}
|
||||
}
|
||||
|
||||
@ -214,240 +184,348 @@ func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key str
|
||||
return localNotExists, zeroTime, nil
|
||||
}
|
||||
|
||||
func (server *Server) streamOnline(w http.ResponseWriter, r *http.Request, mtime time.Time, key string) <-chan struct{} {
|
||||
func (server *Server) serveRangedRequest(w http.ResponseWriter, r *http.Request, key string, mtime time.Time) {
|
||||
server.lu.Lock()
|
||||
memoryObject, exists := server.o[r.URL.Path]
|
||||
locked := false
|
||||
defer func() {
|
||||
if locked {
|
||||
server.lu.Unlock()
|
||||
locked = false
|
||||
}
|
||||
}()
|
||||
if !exists {
|
||||
server.lu.Lock()
|
||||
locked = true
|
||||
|
||||
memoryObject, exists = server.o[r.URL.Path]
|
||||
}
|
||||
if exists {
|
||||
if locked {
|
||||
server.lu.Unlock()
|
||||
locked = false
|
||||
}
|
||||
|
||||
if w != nil {
|
||||
memoryObject.wg.Add(1)
|
||||
for k := range memoryObject.Headers {
|
||||
v := memoryObject.Headers.Get(k)
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
|
||||
if err := memoryObject.StreamTo(w, memoryObject.wg); err != nil {
|
||||
slog.With("error", err).Warn("failed to stream response with existing memory object")
|
||||
}
|
||||
}
|
||||
return preclosedChan
|
||||
} else {
|
||||
slog.With("mtime", mtime).Debug("checking fastest upstream")
|
||||
selectedIdx, response, chunks, err := server.fastestUpstream(r, mtime)
|
||||
if chunks == nil && mtime != zeroTime {
|
||||
slog.With("upstreamIdx", selectedIdx, "key", key).Debug("not modified. using local version")
|
||||
if w != nil {
|
||||
server.serveFile(w, r, key)
|
||||
}
|
||||
return preclosedChan
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
slog.With("error", err).Warn("failed to select fastest upstream")
|
||||
if w != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return preclosedChan
|
||||
}
|
||||
if selectedIdx == -1 || response == nil || chunks == nil {
|
||||
slog.Debug("no upstream is selected")
|
||||
if w != nil {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
return preclosedChan
|
||||
}
|
||||
if response.StatusCode == http.StatusNotModified {
|
||||
slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version")
|
||||
os.Chtimes(key, zeroTime, time.Now())
|
||||
if w != nil {
|
||||
server.serveFile(w, r, key)
|
||||
}
|
||||
return preclosedChan
|
||||
}
|
||||
|
||||
slog.With(
|
||||
"upstreamIdx", selectedIdx,
|
||||
).Debug("found fastest upstream")
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
|
||||
memoryObject = &StreamObject{
|
||||
Headers: response.Header,
|
||||
Buffer: buffer,
|
||||
|
||||
ctx: ctx,
|
||||
|
||||
wg: &sync.WaitGroup{},
|
||||
}
|
||||
|
||||
server.o[r.URL.Path] = memoryObject
|
||||
server.lu.Unlock()
|
||||
locked = false
|
||||
// This is the first request for this file, so we start the download.
|
||||
server.startOrJoinStream(r, mtime, key)
|
||||
|
||||
err = nil
|
||||
|
||||
if w != nil {
|
||||
memoryObject.wg.Add(1)
|
||||
|
||||
for k := range memoryObject.Headers {
|
||||
v := memoryObject.Headers.Get(k)
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
|
||||
go memoryObject.StreamTo(w, memoryObject.wg)
|
||||
// Re-acquire the lock to get the newly created stream object.
|
||||
server.lu.Lock()
|
||||
memoryObject = server.o[r.URL.Path]
|
||||
if memoryObject == nil {
|
||||
server.lu.Unlock()
|
||||
// This can happen if the upstream fails very quickly.
|
||||
http.Error(w, "Failed to start download stream", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
server.lu.Unlock()
|
||||
|
||||
for chunk := range chunks {
|
||||
if chunk.error != nil {
|
||||
err = chunk.error
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
slog.With("error", err).Warn("failed to read from upstream")
|
||||
}
|
||||
}
|
||||
if chunk.buffer == nil {
|
||||
break
|
||||
}
|
||||
n, _ := buffer.Write(chunk.buffer)
|
||||
memoryObject.Offset += n
|
||||
}
|
||||
cancel()
|
||||
// At this point, we have a memoryObject, and a producer goroutine is downloading the file.
|
||||
// Now we implement the consumer logic.
|
||||
|
||||
memoryObject.wg.Wait()
|
||||
// Duplicate the file descriptor from the producer's temp file. This creates a new,
|
||||
// independent file descriptor that points to the same underlying file description.
|
||||
// This is more efficient and robust than opening the file by path.
|
||||
fd, err := syscall.Dup(int(memoryObject.TempFile.Fd()))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to duplicate file descriptor", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// We create a new *os.File from the duplicated descriptor. The consumer is now
|
||||
// responsible for closing this new file.
|
||||
consumerFile := os.NewFile(uintptr(fd), memoryObject.TempFile.Name())
|
||||
if consumerFile == nil {
|
||||
syscall.Close(fd) // Clean up if NewFile fails
|
||||
http.Error(w, "Failed to create file from duplicated descriptor", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer consumerFile.Close()
|
||||
|
||||
if response.ContentLength > 0 {
|
||||
if memoryObject.Offset == int(response.ContentLength) && err != nil {
|
||||
if !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
|
||||
slog.With("read-length", memoryObject.Offset, "content-length", response.ContentLength, "error", err, "upstreamIdx", selectedIdx).Debug("something happened during download. but response body is read as whole. so error is reset to nil")
|
||||
}
|
||||
err = nil
|
||||
}
|
||||
} else if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
rangeHeader := r.Header.Get("Range")
|
||||
if rangeHeader == "" {
|
||||
// This should not happen if called from HandleRequestWithCache, but as a safeguard:
|
||||
http.Error(w, "Range header is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the Range header. We only support a single range like "bytes=start-end".
|
||||
var start, end int64
|
||||
parts := strings.Split(strings.TrimPrefix(rangeHeader, "bytes="), "-")
|
||||
if len(parts) != 2 {
|
||||
http.Error(w, "Invalid Range header", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
start, err = strconv.ParseInt(parts[0], 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid Range header", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if parts[1] != "" {
|
||||
end, err = strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
logger := slog.With("upstreamIdx", selectedIdx)
|
||||
logger.Error("something happened during download. will not cache this response. setting lingering to reset the connection.")
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
logger.Warn("response writer is not a hijacker. failed to set lingering")
|
||||
return preclosedChan
|
||||
}
|
||||
conn, _, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
logger.With("error", err).Warn("hijack failed. failed to set lingering")
|
||||
return preclosedChan
|
||||
}
|
||||
defer conn.Close()
|
||||
http.Error(w, "Invalid Range header", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// An empty end means "to the end of the file". We don't know the full size yet.
|
||||
// We'll have to handle this dynamically.
|
||||
end = -1 // Sentinel value
|
||||
}
|
||||
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
logger.With("error", err).Warn("connection is not a *net.TCPConn. failed to set lingering")
|
||||
return preclosedChan
|
||||
}
|
||||
if err := tcpConn.SetLinger(0); err != nil {
|
||||
logger.With("error", err).Warn("failed to set lingering")
|
||||
return preclosedChan
|
||||
}
|
||||
logger.Debug("connection set to linger. it will be reset once the conn.Close is called")
|
||||
memoryObject.mu.Lock()
|
||||
defer memoryObject.mu.Unlock()
|
||||
|
||||
// Wait until we have the headers.
|
||||
for memoryObject.Headers == nil && !memoryObject.Done {
|
||||
memoryObject.cond.Wait()
|
||||
}
|
||||
|
||||
if memoryObject.Error != nil {
|
||||
http.Error(w, memoryObject.Error.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
contentLengthStr := memoryObject.Headers.Get("Content-Length")
|
||||
totalSize, _ := strconv.ParseInt(contentLengthStr, 10, 64)
|
||||
if end == -1 || end >= totalSize {
|
||||
end = totalSize - 1
|
||||
}
|
||||
|
||||
if start >= totalSize || start > end {
|
||||
http.Error(w, "Range not satisfiable", http.StatusRequestedRangeNotSatisfiable)
|
||||
return
|
||||
}
|
||||
|
||||
var bytesSent int64
|
||||
bytesToSend := end - start + 1
|
||||
headersWritten := false
|
||||
|
||||
for bytesSent < bytesToSend && memoryObject.Error == nil {
|
||||
// Calculate what we need.
|
||||
neededStart := start + bytesSent
|
||||
|
||||
// Wait for the data to be available.
|
||||
for memoryObject.Offset <= neededStart && !memoryObject.Done {
|
||||
memoryObject.cond.Wait()
|
||||
}
|
||||
|
||||
fileWrittenCh := make(chan struct{})
|
||||
go func() {
|
||||
defer func() {
|
||||
server.lu.Lock()
|
||||
defer server.lu.Unlock()
|
||||
|
||||
delete(server.o, r.URL.Path)
|
||||
slog.Debug("memory object released")
|
||||
close(fileWrittenCh)
|
||||
}()
|
||||
|
||||
if err == nil {
|
||||
slog.Debug("preparing to release memory object")
|
||||
mtime := zeroTime
|
||||
lastModifiedHeader := response.Header.Get("Last-Modified")
|
||||
if lastModified, err := time.Parse(time.RFC1123, lastModifiedHeader); err != nil {
|
||||
slog.With(
|
||||
"error", err,
|
||||
"value", lastModifiedHeader,
|
||||
"url", response.Request.URL,
|
||||
).Debug("failed to parse last modified header value. set modified time to now")
|
||||
} else {
|
||||
slog.With(
|
||||
"header", lastModifiedHeader,
|
||||
"value", lastModified,
|
||||
"url", response.Request.URL,
|
||||
).Debug("found modified time")
|
||||
mtime = lastModified
|
||||
}
|
||||
if err := os.MkdirAll(server.Storage.Local.Path, 0755); err != nil {
|
||||
slog.With("error", err).Warn("failed to create local storage path")
|
||||
}
|
||||
|
||||
if server.Config.Storage.Local.TemporaryFilePattern == "" {
|
||||
if err := os.WriteFile(key, buffer.Bytes(), 0644); err != nil {
|
||||
slog.With("error", err).Warn("failed to write file")
|
||||
os.Remove(key)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
fp, err := os.CreateTemp(server.Storage.Local.Path, server.Storage.Local.TemporaryFilePattern)
|
||||
if err != nil {
|
||||
slog.With(
|
||||
"key", key,
|
||||
"path", server.Storage.Local.Path,
|
||||
"pattern", server.Storage.Local.TemporaryFilePattern,
|
||||
"error", err,
|
||||
).Warn("failed to create template file")
|
||||
return
|
||||
}
|
||||
|
||||
name := fp.Name()
|
||||
|
||||
if _, err := fp.Write(buffer.Bytes()); err != nil {
|
||||
fp.Close()
|
||||
os.Remove(name)
|
||||
|
||||
slog.With("error", err).Warn("failed to write into template file")
|
||||
} else if err := fp.Close(); err != nil {
|
||||
os.Remove(name)
|
||||
|
||||
slog.With("error", err).Warn("failed to close template file")
|
||||
} else {
|
||||
os.Chtimes(name, zeroTime, mtime)
|
||||
dirname := filepath.Dir(key)
|
||||
os.MkdirAll(dirname, 0755)
|
||||
os.Remove(key)
|
||||
os.Rename(name, key)
|
||||
}
|
||||
// Check for error AFTER waiting. This is the critical fix.
|
||||
if memoryObject.Error != nil {
|
||||
// If headers haven't been written, we can send a 500 error.
|
||||
// If they have, it's too late, the connection will just be closed.
|
||||
if !headersWritten {
|
||||
http.Error(w, memoryObject.Error.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
return // Use return instead of break to exit immediately.
|
||||
}
|
||||
|
||||
return fileWrittenCh
|
||||
// If we are here, we have some data to send. Write headers if we haven't already.
|
||||
if !headersWritten {
|
||||
w.Header().Set("Content-Range", "bytes "+strconv.FormatInt(start, 10)+"-"+strconv.FormatInt(end, 10)+"/"+contentLengthStr)
|
||||
w.Header().Set("Content-Length", strconv.FormatInt(end-start+1, 10))
|
||||
w.Header().Set("Accept-Ranges", "bytes")
|
||||
w.WriteHeader(http.StatusPartialContent)
|
||||
headersWritten = true
|
||||
}
|
||||
|
||||
// Data is available, read from the temporary file.
|
||||
// We calculate how much we can read in this iteration.
|
||||
readNow := memoryObject.Offset - neededStart
|
||||
if readNow <= 0 {
|
||||
// This can happen if we woke up but the data isn't what we need.
|
||||
// The loop will continue to wait.
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't read more than the client requested in total.
|
||||
remainingToSend := bytesToSend - bytesSent
|
||||
if readNow > remainingToSend {
|
||||
readNow = remainingToSend
|
||||
}
|
||||
|
||||
// Read the chunk from the file at the correct offset.
|
||||
buffer := make([]byte, readNow)
|
||||
bytesRead, err := consumerFile.ReadAt(buffer, neededStart)
|
||||
if err != nil && err != io.EOF {
|
||||
if !headersWritten {
|
||||
http.Error(w, "Error reading from cache stream", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if bytesRead > 0 {
|
||||
n, err := w.Write(buffer[:bytesRead])
|
||||
if err != nil {
|
||||
// Client closed connection, just return.
|
||||
return
|
||||
}
|
||||
bytesSent += int64(n)
|
||||
}
|
||||
|
||||
if memoryObject.Done && bytesSent >= bytesToSend {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, resultCh chan Chunk, resultErr error) {
|
||||
// startOrJoinStream ensures a download stream is active for the given key.
|
||||
// If a stream already exists, it returns the channel that signals completion.
|
||||
// If not, it starts a new download producer and returns its completion channel.
|
||||
func (server *Server) startOrJoinStream(r *http.Request, mtime time.Time, key string) <-chan struct{} {
|
||||
server.lu.Lock()
|
||||
memoryObject, exists := server.o[r.URL.Path]
|
||||
if exists {
|
||||
// A download is already in progress. Return its completion channel.
|
||||
server.lu.Unlock()
|
||||
return memoryObject.fileWrittenCh
|
||||
}
|
||||
|
||||
// No active stream, create a new one.
|
||||
if err := os.MkdirAll(filepath.Dir(key), 0755); err != nil {
|
||||
slog.With("error", err).Warn("failed to create local storage path for temp file")
|
||||
server.lu.Unlock()
|
||||
return preclosedChan // Return a closed channel to prevent blocking
|
||||
}
|
||||
|
||||
tempFilePattern := server.Storage.Local.TemporaryFilePattern
|
||||
if tempFilePattern == "" {
|
||||
tempFilePattern = "cache-proxy-*"
|
||||
}
|
||||
tempFile, err := os.CreateTemp(filepath.Dir(key), tempFilePattern)
|
||||
if err != nil {
|
||||
slog.With("error", err).Warn("failed to create temporary file")
|
||||
server.lu.Unlock()
|
||||
return preclosedChan
|
||||
}
|
||||
|
||||
mu := &sync.Mutex{}
|
||||
fileWrittenCh := make(chan struct{})
|
||||
memoryObject = &StreamObject{
|
||||
TempFile: tempFile,
|
||||
mu: mu,
|
||||
cond: sync.NewCond(mu),
|
||||
fileWrittenCh: fileWrittenCh,
|
||||
}
|
||||
server.o[r.URL.Path] = memoryObject
|
||||
server.lu.Unlock()
|
||||
|
||||
// This is the producer goroutine
|
||||
go func(mo *StreamObject) {
|
||||
var err error
|
||||
downloadSucceeded := false
|
||||
tempFileName := mo.TempFile.Name()
|
||||
|
||||
defer func() {
|
||||
// On completion (or error), update the stream object,
|
||||
// wake up all consumers, and then clean up.
|
||||
mo.mu.Lock()
|
||||
mo.Done = true
|
||||
mo.Error = err
|
||||
mo.cond.Broadcast()
|
||||
mo.mu.Unlock()
|
||||
|
||||
// Close the temp file handle.
|
||||
mo.TempFile.Close()
|
||||
|
||||
// If download failed, remove the temp file.
|
||||
if !downloadSucceeded {
|
||||
os.Remove(tempFileName)
|
||||
}
|
||||
|
||||
// The producer's job is done. Remove the object from the central map.
|
||||
// Any existing consumers still hold a reference to the object,
|
||||
// so it won't be garbage collected until they are done.
|
||||
server.lu.Lock()
|
||||
delete(server.o, r.URL.Path)
|
||||
server.lu.Unlock()
|
||||
slog.Debug("memory object released by producer")
|
||||
close(mo.fileWrittenCh)
|
||||
}()
|
||||
|
||||
slog.With("mtime", mtime).Debug("checking fastest upstream")
|
||||
selectedIdx, response, firstChunk, upstreamErr := server.fastestUpstream(r, mtime)
|
||||
if upstreamErr != nil {
|
||||
if !errors.Is(upstreamErr, io.EOF) && !errors.Is(upstreamErr, io.ErrUnexpectedEOF) {
|
||||
slog.With("error", upstreamErr).Warn("failed to select fastest upstream")
|
||||
err = upstreamErr
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if selectedIdx == -1 || response == nil {
|
||||
slog.Debug("no upstream is selected")
|
||||
err = errors.New("no suitable upstream found")
|
||||
return
|
||||
}
|
||||
|
||||
if response.StatusCode == http.StatusNotModified {
|
||||
slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version")
|
||||
os.Chtimes(key, zeroTime, time.Now())
|
||||
// In this case, we don't have a new file, so we just exit.
|
||||
// The temp file will be cleaned up by the defer.
|
||||
return
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
slog.With("upstreamIdx", selectedIdx).Debug("found fastest upstream")
|
||||
|
||||
mo.mu.Lock()
|
||||
mo.Headers = response.Header
|
||||
mo.cond.Broadcast() // Broadcast headers availability
|
||||
mo.mu.Unlock()
|
||||
|
||||
// Write the first chunk that we already downloaded.
|
||||
if len(firstChunk) > 0 {
|
||||
written, writeErr := mo.TempFile.Write(firstChunk)
|
||||
if writeErr != nil {
|
||||
err = writeErr
|
||||
return
|
||||
}
|
||||
mo.mu.Lock()
|
||||
mo.Offset += int64(written)
|
||||
mo.cond.Broadcast()
|
||||
mo.mu.Unlock()
|
||||
}
|
||||
|
||||
// Download the rest of the file in chunks
|
||||
buffer := make([]byte, server.Misc.ChunkBytes)
|
||||
for {
|
||||
n, readErr := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
written, writeErr := mo.TempFile.Write(buffer[:n])
|
||||
if writeErr != nil {
|
||||
err = writeErr
|
||||
break
|
||||
}
|
||||
mo.mu.Lock()
|
||||
mo.Offset += int64(written)
|
||||
mo.cond.Broadcast()
|
||||
mo.mu.Unlock()
|
||||
}
|
||||
if readErr != nil {
|
||||
if readErr != io.EOF {
|
||||
err = readErr
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// After download, if no critical error, rename the temp file to its final destination.
|
||||
if err == nil {
|
||||
// Set modification time
|
||||
mtime := zeroTime
|
||||
lastModifiedHeader := response.Header.Get("Last-Modified")
|
||||
if lastModified, lmErr := time.Parse(time.RFC1123, lastModifiedHeader); lmErr == nil {
|
||||
mtime = lastModified
|
||||
}
|
||||
// Close file before Chtimes and Rename
|
||||
mo.TempFile.Close()
|
||||
|
||||
os.Chtimes(tempFileName, zeroTime, mtime)
|
||||
|
||||
// Rename the file
|
||||
if renameErr := os.Rename(tempFileName, key); renameErr != nil {
|
||||
slog.With("error", renameErr, "from", tempFileName, "to", key).Warn("failed to rename temp file")
|
||||
err = renameErr
|
||||
os.Remove(tempFileName) // Attempt to clean up if rename fails
|
||||
} else {
|
||||
downloadSucceeded = true
|
||||
}
|
||||
} else {
|
||||
logger := slog.With("upstreamIdx", selectedIdx)
|
||||
logger.Error("something happened during download. will not cache this response.", "error", err)
|
||||
}
|
||||
}(memoryObject)
|
||||
|
||||
return fileWrittenCh
|
||||
}
|
||||
|
||||
func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, firstChunk []byte, resultErr error) {
|
||||
returnLock := &sync.Mutex{}
|
||||
upstreams := len(server.Upstreams)
|
||||
cancelFuncs := make([]func(), upstreams)
|
||||
@ -458,6 +536,8 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (
|
||||
|
||||
resultIdx = -1
|
||||
|
||||
var resultFirstChunk []byte
|
||||
|
||||
defer close(updateCh)
|
||||
defer close(notModifiedCh)
|
||||
defer func() {
|
||||
@ -472,7 +552,8 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (
|
||||
groups := make(map[int][]int)
|
||||
for upstreamIdx, upstream := range server.Upstreams {
|
||||
if _, matched, err := upstream.GetPath(r.URL.Path); err != nil {
|
||||
return -1, nil, nil, err
|
||||
resultErr = err
|
||||
return
|
||||
} else if !matched {
|
||||
continue
|
||||
}
|
||||
@ -480,7 +561,8 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (
|
||||
priority := 0
|
||||
for _, priorityGroup := range upstream.PriorityGroups {
|
||||
if matched, err := regexp.MatchString(priorityGroup.Match, r.URL.Path); err != nil {
|
||||
return -1, nil, nil, err
|
||||
resultErr = err
|
||||
return
|
||||
} else if matched {
|
||||
priority = priorityGroup.Priority
|
||||
break
|
||||
@ -516,7 +598,7 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
response, ch, err := server.tryUpstream(ctx, idx, priority, r, lastModified)
|
||||
response, chunk, err := server.tryUpstream(ctx, idx, priority, r, lastModified)
|
||||
if err == context.Canceled { // others returned
|
||||
logger.Debug("context canceled")
|
||||
return
|
||||
@ -536,13 +618,16 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (
|
||||
}
|
||||
locked := returnLock.TryLock()
|
||||
if !locked {
|
||||
if response != nil {
|
||||
response.Body.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
defer returnLock.Unlock()
|
||||
|
||||
if response.StatusCode == http.StatusNotModified {
|
||||
notModifiedOnce.Do(func() {
|
||||
resultResponse, resultCh, resultErr = response, ch, err
|
||||
resultResponse, resultErr = response, err
|
||||
notModifiedCh <- idx
|
||||
})
|
||||
logger.Debug("voted not modified")
|
||||
@ -550,7 +635,7 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (
|
||||
}
|
||||
|
||||
updateOnce.Do(func() {
|
||||
resultResponse, resultCh, resultErr = response, ch, err
|
||||
resultResponse, resultFirstChunk, resultErr = response, chunk, err
|
||||
updateCh <- idx
|
||||
|
||||
for cancelIdx, cancel := range cancelFuncs {
|
||||
@ -570,6 +655,7 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (
|
||||
select {
|
||||
case idx := <-updateCh:
|
||||
resultIdx = idx
|
||||
firstChunk = resultFirstChunk
|
||||
logger.With("upstreamIdx", resultIdx).Debug("upstream selected")
|
||||
return
|
||||
default:
|
||||
@ -584,10 +670,10 @@ func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (
|
||||
}
|
||||
}
|
||||
|
||||
return -1, nil, nil, nil
|
||||
return
|
||||
}
|
||||
|
||||
func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int, r *http.Request, lastModified time.Time) (response *http.Response, chunks chan Chunk, err error) {
|
||||
func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int, r *http.Request, lastModified time.Time) (response *http.Response, firstChunk []byte, err error) {
|
||||
upstream := server.Upstreams[upstreamIdx]
|
||||
|
||||
newpath, matched, err := upstream.GetPath(r.URL.Path)
|
||||
@ -633,9 +719,9 @@ func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
streaming := false
|
||||
shouldCloseBody := true
|
||||
defer func() {
|
||||
if !streaming && response != nil {
|
||||
if shouldCloseBody && response != nil {
|
||||
response.Body.Close()
|
||||
}
|
||||
}()
|
||||
@ -683,10 +769,6 @@ func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int
|
||||
}
|
||||
}
|
||||
|
||||
var currentOffset int64
|
||||
|
||||
ch := make(chan Chunk, 1024)
|
||||
|
||||
buffer := make([]byte, server.Misc.FirstChunkBytes)
|
||||
n, err := io.ReadAtLeast(response.Body, buffer, len(buffer))
|
||||
|
||||
@ -694,30 +776,12 @@ func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int
|
||||
if n == 0 {
|
||||
return response, nil, err
|
||||
}
|
||||
}
|
||||
ch <- Chunk{buffer: buffer[:n]}
|
||||
|
||||
streaming = true
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer response.Body.Close()
|
||||
|
||||
for {
|
||||
buffer := make([]byte, server.Misc.ChunkBytes)
|
||||
n, err := io.ReadAtLeast(response.Body, buffer, len(buffer))
|
||||
if n > 0 {
|
||||
ch <- Chunk{buffer: buffer[:n]}
|
||||
currentOffset += int64(n)
|
||||
}
|
||||
if response.ContentLength > 0 && currentOffset == response.ContentLength && err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
ch <- Chunk{error: err}
|
||||
return
|
||||
}
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return response, ch, nil
|
||||
shouldCloseBody = false
|
||||
return response, buffer[:n], err
|
||||
}
|
||||
|
Reference in New Issue
Block a user