From 45f69d994eb86d9fd39e87bc58a033c1dbe04999 Mon Sep 17 00:00:00 2001 From: guochao Date: Thu, 9 Jan 2025 23:30:42 +0800 Subject: [PATCH] move server into seperate package --- cmd/proxy/main.go | 694 ++++------------------------------------------ config.go | 65 +++++ server.go | 548 ++++++++++++++++++++++++++++++++++++ 3 files changed, 663 insertions(+), 644 deletions(-) create mode 100644 config.go create mode 100644 server.go diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 779f8b2..d74c40c 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -1,654 +1,18 @@ package main import ( - "bytes" - "context" - "errors" "flag" - "fmt" - "io" "log/slog" "net/http" "os" "path/filepath" - "regexp" - "slices" - "strings" - "sync" "time" + cacheproxy "git.jeffthecoder.xyz/guochao/cache-proxy" "github.com/getsentry/sentry-go" - "gopkg.in/yaml.v3" ) -var zeroTime time.Time - -type UpstreamMatch struct { - Match string `yaml:"match"` - Replace string `yaml:"replace"` -} - -type Upstream struct { - Server string `yaml:"server"` - Match UpstreamMatch `yaml:"match"` -} - -func (upstream Upstream) GetPath(orig string) (string, bool, error) { - if upstream.Match.Match == "" || upstream.Match.Replace == "" { - return orig, true, nil - } - matcher, err := regexp.Compile(upstream.Match.Match) - if err != nil { - return "", false, err - } - return matcher.ReplaceAllString(orig, upstream.Match.Replace), matcher.MatchString(orig), nil -} - -type LocalStorage struct { - Path string `yaml:"path"` - TemporaryFilePattern string `yaml:"temporary-file-pattern"` -} - -type Accel struct { - EnableByHeader string `yaml:"enable-by-header"` - ResponseWithHeaders []string `yaml:"response-with-headers"` -} - -type Storage struct { - Type string `yaml:"type"` - Local *LocalStorage `yaml:"local"` - Accel Accel `yaml:"accel"` -} - -type CachePolicyOnPath struct { - Match string `yaml:"match"` - RefreshAfter string `yaml:"refresh-after"` -} - -type Cache struct { - RefreshAfter time.Duration `yaml:"refresh-after"` - Policies []CachePolicyOnPath `yaml:"policies"` -} - -type MiscConfig struct { - FirstChunkBytes uint64 `yaml:"first-chunk-bytes"` - ChunkBytes uint64 `yaml:"chunk-bytes"` -} - -type Config struct { - Upstreams []Upstream `yaml:"upstream"` - Storage Storage `yaml:"storage"` - Cache Cache `yaml:"cache"` - Misc MiscConfig `yaml:"misc"` -} - -type StreamObject struct { - Headers http.Header - Buffer *bytes.Buffer - Offset int - - ctx context.Context - wg *sync.WaitGroup -} - -func (memoryObject *StreamObject) StreamTo(w io.Writer, wg *sync.WaitGroup) error { - defer wg.Done() - offset := 0 - if w == nil { - w = io.Discard - } -OUTER: - for { - select { - case <-memoryObject.ctx.Done(): - break OUTER - default: - } - - newOffset := memoryObject.Offset - if newOffset == offset { - time.Sleep(time.Millisecond) - continue - } - bytes := memoryObject.Buffer.Bytes()[offset:newOffset] - written, err := w.Write(bytes) - if err != nil { - return err - } - - offset += written - } - time.Sleep(time.Millisecond) - slog.With( - "start", offset, - "end", memoryObject.Buffer.Len(), - "n", memoryObject.Buffer.Len()-offset, - ).Debug("remain bytes") - - _, err := w.Write(memoryObject.Buffer.Bytes()[offset:]) - return err -} - -type Server struct { - Config - - lu *sync.Mutex - o map[string]*StreamObject -} - -type Chunk struct { - buffer []byte - error error -} - -func configFromFile(path string) (*Config, error) { - file, err := os.Open(path) - if err != nil { - return nil, err - } - defer file.Close() - - config := &Config{ - Upstreams: []Upstream{ - { - Server: "https://mirrors.ustc.edu.cn", - }, - }, - Storage: Storage{ - Type: "local", - Local: &LocalStorage{ - Path: "./data", - TemporaryFilePattern: "temp.*", - }, - Accel: Accel{ - ResponseWithHeaders: []string{"X-Sendfile", "X-Accel-Redirect"}, - }, - }, - Misc: MiscConfig{ - FirstChunkBytes: 1024 * 1024 * 50, - ChunkBytes: 1024 * 1024, - }, - Cache: Cache{ - RefreshAfter: time.Hour, - }, - } - - if err := yaml.NewDecoder(file).Decode(&config); err != nil { - return nil, err - } - - if config.Storage.Local != nil { - localPath, err := filepath.Abs(config.Storage.Local.Path) - if err != nil { - return nil, err - } - config.Storage.Local.Path = localPath - } - - return config, nil -} - -func (server *Server) serveFile(w http.ResponseWriter, r *http.Request, path string) { - if location := r.Header.Get(server.Storage.Accel.EnableByHeader); server.Storage.Accel.EnableByHeader != "" && location != "" { - relPath, err := filepath.Rel(server.Storage.Local.Path, path) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - accelPath := filepath.Join(location, relPath) - - for _, headerKey := range server.Storage.Accel.ResponseWithHeaders { - w.Header().Set(headerKey, accelPath) - } - - return - } - - http.ServeFile(w, r, path) -} - -func (server *Server) handleRequest(w http.ResponseWriter, r *http.Request) { - fullpath := filepath.Join(server.Storage.Local.Path, r.URL.Path) - fullpath, err := filepath.Abs(fullpath) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if !strings.HasPrefix(fullpath, server.Storage.Local.Path) { - http.Error(w, "crossing local directory boundary", http.StatusBadRequest) - return - } - - ranged := r.Header.Get("Range") != "" - - localStatus, mtime, err := server.checkLocal(w, r, fullpath) - slog.With("status", localStatus, "mtime", mtime, "error", err, "key", fullpath).Debug("local status checked") - if os.IsPermission(err) { - http.Error(w, err.Error(), http.StatusForbidden) - } else if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - } else if localStatus != localNotExists { - if localStatus == localExistsButNeedHead { - if ranged { - server.streamOnline(nil, r, mtime, fullpath) - server.serveFile(w, r, fullpath) - } else { - server.streamOnline(w, r, mtime, fullpath) - } - } else { - server.serveFile(w, r, fullpath) - } - } else { - if ranged { - server.streamOnline(nil, r, mtime, fullpath) - server.serveFile(w, r, fullpath) - } else { - server.streamOnline(w, r, mtime, fullpath) - } - } -} - -type localStatus int - -const ( - localNotExists localStatus = iota - localExists - localExistsButNeedHead -) - -func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key string) (exists localStatus, mtime time.Time, err error) { - if stat, err := os.Stat(key); err == nil { - refreshAfter := server.Cache.RefreshAfter - refresh := "" - - for _, policy := range server.Cache.Policies { - if match, err := regexp.MatchString(policy.Match, key); err != nil { - return 0, zeroTime, err - } else if match { - if dur, err := time.ParseDuration(policy.RefreshAfter); err != nil { - if slices.Contains([]string{"always", "never"}, policy.RefreshAfter) { - refresh = policy.RefreshAfter - } else { - return 0, zeroTime, err - } - } else { - refreshAfter = dur - } - break - } - } - mtime := stat.ModTime() - slog.With("policy", refresh, "after", refreshAfter, "mtime", mtime, "key", key).Debug("refresh policy checked") - if (mtime.Add(refreshAfter).Before(time.Now()) || refresh == "always") && refresh != "never" { - return localExistsButNeedHead, mtime.In(time.UTC), nil - } - return localExists, mtime.In(time.UTC), nil - } else if os.IsPermission(err) { - http.Error(w, err.Error(), http.StatusForbidden) - } else if !os.IsNotExist(err) { - return localNotExists, zeroTime, err - } - - return localNotExists, zeroTime, nil -} - -func (server *Server) streamOnline(w http.ResponseWriter, r *http.Request, mtime time.Time, key string) { - memoryObject, exists := server.o[r.URL.Path] - locked := false - defer func() { - if locked { - server.lu.Unlock() - locked = false - } - }() - if !exists { - server.lu.Lock() - locked = true - - memoryObject, exists = server.o[r.URL.Path] - } - if exists { - if locked { - server.lu.Unlock() - locked = false - } - - if w != nil { - memoryObject.wg.Add(1) - for k := range memoryObject.Headers { - v := memoryObject.Headers.Get(k) - w.Header().Set(k, v) - } - - if err := memoryObject.StreamTo(w, memoryObject.wg); err != nil { - slog.With("error", err).Warn("failed to stream response with existing memory object") - } - } - } else { - slog.With("mtime", mtime).Debug("checking fastest upstream") - selectedIdx, response, chunks, err := server.fastesUpstream(r, mtime) - if chunks == nil && mtime != zeroTime { - slog.With("upstreamIdx", selectedIdx, "key", key).Debug("not modified. using local version") - if w != nil { - server.serveFile(w, r, key) - } - return - } - - if err != nil { - slog.With("error", err).Warn("failed to select fastest upstream") - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if selectedIdx == -1 || response == nil || chunks == nil { - slog.Debug("no upstream is selected") - http.NotFound(w, r) - return - } - if response.StatusCode == http.StatusNotModified { - slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version") - os.Chtimes(key, zeroTime, time.Now()) - server.serveFile(w, r, key) - return - } - - slog.With( - "upstreamIdx", selectedIdx, - ).Debug("found fastest upstream") - - buffer := &bytes.Buffer{} - ctx, cancel := context.WithCancel(r.Context()) - defer cancel() - - memoryObject = &StreamObject{ - Headers: response.Header, - Buffer: buffer, - - ctx: ctx, - - wg: &sync.WaitGroup{}, - } - - server.o[r.URL.Path] = memoryObject - server.lu.Unlock() - locked = false - - err = nil - - if w != nil { - memoryObject.wg.Add(1) - - for k := range memoryObject.Headers { - v := memoryObject.Headers.Get(k) - w.Header().Set(k, v) - } - - go memoryObject.StreamTo(w, memoryObject.wg) - } - - for chunk := range chunks { - if chunk.error != nil { - err = chunk.error - slog.With("error", err).Warn("failed to read from upstream") - } - if chunk.buffer == nil { - break - } - n, _ := buffer.Write(chunk.buffer) - memoryObject.Offset += n - } - cancel() - - memoryObject.wg.Wait() - - if response.ContentLength > 0 { - if memoryObject.Offset == int(response.ContentLength) && err != nil { - if err != io.EOF { - slog.With("length", memoryObject.Offset, "error", err, "upstreamIdx", selectedIdx).Debug("something happened during download. but response body is read as whole. so error is reset to nil") - } - err = nil - } - } else if err == io.EOF { - err = nil - } - - if err != nil { - slog.With("error", err, "upstreamIdx", selectedIdx).Error("something happened during download. will not cache this response") - } - go func() { - defer func() { - server.lu.Lock() - defer server.lu.Unlock() - - delete(server.o, r.URL.Path) - slog.Debug("memory object released") - }() - - if err == nil { - slog.Debug("preparing to release memory object") - mtime := zeroTime - lastModifiedHeader := response.Header.Get("Last-Modified") - if lastModified, err := time.Parse(time.RFC1123, lastModifiedHeader); err != nil { - slog.With( - "error", err, - "value", lastModifiedHeader, - "url", response.Request.URL, - ).Debug("failed to parse last modified header value. set modified time to now") - } else { - slog.With( - "header", lastModifiedHeader, - "value", lastModified, - "url", response.Request.URL, - ).Debug("found modified time") - mtime = lastModified - } - if err := os.MkdirAll(server.Storage.Local.Path, 0755); err != nil { - slog.With("error", err).Warn("failed to create local storage path") - } - - if server.Config.Storage.Local.TemporaryFilePattern == "" { - if err := os.WriteFile(key, buffer.Bytes(), 0644); err != nil { - slog.With("error", err).Warn("failed to write file") - os.Remove(key) - } - return - } - - fp, err := os.CreateTemp(server.Storage.Local.Path, server.Storage.Local.TemporaryFilePattern) - if err != nil { - slog.With( - "key", key, - "path", server.Storage.Local.Path, - "pattern", server.Storage.Local.TemporaryFilePattern, - "error", err, - ).Warn("failed to create template file") - return - } - - name := fp.Name() - - if _, err := fp.Write(buffer.Bytes()); err != nil { - fp.Close() - os.Remove(name) - - slog.With("error", err).Warn("failed to write into template file") - } else if err := fp.Close(); err != nil { - os.Remove(name) - - slog.With("error", err).Warn("failed to close template file") - } else { - os.Chtimes(name, zeroTime, mtime) - dirname := filepath.Dir(key) - os.MkdirAll(dirname, 0755) - os.Remove(key) - os.Rename(name, key) - } - } - }() - - } -} - -func (server *Server) fastesUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, resultCh chan Chunk, resultErr error) { - returnLock := &sync.Mutex{} - upstreams := len(server.Upstreams) - cancelFuncs := make([]func(), upstreams) - selectedCh := make(chan int, 1) - selectedOnce := &sync.Once{} - wg := &sync.WaitGroup{} - wg.Add(len(server.Upstreams)) - - defer close(selectedCh) - for idx := range server.Upstreams { - idx := idx - ctx, cancel := context.WithCancel(context.Background()) - cancelFuncs[idx] = cancel - - logger := slog.With("upstreamIdx", idx) - - go func() { - defer wg.Done() - response, ch, err := server.tryUpstream(ctx, idx, r, lastModified) - if err == context.Canceled { // others returned - logger.Debug("context canceled") - return - } - - if err != nil { - if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { - logger.With("error", err).Warn("upstream has error") - } - return - } - if response == nil { - return - } - locked := returnLock.TryLock() - if !locked { - return - } - defer returnLock.Unlock() - - selectedOnce.Do(func() { - resultResponse, resultCh, resultErr = response, ch, err - selectedCh <- idx - - for cancelIdx, cancel := range cancelFuncs { - if cancelIdx == idx { - continue - } - cancel() - } - }) - - logger.Debug("voted") - }() - } - - wg.Wait() - - resultIdx = -1 - select { - case idx := <-selectedCh: - resultIdx = idx - slog.With("upstreamIdx", resultIdx).Debug("upstream selected") - default: - slog.Debug("no valid upstream found") - } - - return -} - -func (server *Server) tryUpstream(ctx context.Context, upstreamIdx int, r *http.Request, lastModified time.Time) (response *http.Response, chunks chan Chunk, err error) { - upstream := server.Upstreams[upstreamIdx] - - newpath, matched, err := upstream.GetPath(r.URL.Path) - if err != nil { - return nil, nil, err - } - if !matched { - return nil, nil, nil - } - logger := slog.With("upstreamIdx", upstreamIdx, "server", upstream.Server, "path", newpath) - - logger.With( - "matched", matched, - ).Debug("trying upstream") - - newurl := upstream.Server + newpath - method := r.Method - if lastModified != zeroTime { - method = http.MethodGet - } - request, err := http.NewRequestWithContext(ctx, method, newurl, nil) - if err != nil { - return nil, nil, err - } - if lastModified != zeroTime { - request.Header.Set("If-Modified-Since", lastModified.Format(time.RFC1123)) - } - - for _, k := range []string{"User-Agent"} { - if _, exists := request.Header[k]; exists { - request.Header.Set(k, r.Header.Get(k)) - } - } - response, err = http.DefaultClient.Do(request) - if err != nil { - return nil, nil, err - } - if response.StatusCode == http.StatusNotModified { - return response, nil, nil - } - if response.StatusCode >= 400 && response.StatusCode < 500 { - return nil, nil, nil - } - if response.StatusCode < 200 || response.StatusCode >= 500 { - logger.With( - "url", newurl, - "status", response.StatusCode, - ).Warn("unexpected status") - return response, nil, fmt.Errorf("unexpected status(url=%v): %v: %v", newurl, response.StatusCode, response) - } - - var currentOffset int64 - - ch := make(chan Chunk, 1024) - - buffer := make([]byte, server.Misc.FirstChunkBytes) - n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) - - if err != nil { - if n == 0 { - return response, nil, err - } - } - ch <- Chunk{buffer: buffer[:n]} - - go func() { - defer close(ch) - - for { - buffer := make([]byte, server.Misc.ChunkBytes) - n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) - if n > 0 { - ch <- Chunk{buffer: buffer[:n]} - currentOffset += int64(n) - } - if response.ContentLength > 0 && currentOffset == response.ContentLength && err == io.EOF || err == io.ErrUnexpectedEOF { - return - } - if err != nil { - ch <- Chunk{error: err} - return - } - } - }() - - return response, ch, nil -} - var ( configFilePath = "config.yaml" logLevel = "info" @@ -670,6 +34,53 @@ func init() { flag.StringVar(&sentrydsn, "sentry", sentrydsn, "sentry dsn to report errors") } +func configFromFile(path string) (*cacheproxy.Config, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + config := &cacheproxy.Config{ + Upstreams: []cacheproxy.Upstream{ + { + Server: "https://mirrors.ustc.edu.cn", + }, + }, + Storage: cacheproxy.Storage{ + Type: "local", + Local: &cacheproxy.LocalStorage{ + Path: "./data", + TemporaryFilePattern: "temp.*", + }, + Accel: cacheproxy.Accel{ + ResponseWithHeaders: []string{"X-Sendfile", "X-Accel-Redirect"}, + }, + }, + Misc: cacheproxy.MiscConfig{ + FirstChunkBytes: 1024 * 1024 * 50, + ChunkBytes: 1024 * 1024, + }, + Cache: cacheproxy.Cache{ + RefreshAfter: time.Hour, + }, + } + + if err := yaml.NewDecoder(file).Decode(&config); err != nil { + return nil, err + } + + if config.Storage.Local != nil { + localPath, err := filepath.Abs(config.Storage.Local.Path) + if err != nil { + return nil, err + } + config.Storage.Local.Path = localPath + } + + return config, nil +} + func main() { flag.Parse() @@ -702,14 +113,9 @@ func main() { ch <- idx } - server := Server{ - Config: *config, + server := cacheproxy.NewServer(*config) - lu: &sync.Mutex{}, - o: make(map[string]*StreamObject), - } - - http.HandleFunc("GET /{path...}", server.handleRequest) + http.HandleFunc("GET /{path...}", server.HandleRequestWithCache) slog.With("addr", ":8881").Info("serving app") if err := http.ListenAndServe(":8881", nil); err != nil { slog.With("error", err).Error("failed to start server") diff --git a/config.go b/config.go new file mode 100644 index 0000000..7ae196b --- /dev/null +++ b/config.go @@ -0,0 +1,65 @@ +package cacheproxy + +import ( + "regexp" + "time" +) + +type UpstreamMatch struct { + Match string `yaml:"match"` + Replace string `yaml:"replace"` +} + +type Upstream struct { + Server string `yaml:"server"` + Match UpstreamMatch `yaml:"match"` +} + +func (upstream Upstream) GetPath(orig string) (string, bool, error) { + if upstream.Match.Match == "" || upstream.Match.Replace == "" { + return orig, true, nil + } + matcher, err := regexp.Compile(upstream.Match.Match) + if err != nil { + return "", false, err + } + return matcher.ReplaceAllString(orig, upstream.Match.Replace), matcher.MatchString(orig), nil +} + +type LocalStorage struct { + Path string `yaml:"path"` + TemporaryFilePattern string `yaml:"temporary-file-pattern"` +} + +type Accel struct { + EnableByHeader string `yaml:"enable-by-header"` + ResponseWithHeaders []string `yaml:"response-with-headers"` +} + +type Storage struct { + Type string `yaml:"type"` + Local *LocalStorage `yaml:"local"` + Accel Accel `yaml:"accel"` +} + +type CachePolicyOnPath struct { + Match string `yaml:"match"` + RefreshAfter string `yaml:"refresh-after"` +} + +type Cache struct { + RefreshAfter time.Duration `yaml:"refresh-after"` + Policies []CachePolicyOnPath `yaml:"policies"` +} + +type MiscConfig struct { + FirstChunkBytes uint64 `yaml:"first-chunk-bytes"` + ChunkBytes uint64 `yaml:"chunk-bytes"` +} + +type Config struct { + Upstreams []Upstream `yaml:"upstream"` + Storage Storage `yaml:"storage"` + Cache Cache `yaml:"cache"` + Misc MiscConfig `yaml:"misc"` +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..a090feb --- /dev/null +++ b/server.go @@ -0,0 +1,548 @@ +package cacheproxy + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "path/filepath" + "regexp" + "slices" + "strings" + "sync" + "time" +) + +var zeroTime time.Time + +type StreamObject struct { + Headers http.Header + Buffer *bytes.Buffer + Offset int + + ctx context.Context + wg *sync.WaitGroup +} + +func (memoryObject *StreamObject) StreamTo(w io.Writer, wg *sync.WaitGroup) error { + defer wg.Done() + offset := 0 + if w == nil { + w = io.Discard + } +OUTER: + for { + select { + case <-memoryObject.ctx.Done(): + break OUTER + default: + } + + newOffset := memoryObject.Offset + if newOffset == offset { + time.Sleep(time.Millisecond) + continue + } + bytes := memoryObject.Buffer.Bytes()[offset:newOffset] + written, err := w.Write(bytes) + if err != nil { + return err + } + + offset += written + } + time.Sleep(time.Millisecond) + slog.With( + "start", offset, + "end", memoryObject.Buffer.Len(), + "n", memoryObject.Buffer.Len()-offset, + ).Debug("remain bytes") + + _, err := w.Write(memoryObject.Buffer.Bytes()[offset:]) + return err +} + +type Server struct { + Config + + lu *sync.Mutex + o map[string]*StreamObject +} + +func NewServer(config Config) *Server { + return &Server{ + Config: config, + + lu: &sync.Mutex{}, + o: make(map[string]*StreamObject), + } +} + +type Chunk struct { + buffer []byte + error error +} + +func (server *Server) serveFile(w http.ResponseWriter, r *http.Request, path string) { + if location := r.Header.Get(server.Storage.Accel.EnableByHeader); server.Storage.Accel.EnableByHeader != "" && location != "" { + relPath, err := filepath.Rel(server.Storage.Local.Path, path) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + accelPath := filepath.Join(location, relPath) + + for _, headerKey := range server.Storage.Accel.ResponseWithHeaders { + w.Header().Set(headerKey, accelPath) + } + + return + } + + http.ServeFile(w, r, path) +} + +func (server *Server) HandleRequestWithCache(w http.ResponseWriter, r *http.Request) { + fullpath := filepath.Join(server.Storage.Local.Path, r.URL.Path) + fullpath, err := filepath.Abs(fullpath) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if !strings.HasPrefix(fullpath, server.Storage.Local.Path) { + http.Error(w, "crossing local directory boundary", http.StatusBadRequest) + return + } + + ranged := r.Header.Get("Range") != "" + + localStatus, mtime, err := server.checkLocal(w, r, fullpath) + slog.With("status", localStatus, "mtime", mtime, "error", err, "key", fullpath).Debug("local status checked") + if os.IsPermission(err) { + http.Error(w, err.Error(), http.StatusForbidden) + } else if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } else if localStatus != localNotExists { + if localStatus == localExistsButNeedHead { + if ranged { + server.streamOnline(nil, r, mtime, fullpath) + server.serveFile(w, r, fullpath) + } else { + server.streamOnline(w, r, mtime, fullpath) + } + } else { + server.serveFile(w, r, fullpath) + } + } else { + if ranged { + server.streamOnline(nil, r, mtime, fullpath) + server.serveFile(w, r, fullpath) + } else { + server.streamOnline(w, r, mtime, fullpath) + } + } +} + +type localStatus int + +const ( + localNotExists localStatus = iota + localExists + localExistsButNeedHead +) + +func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key string) (exists localStatus, mtime time.Time, err error) { + if stat, err := os.Stat(key); err == nil { + refreshAfter := server.Cache.RefreshAfter + refresh := "" + + for _, policy := range server.Cache.Policies { + if match, err := regexp.MatchString(policy.Match, key); err != nil { + return 0, zeroTime, err + } else if match { + if dur, err := time.ParseDuration(policy.RefreshAfter); err != nil { + if slices.Contains([]string{"always", "never"}, policy.RefreshAfter) { + refresh = policy.RefreshAfter + } else { + return 0, zeroTime, err + } + } else { + refreshAfter = dur + } + break + } + } + mtime := stat.ModTime() + slog.With("policy", refresh, "after", refreshAfter, "mtime", mtime, "key", key).Debug("refresh policy checked") + if (mtime.Add(refreshAfter).Before(time.Now()) || refresh == "always") && refresh != "never" { + return localExistsButNeedHead, mtime.In(time.UTC), nil + } + return localExists, mtime.In(time.UTC), nil + } else if os.IsPermission(err) { + http.Error(w, err.Error(), http.StatusForbidden) + } else if !os.IsNotExist(err) { + return localNotExists, zeroTime, err + } + + return localNotExists, zeroTime, nil +} + +func (server *Server) streamOnline(w http.ResponseWriter, r *http.Request, mtime time.Time, key string) { + memoryObject, exists := server.o[r.URL.Path] + locked := false + defer func() { + if locked { + server.lu.Unlock() + locked = false + } + }() + if !exists { + server.lu.Lock() + locked = true + + memoryObject, exists = server.o[r.URL.Path] + } + if exists { + if locked { + server.lu.Unlock() + locked = false + } + + if w != nil { + memoryObject.wg.Add(1) + for k := range memoryObject.Headers { + v := memoryObject.Headers.Get(k) + w.Header().Set(k, v) + } + + if err := memoryObject.StreamTo(w, memoryObject.wg); err != nil { + slog.With("error", err).Warn("failed to stream response with existing memory object") + } + } + } else { + slog.With("mtime", mtime).Debug("checking fastest upstream") + selectedIdx, response, chunks, err := server.fastesUpstream(r, mtime) + if chunks == nil && mtime != zeroTime { + slog.With("upstreamIdx", selectedIdx, "key", key).Debug("not modified. using local version") + if w != nil { + server.serveFile(w, r, key) + } + return + } + + if err != nil { + slog.With("error", err).Warn("failed to select fastest upstream") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if selectedIdx == -1 || response == nil || chunks == nil { + slog.Debug("no upstream is selected") + http.NotFound(w, r) + return + } + if response.StatusCode == http.StatusNotModified { + slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version") + os.Chtimes(key, zeroTime, time.Now()) + server.serveFile(w, r, key) + return + } + + slog.With( + "upstreamIdx", selectedIdx, + ).Debug("found fastest upstream") + + buffer := &bytes.Buffer{} + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + memoryObject = &StreamObject{ + Headers: response.Header, + Buffer: buffer, + + ctx: ctx, + + wg: &sync.WaitGroup{}, + } + + server.o[r.URL.Path] = memoryObject + server.lu.Unlock() + locked = false + + err = nil + + if w != nil { + memoryObject.wg.Add(1) + + for k := range memoryObject.Headers { + v := memoryObject.Headers.Get(k) + w.Header().Set(k, v) + } + + go memoryObject.StreamTo(w, memoryObject.wg) + } + + for chunk := range chunks { + if chunk.error != nil { + err = chunk.error + slog.With("error", err).Warn("failed to read from upstream") + } + if chunk.buffer == nil { + break + } + n, _ := buffer.Write(chunk.buffer) + memoryObject.Offset += n + } + cancel() + + memoryObject.wg.Wait() + + if response.ContentLength > 0 { + if memoryObject.Offset == int(response.ContentLength) && err != nil { + if err != io.EOF { + slog.With("length", memoryObject.Offset, "error", err, "upstreamIdx", selectedIdx).Debug("something happened during download. but response body is read as whole. so error is reset to nil") + } + err = nil + } + } else if err == io.EOF { + err = nil + } + + if err != nil { + slog.With("error", err, "upstreamIdx", selectedIdx).Error("something happened during download. will not cache this response") + } + go func() { + defer func() { + server.lu.Lock() + defer server.lu.Unlock() + + delete(server.o, r.URL.Path) + slog.Debug("memory object released") + }() + + if err == nil { + slog.Debug("preparing to release memory object") + mtime := zeroTime + lastModifiedHeader := response.Header.Get("Last-Modified") + if lastModified, err := time.Parse(time.RFC1123, lastModifiedHeader); err != nil { + slog.With( + "error", err, + "value", lastModifiedHeader, + "url", response.Request.URL, + ).Debug("failed to parse last modified header value. set modified time to now") + } else { + slog.With( + "header", lastModifiedHeader, + "value", lastModified, + "url", response.Request.URL, + ).Debug("found modified time") + mtime = lastModified + } + if err := os.MkdirAll(server.Storage.Local.Path, 0755); err != nil { + slog.With("error", err).Warn("failed to create local storage path") + } + + if server.Config.Storage.Local.TemporaryFilePattern == "" { + if err := os.WriteFile(key, buffer.Bytes(), 0644); err != nil { + slog.With("error", err).Warn("failed to write file") + os.Remove(key) + } + return + } + + fp, err := os.CreateTemp(server.Storage.Local.Path, server.Storage.Local.TemporaryFilePattern) + if err != nil { + slog.With( + "key", key, + "path", server.Storage.Local.Path, + "pattern", server.Storage.Local.TemporaryFilePattern, + "error", err, + ).Warn("failed to create template file") + return + } + + name := fp.Name() + + if _, err := fp.Write(buffer.Bytes()); err != nil { + fp.Close() + os.Remove(name) + + slog.With("error", err).Warn("failed to write into template file") + } else if err := fp.Close(); err != nil { + os.Remove(name) + + slog.With("error", err).Warn("failed to close template file") + } else { + os.Chtimes(name, zeroTime, mtime) + dirname := filepath.Dir(key) + os.MkdirAll(dirname, 0755) + os.Remove(key) + os.Rename(name, key) + } + } + }() + + } +} + +func (server *Server) fastesUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, resultCh chan Chunk, resultErr error) { + returnLock := &sync.Mutex{} + upstreams := len(server.Upstreams) + cancelFuncs := make([]func(), upstreams) + selectedCh := make(chan int, 1) + selectedOnce := &sync.Once{} + wg := &sync.WaitGroup{} + wg.Add(len(server.Upstreams)) + + defer close(selectedCh) + for idx := range server.Upstreams { + idx := idx + ctx, cancel := context.WithCancel(context.Background()) + cancelFuncs[idx] = cancel + + logger := slog.With("upstreamIdx", idx) + + go func() { + defer wg.Done() + response, ch, err := server.tryUpstream(ctx, idx, r, lastModified) + if err == context.Canceled { // others returned + logger.Debug("context canceled") + return + } + + if err != nil { + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + logger.With("error", err).Warn("upstream has error") + } + return + } + if response == nil { + return + } + locked := returnLock.TryLock() + if !locked { + return + } + defer returnLock.Unlock() + + selectedOnce.Do(func() { + resultResponse, resultCh, resultErr = response, ch, err + selectedCh <- idx + + for cancelIdx, cancel := range cancelFuncs { + if cancelIdx == idx { + continue + } + cancel() + } + }) + + logger.Debug("voted") + }() + } + + wg.Wait() + + resultIdx = -1 + select { + case idx := <-selectedCh: + resultIdx = idx + slog.With("upstreamIdx", resultIdx).Debug("upstream selected") + default: + slog.Debug("no valid upstream found") + } + + return +} + +func (server *Server) tryUpstream(ctx context.Context, upstreamIdx int, r *http.Request, lastModified time.Time) (response *http.Response, chunks chan Chunk, err error) { + upstream := server.Upstreams[upstreamIdx] + + newpath, matched, err := upstream.GetPath(r.URL.Path) + if err != nil { + return nil, nil, err + } + if !matched { + return nil, nil, nil + } + logger := slog.With("upstreamIdx", upstreamIdx, "server", upstream.Server, "path", newpath) + + logger.With( + "matched", matched, + ).Debug("trying upstream") + + newurl := upstream.Server + newpath + method := r.Method + if lastModified != zeroTime { + method = http.MethodGet + } + request, err := http.NewRequestWithContext(ctx, method, newurl, nil) + if err != nil { + return nil, nil, err + } + if lastModified != zeroTime { + request.Header.Set("If-Modified-Since", lastModified.Format(time.RFC1123)) + } + + for _, k := range []string{"User-Agent"} { + if _, exists := request.Header[k]; exists { + request.Header.Set(k, r.Header.Get(k)) + } + } + response, err = http.DefaultClient.Do(request) + if err != nil { + return nil, nil, err + } + if response.StatusCode == http.StatusNotModified { + return response, nil, nil + } + if response.StatusCode >= 400 && response.StatusCode < 500 { + return nil, nil, nil + } + if response.StatusCode < 200 || response.StatusCode >= 500 { + logger.With( + "url", newurl, + "status", response.StatusCode, + ).Warn("unexpected status") + return response, nil, fmt.Errorf("unexpected status(url=%v): %v: %v", newurl, response.StatusCode, response) + } + + var currentOffset int64 + + ch := make(chan Chunk, 1024) + + buffer := make([]byte, server.Misc.FirstChunkBytes) + n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) + + if err != nil { + if n == 0 { + return response, nil, err + } + } + ch <- Chunk{buffer: buffer[:n]} + + go func() { + defer close(ch) + + for { + buffer := make([]byte, server.Misc.ChunkBytes) + n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) + if n > 0 { + ch <- Chunk{buffer: buffer[:n]} + currentOffset += int64(n) + } + if response.ContentLength > 0 && currentOffset == response.ContentLength && err == io.EOF || err == io.ErrUnexpectedEOF { + return + } + if err != nil { + ch <- Chunk{error: err} + return + } + } + }() + + return response, ch, nil +}