package main import ( "bytes" "context" "flag" "fmt" "io" "net/http" "os" "path/filepath" "regexp" "sync" "time" _ "net/http/pprof" "github.com/getsentry/sentry-go" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" ) var zeroTime time.Time type UpstreamMatch struct { Match string `yaml:"match"` Replace string `yaml:"replace"` } type Upstream struct { Server string `yaml:"server"` Match UpstreamMatch `yaml:"match"` } func (upstream Upstream) GetPath(orig string) (string, bool, error) { if upstream.Match.Match == "" || upstream.Match.Replace == "" { return orig, true, nil } matcher, err := regexp.Compile(upstream.Match.Match) if err != nil { return "", false, err } return matcher.ReplaceAllString(orig, upstream.Match.Replace), matcher.MatchString(orig), nil } type LocalStorage struct { Path string `yaml:"path"` } type Storage struct { Type string `yaml:"type"` Local *LocalStorage `yaml:"local"` } type Cache struct { Timeout time.Duration `yaml:"timeout"` } type MiscConfig struct { FirstChunkBytes uint64 `yaml:"first-chunk-bytes"` ChunkBytes uint64 `yaml:"chunk-bytes"` } type Config struct { Upstreams []Upstream `yaml:"upstream"` Storage Storage `yaml:"storage"` Cache Cache `yaml:"cache"` Misc MiscConfig `yaml:"misc"` } type StreamObject struct { Headers http.Header Buffer *bytes.Buffer Offset int ctx context.Context wg *sync.WaitGroup } func (memoryObject *StreamObject) StreamTo(w io.Writer, wg *sync.WaitGroup) error { defer wg.Done() offset := 0 if w == nil { w = io.Discard } OUTER: for { select { case <-memoryObject.ctx.Done(): break OUTER default: } newOffset := memoryObject.Offset if newOffset == offset { time.Sleep(time.Millisecond) continue } bytes := memoryObject.Buffer.Bytes()[offset:newOffset] written, err := w.Write(bytes) if err != nil { return err } offset += written } time.Sleep(time.Millisecond) logrus.WithFields(logrus.Fields{ "start": offset, "end": memoryObject.Buffer.Len(), "n": memoryObject.Buffer.Len() - offset, }).Trace("remain bytes") _, err := w.Write(memoryObject.Buffer.Bytes()[offset:]) return err } type Server struct { Config lu *sync.Mutex o map[string]*StreamObject } type Chunk struct { buffer []byte error error } func configFromFile(path string) (*Config, error) { file, err := os.Open(path) if err != nil { return nil, err } defer file.Close() config := &Config{ Upstreams: []Upstream{ { Server: "https://mirrors.ustc.edu.cn", }, }, Storage: Storage{ Type: "local", Local: &LocalStorage{ Path: "./data", }, }, Misc: MiscConfig{ FirstChunkBytes: 1024 * 1024 * 50, ChunkBytes: 1024 * 1024, }, Cache: Cache{ Timeout: time.Hour, }, } if err := yaml.NewDecoder(file).Decode(&config); err != nil { return nil, err } return config, nil } func (server *Server) handleRequest(w http.ResponseWriter, r *http.Request) { fullpath := filepath.Join(server.Storage.Local.Path, r.URL.Path) fullpath, err := filepath.Abs(fullpath) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } ranged := r.Header.Get("Range") != "" localStatus, mtime, err := server.checkLocal(w, r, fullpath) if os.IsPermission(err) { http.Error(w, err.Error(), http.StatusForbidden) } else if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } else if localStatus != localNotExists { if localStatus == localExistsButNeedHead { if ranged { server.streamOnline(nil, r, mtime, fullpath) http.ServeFile(w, r, fullpath) } else { server.streamOnline(w, r, mtime, fullpath) } } else { http.ServeFile(w, r, fullpath) } } else { if ranged { server.streamOnline(nil, r, mtime, fullpath) http.ServeFile(w, r, fullpath) } else { server.streamOnline(w, r, mtime, fullpath) } } } type localStatus int const ( localNotExists localStatus = iota localExists localExistsButNeedHead ) func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key string) (exists localStatus, mtime time.Time, err error) { if stat, err := os.Stat(key); err == nil { if mtime := stat.ModTime(); mtime.Add(server.Cache.Timeout).Before(time.Now()) { return localExistsButNeedHead, mtime.In(time.UTC), nil } return localExists, mtime.In(time.UTC), nil } else if os.IsPermission(err) { http.Error(w, err.Error(), http.StatusForbidden) } else if !os.IsNotExist(err) { return localNotExists, zeroTime, err } return localNotExists, zeroTime, nil } func (server *Server) streamOnline(w http.ResponseWriter, r *http.Request, mtime time.Time, key string) { memoryObject, exists := server.o[r.URL.Path] locked := false defer func() { if locked { server.lu.Unlock() locked = false } }() if !exists { server.lu.Lock() locked = true } memoryObject, exists = server.o[r.URL.Path] if exists { if locked { server.lu.Unlock() locked = false } if w != nil { memoryObject.wg.Add(1) for k := range memoryObject.Headers { v := memoryObject.Headers.Get(k) w.Header().Set(k, v) } if err := memoryObject.StreamTo(w, memoryObject.wg); err != nil { logrus.WithError(err).Warn("failed to stream response with existing memory object") } } } else { logrus.WithField("mtime", mtime).Trace("checking fastest upstream") selectedIdx, response, chunks, err := server.fastesUpstream(r, mtime) logrus.WithFields(logrus.Fields{ "upstreamIdx": selectedIdx, }).Trace("fastest upstream") if chunks == nil && mtime != zeroTime { logrus.WithFields(logrus.Fields{"upstreamIdx": selectedIdx, "key": key}).Trace("not modified. using local version") if w != nil { http.ServeFile(w, r, key) } return } if err != nil { logrus.WithError(err).Warn("failed to select fastest upstream") http.Error(w, err.Error(), http.StatusInternalServerError) return } if selectedIdx == -1 || response == nil || chunks == nil { logrus.Trace("no upstream is selected") http.NotFound(w, r) return } if response.StatusCode == http.StatusNotModified { logrus.WithField("upstreamIdx", selectedIdx).Trace("not modified. using local version") os.Chtimes(key, zeroTime, time.Now()) http.ServeFile(w, r, key) return } buffer := &bytes.Buffer{} ctx, cancel := context.WithCancel(r.Context()) defer cancel() memoryObject = &StreamObject{ Headers: response.Header, Buffer: buffer, ctx: ctx, wg: &sync.WaitGroup{}, } server.o[r.URL.Path] = memoryObject server.lu.Unlock() locked = false if w != nil { memoryObject.wg.Add(1) for k := range memoryObject.Headers { v := memoryObject.Headers.Get(k) w.Header().Set(k, v) } go memoryObject.StreamTo(w, memoryObject.wg) } err = nil for chunk := range chunks { if chunk.error != nil { err = chunk.error logrus.WithError(err).Warn("failed to read from upstream") } if chunk.buffer == nil { break } n, _ := buffer.Write(chunk.buffer) memoryObject.Offset += n } cancel() memoryObject.wg.Wait() if err != nil { logrus.WithError(err).WithField("upstreamIdx", selectedIdx).Error("something happened during download. will not cache this response") return } go func() { logrus.Trace("preparing to release memory object") mtime := zeroTime lastModifiedHeader := response.Header.Get("Last-Modified") if lastModified, err := time.Parse(time.RFC1123, lastModifiedHeader); err != nil { logrus.WithError(err).WithFields(logrus.Fields{ "value": lastModifiedHeader, "url": response.Request.URL, }).Trace("failed to parse last modified header value") } else { mtime = lastModified } if err := os.MkdirAll(server.Storage.Local.Path, 0755); err != nil { logrus.Warn(err) } fp, err := os.CreateTemp(server.Storage.Local.Path, "temp.*") name := fp.Name() if err != nil { logrus.WithFields(logrus.Fields{ "key": key, "path": server.Storage.Local.Path, "pattern": "temp.*", }).WithError(err).Warn("ftime.Time{}ailed to create template file") } else if _, err := fp.Write(buffer.Bytes()); err != nil { fp.Close() os.Remove(name) logrus.WithError(err).Warn("failed to write into template file") } else if err := fp.Close(); err != nil { os.Remove(name) logrus.WithError(err).Warn("failed to close template file") } else { os.Chtimes(name, zeroTime, mtime) dirname := filepath.Dir(key) os.MkdirAll(dirname, 0755) os.Remove(key) os.Rename(name, key) } server.lu.Lock() defer server.lu.Unlock() delete(server.o, r.URL.Path) logrus.Trace("memory object released") }() } } func (server *Server) fastesUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, resultCh chan Chunk, resultErr error) { returnLock := &sync.Mutex{} upstreams := len(server.Upstreams) cancelFuncs := make([]func(), upstreams) selectedCh := make(chan int, 1) selectedOnce := &sync.Once{} wg := &sync.WaitGroup{} wg.Add(len(server.Upstreams)) logrus.WithField("size", len(server.Upstreams)).Trace("wg") defer close(selectedCh) for idx := range server.Upstreams { idx := idx ctx, cancel := context.WithCancel(context.Background()) cancelFuncs[idx] = cancel logger := logrus.WithField("upstreamIdx", idx) go func() { defer wg.Done() response, ch, err := server.tryUpstream(ctx, idx, r, lastModified) if err == context.Canceled { // others returned logger.Trace("context canceled") return } if err != nil { if err != context.Canceled && err != context.DeadlineExceeded { logger.WithError(err).Warn("upstream has error") } return } if response == nil { return } locked := returnLock.TryLock() if !locked { return } defer returnLock.Unlock() selectedOnce.Do(func() { resultResponse, resultCh, resultErr = response, ch, err selectedCh <- idx for cancelIdx, cancel := range cancelFuncs { if cancelIdx == idx { logrus.WithField("upstreamIdx", cancelIdx).Trace("selected thus not canceled") continue } logrus.WithField("upstreamIdx", cancelIdx).Trace("not selected and thus canceled") cancel() } logger.Trace("upstream is selected") }) logger.Trace("voted") return }() } wg.Wait() logrus.Trace("all upstream tried") resultIdx = -1 select { case idx := <-selectedCh: resultIdx = idx default: } return } func (server *Server) tryUpstream(ctx context.Context, upstreamIdx int, r *http.Request, lastModified time.Time) (response *http.Response, chunks chan Chunk, err error) { upstream := server.Upstreams[upstreamIdx] logger := logrus.WithField("upstreamIdx", upstreamIdx) newpath, matched, err := upstream.GetPath(r.URL.Path) logger.WithFields(logrus.Fields{ "path": newpath, "matched": matched, }).Trace("trying upstream") if err != nil { return nil, nil, err } if !matched { return nil, nil, nil } newurl := upstream.Server + newpath method := r.Method if lastModified != zeroTime { method = http.MethodGet } request, err := http.NewRequestWithContext(ctx, method, newurl, nil) if err != nil { return nil, nil, err } if lastModified != zeroTime { logger.WithFields(logrus.Fields{ "mtime": lastModified.Format(time.RFC1123), }).Trace("check modified since") request.Header.Set("If-Modified-Since", lastModified.Format(time.RFC1123)) } for _, k := range []string{"User-Agent"} { if _, exists := request.Header[k]; exists { request.Header.Set(k, r.Header.Get(k)) } } response, err = http.DefaultClient.Do(request) if err != nil { return nil, nil, err } logrus.WithField("status", response.StatusCode).Trace("responded") if response.StatusCode == http.StatusNotModified { return response, nil, nil } if response.StatusCode >= 400 && response.StatusCode < 500 { return nil, nil, nil } if response.StatusCode < 200 || response.StatusCode >= 500 { logrus.WithFields(logrus.Fields{ "url": newurl, "status": response.StatusCode, }).Warn("unexpected status") return response, nil, fmt.Errorf("unexpected status(url=%v): %v: %v", newurl, response.StatusCode, response) } ch := make(chan Chunk, 1024) buffer := make([]byte, server.Misc.FirstChunkBytes) start := time.Now() n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) if err != nil { if n == 0 { return response, nil, err } } logger.WithField("duration", time.Now().Sub(start)).Tracef("first %v bytes", n) ch <- Chunk{buffer: buffer[:n]} go func() { defer close(ch) for { buffer := make([]byte, server.Misc.ChunkBytes) n, err := io.ReadAtLeast(response.Body, buffer, len(buffer)) if n > 0 { ch <- Chunk{buffer: buffer[:n]} } if err == io.EOF { logger.Trace("done") return } if err != nil { ch <- Chunk{error: err} logger.WithError(err).Trace("failed") return } } }() return response, ch, nil } var ( configFilePath = "config.yaml" logLevel = "info" sentrydsn = "" ) func init() { if v, ok := os.LookupEnv("CONFIG_PATH"); ok { configFilePath = v } if v, ok := os.LookupEnv("LOG_LEVEL"); ok { logLevel = v } if v, ok := os.LookupEnv("SENTRY_DSN"); ok { sentrydsn = v } flag.StringVar(&configFilePath, "config", configFilePath, "path to config file") flag.StringVar(&logLevel, "log-level", logLevel, "log level. (trace, debug, info, warn, error)") flag.StringVar(&sentrydsn, "sentry", sentrydsn, "sentry dsn to report errors") } func main() { flag.Parse() if lvl, err := logrus.ParseLevel(logLevel); err != nil { logrus.WithError(err).Panic("failed to parse log level") } else { logrus.SetLevel(lvl) } if sentrydsn != "" { if err := sentry.Init(sentry.ClientOptions{ Dsn: sentrydsn, }); err != nil { logrus.WithField("dsn", sentrydsn).WithError(err).Panic("failed to setup sentry") } defer sentry.Flush(time.Second * 3) } logrus.SetFormatter(&logrus.TextFormatter{ FullTimestamp: true, TimestampFormat: "2006-01-02T15:04:05.000", }) config, err := configFromFile(configFilePath) if err != nil { panic(err) } ch := make(chan any, 10) for idx := 0; idx < 10; idx += 1 { ch <- idx } server := Server{ Config: *config, lu: &sync.Mutex{}, o: make(map[string]*StreamObject), } http.HandleFunc("GET /{path...}", server.handleRequest) logrus.WithFields(logrus.Fields{"addr": ":8881"}).Info("serving app") http.ListenAndServe(":8881", nil) }