package cacheproxy import ( "bytes" "context" "errors" "fmt" "io" "log/slog" "net/http" "os" "path/filepath" "regexp" "slices" "strings" "sync" "time" ) type reqCtxKey int const ( reqCtxAllowedRedirect reqCtxKey = iota ) var zeroTime time.Time var ( httpClient = http.Client{ // check allowed redirect CheckRedirect: func(req *http.Request, via []*http.Request) error { lastRequest := via[len(via)-1] if allowedRedirect, ok := lastRequest.Context().Value(reqCtxAllowedRedirect).(string); ok { if matched, err := regexp.MatchString(allowedRedirect, req.URL.String()); err != nil { return err } else if !matched { return http.ErrUseLastResponse } return nil } return http.ErrUseLastResponse }, } ) type StreamObject struct { Headers http.Header Buffer *bytes.Buffer Offset int ctx context.Context wg *sync.WaitGroup } func (memoryObject *StreamObject) StreamTo(w io.Writer, wg *sync.WaitGroup) error { defer wg.Done() offset := 0 if w == nil { w = io.Discard } OUTER: for { select { case <-memoryObject.ctx.Done(): break OUTER default: } newOffset := memoryObject.Offset if newOffset == offset { time.Sleep(time.Millisecond) continue } bytes := memoryObject.Buffer.Bytes()[offset:newOffset] written, err := w.Write(bytes) if err != nil { return err } offset += written } time.Sleep(time.Millisecond) slog.With( "start", offset, "end", memoryObject.Buffer.Len(), "n", memoryObject.Buffer.Len()-offset, ).Debug("remain bytes") _, err := w.Write(memoryObject.Buffer.Bytes()[offset:]) return err } type Server struct { Config lu *sync.Mutex o map[string]*StreamObject } func NewServer(config Config) *Server { return &Server{ Config: config, lu: &sync.Mutex{}, o: make(map[string]*StreamObject), } } type Chunk struct { buffer []byte error error } func (server *Server) serveFile(w http.ResponseWriter, r *http.Request, path string) { if location := r.Header.Get(server.Storage.Accel.EnableByHeader); server.Storage.Accel.EnableByHeader != "" && location != "" { relPath, err := filepath.Rel(server.Storage.Local.Path, path) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } accelPath := filepath.Join(location, relPath) for _, headerKey := range server.Storage.Accel.ResponseWithHeaders { w.Header().Set(headerKey, accelPath) } return } http.ServeFile(w, r, path) } func (server *Server) HandleRequestWithCache(w http.ResponseWriter, r *http.Request) { fullpath := filepath.Join(server.Storage.Local.Path, r.URL.Path) fullpath, err := filepath.Abs(fullpath) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } if !strings.HasPrefix(fullpath, server.Storage.Local.Path) { http.Error(w, "crossing local directory boundary", http.StatusBadRequest) return } ranged := r.Header.Get("Range") != "" localStatus, mtime, err := server.checkLocal(w, r, fullpath) slog.With("status", localStatus, "mtime", mtime, "error", err, "key", fullpath).Debug("local status checked") if os.IsPermission(err) { http.Error(w, err.Error(), http.StatusForbidden) } else if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } else if localStatus != localNotExists { if localStatus == localExistsButNeedHead { if ranged { server.streamOnline(nil, r, mtime, fullpath) server.serveFile(w, r, fullpath) } else { server.streamOnline(w, r, mtime, fullpath) } } else { server.serveFile(w, r, fullpath) } } else { if ranged { server.streamOnline(nil, r, mtime, fullpath) server.serveFile(w, r, fullpath) } else { server.streamOnline(w, r, mtime, fullpath) } } } type localStatus int const ( localNotExists localStatus = iota localExists localExistsButNeedHead ) func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key string) (exists localStatus, mtime time.Time, err error) { if stat, err := os.Stat(key); err == nil { refreshAfter := server.Cache.RefreshAfter refresh := "" for _, policy := range server.Cache.Policies { if match, err := regexp.MatchString(policy.Match, key); err != nil { return 0, zeroTime, err } else if match { if dur, err := time.ParseDuration(policy.RefreshAfter); err != nil { if slices.Contains([]string{"always", "never"}, policy.RefreshAfter) { refresh = policy.RefreshAfter } else { return 0, zeroTime, err } } else { refreshAfter = dur } break } } mtime := stat.ModTime() slog.With("policy", refresh, "after", refreshAfter, "mtime", mtime, "key", key).Debug("refresh policy checked") if (mtime.Add(refreshAfter).Before(time.Now()) || refresh == "always") && refresh != "never" { return localExistsButNeedHead, mtime.In(time.UTC), nil } return localExists, mtime.In(time.UTC), nil } else if os.IsPermission(err) { http.Error(w, err.Error(), http.StatusForbidden) } else if !os.IsNotExist(err) { return localNotExists, zeroTime, err } return localNotExists, zeroTime, nil } func (server *Server) streamOnline(w http.ResponseWriter, r *http.Request, mtime time.Time, key string) { memoryObject, exists := server.o[r.URL.Path] locked := false defer func() { if locked { server.lu.Unlock() locked = false } }() if !exists { server.lu.Lock() locked = true memoryObject, exists = server.o[r.URL.Path] } if exists { if locked { server.lu.Unlock() locked = false } if w != nil { memoryObject.wg.Add(1) for k := range memoryObject.Headers { v := memoryObject.Headers.Get(k) w.Header().Set(k, v) } if err := memoryObject.StreamTo(w, memoryObject.wg); err != nil { slog.With("error", err).Warn("failed to stream response with existing memory object") } } } else { slog.With("mtime", mtime).Debug("checking fastest upstream") selectedIdx, response, chunks, err := server.fastesUpstream(r, mtime) if chunks == nil && mtime != zeroTime { slog.With("upstreamIdx", selectedIdx, "key", key).Debug("not modified. using local version") if w != nil { server.serveFile(w, r, key) } return } if err != nil { slog.With("error", err).Warn("failed to select fastest upstream") http.Error(w, err.Error(), http.StatusInternalServerError) return } if selectedIdx == -1 || response == nil || chunks == nil { slog.Debug("no upstream is selected") http.NotFound(w, r) return } if response.StatusCode == http.StatusNotModified { slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version") os.Chtimes(key, zeroTime, time.Now()) server.serveFile(w, r, key) return } slog.With( "upstreamIdx", selectedIdx, ).Debug("found fastest upstream") buffer := &bytes.Buffer{} ctx, cancel := context.WithCancel(r.Context()) defer cancel() memoryObject = &StreamObject{ Headers: response.Header, Buffer: buffer, ctx: ctx, wg: &sync.WaitGroup{}, } server.o[r.URL.Path] = memoryObject server.lu.Unlock() locked = false err = nil if w != nil { memoryObject.wg.Add(1) for k := range memoryObject.Headers { v := memoryObject.Headers.Get(k) w.Header().Set(k, v) } go memoryObject.StreamTo(w, memoryObject.wg) } for chunk := range chunks { if chunk.error != nil { err = chunk.error slog.With("error", err).Warn("failed to read from upstream") } if chunk.buffer == nil { break } n, _ := buffer.Write(chunk.buffer) memoryObject.Offset += n } cancel() memoryObject.wg.Wait() if response.ContentLength > 0 { if memoryObject.Offset == int(response.ContentLength) && err != nil { if !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) { slog.With("read-length", memoryObject.Offset, "content-length", response.ContentLength, "error", err, "upstreamIdx", selectedIdx).Debug("something happened during download. but response body is read as whole. so error is reset to nil") } err = nil } } else if err == io.EOF { err = nil } if err != nil { slog.With("error", err, "upstreamIdx", selectedIdx).Error("something happened during download. will not cache this response") } go func() { defer func() { server.lu.Lock() defer server.lu.Unlock() delete(server.o, r.URL.Path) slog.Debug("memory object released") }() if err == nil { slog.Debug("preparing to release memory object") mtime := zeroTime lastModifiedHeader := response.Header.Get("Last-Modified") if lastModified, err := time.Parse(time.RFC1123, lastModifiedHeader); err != nil { slog.With( "error", err, "value", lastModifiedHeader, "url", response.Request.URL, ).Debug("failed to parse last modified header value. set modified time to now") } else { slog.With( "header", lastModifiedHeader, "value", lastModified, "url", response.Request.URL, ).Debug("found modified time") mtime = lastModified } if err := os.MkdirAll(server.Storage.Local.Path, 0755); err != nil { slog.With("error", err).Warn("failed to create local storage path") } if server.Config.Storage.Local.TemporaryFilePattern == "" { if err := os.WriteFile(key, buffer.Bytes(), 0644); err != nil { slog.With("error", err).Warn("failed to write file") os.Remove(key) } return } fp, err := os.CreateTemp(server.Storage.Local.Path, server.Storage.Local.TemporaryFilePattern) if err != nil { slog.With( "key", key, "path", server.Storage.Local.Path, "pattern", server.Storage.Local.TemporaryFilePattern, "error", err, ).Warn("failed to create template file") return } name := fp.Name() if _, err := fp.Write(buffer.Bytes()); err != nil { fp.Close() os.Remove(name) slog.With("error", err).Warn("failed to write into template file") } else if err := fp.Close(); err != nil { os.Remove(name) slog.With("error", err).Warn("failed to close template file") } else { os.Chtimes(name, zeroTime, mtime) dirname := filepath.Dir(key) os.MkdirAll(dirname, 0755) os.Remove(key) os.Rename(name, key) } } }() } } func (server *Server) fastesUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, resultCh chan Chunk, resultErr error) { returnLock := &sync.Mutex{} upstreams := len(server.Upstreams) cancelFuncs := make([]func(), upstreams) selectedCh := make(chan int, 1) selectedOnce := &sync.Once{} wg := &sync.WaitGroup{} wg.Add(len(server.Upstreams)) defer close(selectedCh) for idx := range server.Upstreams { idx := idx ctx, cancel := context.WithCancel(context.Background()) cancelFuncs[idx] = cancel logger := slog.With("upstreamIdx", idx) go func() { defer wg.Done() response, ch, err := server.tryUpstream(ctx, idx, r, lastModified) if err == context.Canceled { // others returned logger.Debug("context canceled") return } if err != nil { if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { logger.With("error", err).Warn("upstream has error") } return } if response == nil { return } if response.StatusCode != http.StatusOK { return } locked := returnLock.TryLock() if !locked { return } defer returnLock.Unlock() selectedOnce.Do(func() { resultResponse, resultCh, resultErr = response, ch, err selectedCh <- idx for cancelIdx, cancel := range cancelFuncs { if cancelIdx == idx { continue } cancel() } }) logger.Debug("voted") }() } wg.Wait() resultIdx = -1 select { case idx := <-selectedCh: resultIdx = idx slog.With("upstreamIdx", resultIdx).Debug("upstream selected") default: slog.Debug("no valid upstream found") } return } func (server *Server) tryUpstream(ctx context.Context, upstreamIdx int, r *http.Request, lastModified time.Time) (response *http.Response, chunks chan Chunk, err error) { upstream := server.Upstreams[upstreamIdx] newpath, matched, err := upstream.GetPath(r.URL.Path) if err != nil { return nil, nil, err } if !matched { return nil, nil, nil } logger := slog.With("upstreamIdx", upstreamIdx, "server", upstream.Server, "path", newpath) logger.With( "matched", matched, ).Debug("trying upstream") newurl := upstream.Server + newpath method := r.Method if lastModified != zeroTime { method = http.MethodGet } if upstream.AllowedRedirect != nil { ctx = context.WithValue(ctx, reqCtxAllowedRedirect, *upstream.AllowedRedirect) } request, err := http.NewRequestWithContext(ctx, method, newurl, nil) if err != nil { return nil, nil, err } if lastModified != zeroTime { request.Header.Set("If-Modified-Since", lastModified.Format(time.RFC1123)) } for _, k := range []string{"User-Agent", "Accept"} { if _, exists := r.Header[k]; exists { request.Header.Set(k, r.Header.Get(k)) } } response, err = httpClient.Do(request) if err != nil { return nil, nil, err } if response.StatusCode == http.StatusNotModified { return response, nil, nil } if response.StatusCode >= 400 && response.StatusCode < 500 { return nil, nil, nil } if response.StatusCode < 200 || response.StatusCode >= 500 { logger.With( "url", newurl, "status", response.StatusCode, ).Warn("unexpected status") return response, nil, fmt.Errorf("unexpected status(url=%v): %v: %v", newurl, response.StatusCode, response) } var currentOffset int64 ch := make(chan Chunk, 1024) buffer := make([]byte, server.Misc.FirstChunkBytes) n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) if err != nil { if n == 0 { return response, nil, err } } ch <- Chunk{buffer: buffer[:n]} go func() { defer close(ch) for { buffer := make([]byte, server.Misc.ChunkBytes) n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) if n > 0 { ch <- Chunk{buffer: buffer[:n]} currentOffset += int64(n) } if response.ContentLength > 0 && currentOffset == response.ContentLength && err == io.EOF || err == io.ErrUnexpectedEOF { return } if err != nil { ch <- Chunk{error: err} return } } }() return response, ch, nil }