package cacheproxy import ( "context" "errors" "io" "log/slog" "net/http" "os" "path/filepath" "regexp" "slices" "strconv" "strings" "sync" "syscall" "time" ) type reqCtxKey int const ( reqCtxAllowedRedirect reqCtxKey = iota ) var zeroTime time.Time var preclosedChan = make(chan struct{}) func init() { close(preclosedChan) } var ( httpClient = http.Client{ // check allowed redirect CheckRedirect: func(req *http.Request, via []*http.Request) error { if allowedRedirect, ok := req.Context().Value(reqCtxAllowedRedirect).(string); ok { if matched, err := regexp.MatchString(allowedRedirect, req.URL.String()); err != nil { return err } else if !matched { return http.ErrUseLastResponse } return nil } return http.ErrUseLastResponse }, } ) type StreamObject struct { Headers http.Header TempFile *os.File // The temporary file holding the download content. Offset int64 // The number of bytes written to TempFile. Done bool Error error mu *sync.Mutex cond *sync.Cond fileWrittenCh chan struct{} // Closed when the file is fully written and renamed. } type Server struct { Config lu *sync.Mutex o map[string]*StreamObject } func NewServer(config Config) *Server { return &Server{ Config: config, lu: &sync.Mutex{}, o: make(map[string]*StreamObject), } } func (server *Server) serveFile(w http.ResponseWriter, r *http.Request, path string) { if location := r.Header.Get(server.Storage.Local.Accel.EnableByHeader); server.Storage.Local.Accel.EnableByHeader != "" && location != "" { relPath, err := filepath.Rel(server.Storage.Local.Path, path) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } accelPath := filepath.Join(location, relPath) for _, headerKey := range server.Storage.Local.Accel.RespondWithHeaders { w.Header().Set(headerKey, accelPath) } return } http.ServeFile(w, r, path) } func (server *Server) HandleRequestWithCache(w http.ResponseWriter, r *http.Request) { fullpath := filepath.Join(server.Storage.Local.Path, r.URL.Path) fullpath, err := filepath.Abs(fullpath) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } if !strings.HasPrefix(fullpath, server.Storage.Local.Path) { http.Error(w, "crossing local directory boundary", http.StatusBadRequest) return } ranged := r.Header.Get("Range") != "" localStatus, mtime, err := server.checkLocal(w, r, fullpath) slog.With("status", localStatus, "mtime", mtime, "error", err, "key", fullpath).Debug("local status checked") if os.IsPermission(err) { http.Error(w, err.Error(), http.StatusForbidden) return } else if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } if localStatus == localExists { server.serveFile(w, r, fullpath) return } // localExistsButNeedHead or localNotExists // Both need to go online. if ranged { server.serveRangedRequest(w, r, fullpath, mtime) } else { // For full requests, we wait for the download to complete and then serve the file. // This maintains the original behavior for now. ch := server.startOrJoinStream(r, mtime, fullpath) if ch != nil { <-ch } server.serveFile(w, r, fullpath) } } type localStatus int const ( localNotExists localStatus = iota localExists localExistsButNeedHead ) func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key string) (exists localStatus, mtime time.Time, err error) { if stat, err := os.Stat(key); err == nil { refreshAfter := server.Cache.RefreshAfter refresh := "" for _, policy := range server.Cache.Policies { if match, err := regexp.MatchString(policy.Match, key); err != nil { return 0, zeroTime, err } else if match { if dur, err := time.ParseDuration(policy.RefreshAfter); err != nil { if slices.Contains([]string{"always", "never"}, policy.RefreshAfter) { refresh = policy.RefreshAfter } else { return 0, zeroTime, err } } else { refreshAfter = dur } break } } mtime := stat.ModTime() slog.With("policy", refresh, "after", refreshAfter, "mtime", mtime, "key", key).Debug("refresh policy checked") if (mtime.Add(refreshAfter).Before(time.Now()) || refresh == "always") && refresh != "never" { return localExistsButNeedHead, mtime.In(time.UTC), nil } return localExists, mtime.In(time.UTC), nil } else if os.IsPermission(err) { http.Error(w, err.Error(), http.StatusForbidden) } else if !os.IsNotExist(err) { return localNotExists, zeroTime, err } return localNotExists, zeroTime, nil } func (server *Server) serveRangedRequest(w http.ResponseWriter, r *http.Request, key string, mtime time.Time) { server.lu.Lock() memoryObject, exists := server.o[r.URL.Path] if !exists { server.lu.Unlock() // This is the first request for this file, so we start the download. server.startOrJoinStream(r, mtime, key) // Re-acquire the lock to get the newly created stream object. server.lu.Lock() memoryObject = server.o[r.URL.Path] if memoryObject == nil { server.lu.Unlock() // This can happen if the upstream fails very quickly. http.Error(w, "Failed to start download stream", http.StatusInternalServerError) return } } server.lu.Unlock() // At this point, we have a memoryObject, and a producer goroutine is downloading the file. // Now we implement the consumer logic. // Duplicate the file descriptor from the producer's temp file. This creates a new, // independent file descriptor that points to the same underlying file description. // This is more efficient and robust than opening the file by path. fd, err := syscall.Dup(int(memoryObject.TempFile.Fd())) if err != nil { http.Error(w, "Failed to duplicate file descriptor", http.StatusInternalServerError) return } // We create a new *os.File from the duplicated descriptor. The consumer is now // responsible for closing this new file. consumerFile := os.NewFile(uintptr(fd), memoryObject.TempFile.Name()) if consumerFile == nil { syscall.Close(fd) // Clean up if NewFile fails http.Error(w, "Failed to create file from duplicated descriptor", http.StatusInternalServerError) return } defer consumerFile.Close() rangeHeader := r.Header.Get("Range") if rangeHeader == "" { // This should not happen if called from HandleRequestWithCache, but as a safeguard: http.Error(w, "Range header is required", http.StatusBadRequest) return } // Parse the Range header. We only support a single range like "bytes=start-end". var start, end int64 parts := strings.Split(strings.TrimPrefix(rangeHeader, "bytes="), "-") if len(parts) != 2 { http.Error(w, "Invalid Range header", http.StatusBadRequest) return } start, err = strconv.ParseInt(parts[0], 10, 64) if err != nil { http.Error(w, "Invalid Range header", http.StatusBadRequest) return } if parts[1] != "" { end, err = strconv.ParseInt(parts[1], 10, 64) if err != nil { http.Error(w, "Invalid Range header", http.StatusBadRequest) return } } else { // An empty end means "to the end of the file". We don't know the full size yet. // We'll have to handle this dynamically. end = -1 // Sentinel value } memoryObject.mu.Lock() defer memoryObject.mu.Unlock() // Wait until we have the headers. for memoryObject.Headers == nil && !memoryObject.Done { memoryObject.cond.Wait() } if memoryObject.Error != nil { http.Error(w, memoryObject.Error.Error(), http.StatusInternalServerError) return } contentLengthStr := memoryObject.Headers.Get("Content-Length") totalSize, _ := strconv.ParseInt(contentLengthStr, 10, 64) if end == -1 || end >= totalSize { end = totalSize - 1 } if start >= totalSize || start > end { http.Error(w, "Range not satisfiable", http.StatusRequestedRangeNotSatisfiable) return } var bytesSent int64 bytesToSend := end - start + 1 headersWritten := false for bytesSent < bytesToSend && memoryObject.Error == nil { // Calculate what we need. neededStart := start + bytesSent // Wait for the data to be available. for memoryObject.Offset <= neededStart && !memoryObject.Done { memoryObject.cond.Wait() } // Check for error AFTER waiting. This is the critical fix. if memoryObject.Error != nil { // If headers haven't been written, we can send a 500 error. // If they have, it's too late, the connection will just be closed. if !headersWritten { http.Error(w, memoryObject.Error.Error(), http.StatusInternalServerError) } return // Use return instead of break to exit immediately. } // If we are here, we have some data to send. Write headers if we haven't already. if !headersWritten { w.Header().Set("Content-Range", "bytes "+strconv.FormatInt(start, 10)+"-"+strconv.FormatInt(end, 10)+"/"+contentLengthStr) w.Header().Set("Content-Length", strconv.FormatInt(end-start+1, 10)) w.Header().Set("Accept-Ranges", "bytes") w.WriteHeader(http.StatusPartialContent) headersWritten = true } // Data is available, read from the temporary file. // We calculate how much we can read in this iteration. readNow := memoryObject.Offset - neededStart if readNow <= 0 { // This can happen if we woke up but the data isn't what we need. // The loop will continue to wait. continue } // Don't read more than the client requested in total. remainingToSend := bytesToSend - bytesSent if readNow > remainingToSend { readNow = remainingToSend } // Read the chunk from the file at the correct offset. buffer := make([]byte, readNow) bytesRead, err := consumerFile.ReadAt(buffer, neededStart) if err != nil && err != io.EOF { if !headersWritten { http.Error(w, "Error reading from cache stream", http.StatusInternalServerError) } return } if bytesRead > 0 { n, err := w.Write(buffer[:bytesRead]) if err != nil { // Client closed connection, just return. return } bytesSent += int64(n) } if memoryObject.Done && bytesSent >= bytesToSend { break } } } // startOrJoinStream ensures a download stream is active for the given key. // If a stream already exists, it returns the channel that signals completion. // If not, it starts a new download producer and returns its completion channel. func (server *Server) startOrJoinStream(r *http.Request, mtime time.Time, key string) <-chan struct{} { server.lu.Lock() memoryObject, exists := server.o[r.URL.Path] if exists { // A download is already in progress. Return its completion channel. server.lu.Unlock() return memoryObject.fileWrittenCh } // No active stream, create a new one. if err := os.MkdirAll(filepath.Dir(key), 0755); err != nil { slog.With("error", err).Warn("failed to create local storage path for temp file") server.lu.Unlock() return preclosedChan // Return a closed channel to prevent blocking } tempFilePattern := server.Storage.Local.TemporaryFilePattern if tempFilePattern == "" { tempFilePattern = "cache-proxy-*" } tempFile, err := os.CreateTemp(filepath.Dir(key), tempFilePattern) if err != nil { slog.With("error", err).Warn("failed to create temporary file") server.lu.Unlock() return preclosedChan } mu := &sync.Mutex{} fileWrittenCh := make(chan struct{}) memoryObject = &StreamObject{ TempFile: tempFile, mu: mu, cond: sync.NewCond(mu), fileWrittenCh: fileWrittenCh, } server.o[r.URL.Path] = memoryObject server.lu.Unlock() // This is the producer goroutine go func(mo *StreamObject) { var err error downloadSucceeded := false tempFileName := mo.TempFile.Name() defer func() { // On completion (or error), update the stream object, // wake up all consumers, and then clean up. mo.mu.Lock() mo.Done = true mo.Error = err mo.cond.Broadcast() mo.mu.Unlock() // Close the temp file handle. mo.TempFile.Close() // If download failed, remove the temp file. if !downloadSucceeded { os.Remove(tempFileName) } // The producer's job is done. Remove the object from the central map. // Any existing consumers still hold a reference to the object, // so it won't be garbage collected until they are done. server.lu.Lock() delete(server.o, r.URL.Path) server.lu.Unlock() slog.Debug("memory object released by producer") close(mo.fileWrittenCh) }() slog.With("mtime", mtime).Debug("checking fastest upstream") selectedIdx, response, firstChunk, upstreamErr := server.fastestUpstream(r, mtime) if upstreamErr != nil { if !errors.Is(upstreamErr, io.EOF) && !errors.Is(upstreamErr, io.ErrUnexpectedEOF) { slog.With("error", upstreamErr).Warn("failed to select fastest upstream") err = upstreamErr return } } if selectedIdx == -1 || response == nil { slog.Debug("no upstream is selected") err = errors.New("no suitable upstream found") return } if response.StatusCode == http.StatusNotModified { slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version") os.Chtimes(key, zeroTime, time.Now()) // In this case, we don't have a new file, so we just exit. // The temp file will be cleaned up by the defer. return } defer response.Body.Close() slog.With("upstreamIdx", selectedIdx).Debug("found fastest upstream") mo.mu.Lock() mo.Headers = response.Header mo.cond.Broadcast() // Broadcast headers availability mo.mu.Unlock() // Write the first chunk that we already downloaded. if len(firstChunk) > 0 { written, writeErr := mo.TempFile.Write(firstChunk) if writeErr != nil { err = writeErr return } mo.mu.Lock() mo.Offset += int64(written) mo.cond.Broadcast() mo.mu.Unlock() } // Download the rest of the file in chunks buffer := make([]byte, server.Misc.ChunkBytes) for { n, readErr := response.Body.Read(buffer) if n > 0 { written, writeErr := mo.TempFile.Write(buffer[:n]) if writeErr != nil { err = writeErr break } mo.mu.Lock() mo.Offset += int64(written) mo.cond.Broadcast() mo.mu.Unlock() } if readErr != nil { if readErr != io.EOF { err = readErr } break } } // After download, if no critical error, rename the temp file to its final destination. if err == nil { // Set modification time mtime := zeroTime lastModifiedHeader := response.Header.Get("Last-Modified") if lastModified, lmErr := time.Parse(time.RFC1123, lastModifiedHeader); lmErr == nil { mtime = lastModified } // Close file before Chtimes and Rename mo.TempFile.Close() os.Chtimes(tempFileName, zeroTime, mtime) // Rename the file if renameErr := os.Rename(tempFileName, key); renameErr != nil { slog.With("error", renameErr, "from", tempFileName, "to", key).Warn("failed to rename temp file") err = renameErr os.Remove(tempFileName) // Attempt to clean up if rename fails } else { downloadSucceeded = true } } else { logger := slog.With("upstreamIdx", selectedIdx) logger.Error("something happened during download. will not cache this response.", "error", err) } }(memoryObject) return fileWrittenCh } func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, firstChunk []byte, resultErr error) { returnLock := &sync.Mutex{} upstreams := len(server.Upstreams) cancelFuncs := make([]func(), upstreams) updateCh := make(chan int, 1) updateOnce := &sync.Once{} notModifiedCh := make(chan int, 1) notModifiedOnce := &sync.Once{} resultIdx = -1 var resultFirstChunk []byte defer close(updateCh) defer close(notModifiedCh) defer func() { for cancelIdx, cancel := range cancelFuncs { if cancelIdx == resultIdx || cancel == nil { continue } cancel() } }() groups := make(map[int][]int) for upstreamIdx, upstream := range server.Upstreams { if _, matched, err := upstream.GetPath(r.URL.Path); err != nil { resultErr = err return } else if !matched { continue } priority := 0 for _, priorityGroup := range upstream.PriorityGroups { if matched, err := regexp.MatchString(priorityGroup.Match, r.URL.Path); err != nil { resultErr = err return } else if matched { priority = priorityGroup.Priority break } } groups[priority] = append(groups[priority], upstreamIdx) } priorities := make([]int, 0, len(groups)) for priority := range groups { priorities = append(priorities, priority) } slices.Sort(priorities) slices.Reverse(priorities) for _, priority := range priorities { upstreams := groups[priority] wg := &sync.WaitGroup{} wg.Add(len(upstreams)) logger := slog.With() if priority != 0 { logger = logger.With("priority", priority) } for _, idx := range upstreams { idx := idx ctx, cancel := context.WithCancel(context.Background()) cancelFuncs[idx] = cancel logger := logger.With("upstreamIdx", idx) go func() { defer wg.Done() response, chunk, err := server.tryUpstream(ctx, idx, priority, r, lastModified) if err == context.Canceled { // others returned logger.Debug("context canceled") return } if err != nil { if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { logger.With("error", err).Warn("upstream has error") } return } if response == nil { return } if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusNotModified { return } locked := returnLock.TryLock() if !locked { if response != nil { response.Body.Close() } return } defer returnLock.Unlock() if response.StatusCode == http.StatusNotModified { notModifiedOnce.Do(func() { resultResponse, resultErr = response, err notModifiedCh <- idx }) logger.Debug("voted not modified") return } updateOnce.Do(func() { resultResponse, resultFirstChunk, resultErr = response, chunk, err updateCh <- idx for cancelIdx, cancel := range cancelFuncs { if cancelIdx == idx || cancel == nil { continue } cancel() } }) logger.Debug("voted update") }() } wg.Wait() select { case idx := <-updateCh: resultIdx = idx firstChunk = resultFirstChunk logger.With("upstreamIdx", resultIdx).Debug("upstream selected") return default: select { case idx := <-notModifiedCh: resultIdx = idx logger.With("upstreamIdx", resultIdx).Debug("all upstream not modified") return default: logger.Debug("no valid upstream found") } } } return } func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int, r *http.Request, lastModified time.Time) (response *http.Response, firstChunk []byte, err error) { upstream := server.Upstreams[upstreamIdx] newpath, matched, err := upstream.GetPath(r.URL.Path) if err != nil { return nil, nil, err } if !matched { return nil, nil, nil } logger := slog.With("upstreamIdx", upstreamIdx, "server", upstream.Server, "path", newpath) if priority != 0 { logger = logger.With("priority", priority) } logger.With( "matched", matched, ).Debug("trying upstream") newurl := upstream.Server + newpath method := r.Method if lastModified != zeroTime { method = http.MethodGet } if upstream.AllowedRedirect != nil { ctx = context.WithValue(ctx, reqCtxAllowedRedirect, *upstream.AllowedRedirect) } request, err := http.NewRequestWithContext(ctx, method, newurl, nil) if err != nil { return nil, nil, err } if lastModified != zeroTime { request.Header.Set("If-Modified-Since", lastModified.Format(time.RFC1123)) } for _, k := range []string{"User-Agent", "Accept"} { if _, exists := r.Header[k]; exists { request.Header.Set(k, r.Header.Get(k)) } } response, err = httpClient.Do(request) if err != nil { return nil, nil, err } shouldCloseBody := true defer func() { if shouldCloseBody && response != nil { response.Body.Close() } }() if response.StatusCode == http.StatusNotModified { return response, nil, nil } responseCheckers := upstream.Checkers if len(responseCheckers) == 0 { responseCheckers = append(responseCheckers, Checker{}) } for _, checker := range responseCheckers { if len(checker.StatusCodes) == 0 { checker.StatusCodes = append(checker.StatusCodes, http.StatusOK) } if !slices.Contains(checker.StatusCodes, response.StatusCode) { return nil, nil, err } for _, headerChecker := range checker.Headers { if headerChecker.Match == nil { // check header exists if _, ok := response.Header[headerChecker.Name]; !ok { logger.Debug("missing header", "header", headerChecker.Name) return nil, nil, nil } } else { // check header match value := response.Header.Get(headerChecker.Name) if matched, err := regexp.MatchString(*headerChecker.Match, value); err != nil { return nil, nil, err } else if !matched { logger.Debug("invalid header value", "header", headerChecker.Name, "value", value, "matcher", *headerChecker.Match, ) return nil, nil, nil } } } } buffer := make([]byte, server.Misc.FirstChunkBytes) n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) if err != nil { if n == 0 { return response, nil, err } if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) { err = nil } } shouldCloseBody = false return response, buffer[:n], err }