690 lines
16 KiB
Go
Raw Normal View History

2024-12-16 21:11:32 +08:00
package main
import (
"bytes"
"context"
2024-12-18 10:47:18 +08:00
"flag"
2024-12-16 21:11:32 +08:00
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
2024-12-19 00:03:22 +08:00
"slices"
2024-12-18 17:16:55 +08:00
"strings"
2024-12-16 21:11:32 +08:00
"sync"
"time"
2024-12-18 10:47:18 +08:00
"github.com/getsentry/sentry-go"
2024-12-16 21:11:32 +08:00
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
)
var zeroTime time.Time
type UpstreamMatch struct {
Match string `yaml:"match"`
Replace string `yaml:"replace"`
}
type Upstream struct {
Server string `yaml:"server"`
Match UpstreamMatch `yaml:"match"`
}
func (upstream Upstream) GetPath(orig string) (string, bool, error) {
if upstream.Match.Match == "" || upstream.Match.Replace == "" {
return orig, true, nil
}
matcher, err := regexp.Compile(upstream.Match.Match)
if err != nil {
return "", false, err
}
return matcher.ReplaceAllString(orig, upstream.Match.Replace), matcher.MatchString(orig), nil
}
type LocalStorage struct {
Path string `yaml:"path"`
}
2024-12-18 17:16:55 +08:00
type Accel struct {
EnableByHeader string `yaml:"enable-by-header"`
ResponseWithHeaders []string `yaml:"response-with-headers"`
}
2024-12-16 21:11:32 +08:00
type Storage struct {
Type string `yaml:"type"`
Local *LocalStorage `yaml:"local"`
2024-12-18 17:16:55 +08:00
Accel Accel `yaml:"accel"`
2024-12-16 21:11:32 +08:00
}
2024-12-19 00:03:22 +08:00
type CachePolicyOnPath struct {
Match string `yaml:"match"`
RefreshAfter string `yaml:"refresh-after"`
}
2024-12-16 21:11:32 +08:00
type Cache struct {
2024-12-19 00:03:22 +08:00
RefreshAfter time.Duration `yaml:"refresh-after"`
Policies []CachePolicyOnPath `yaml:"policies"`
2024-12-16 21:11:32 +08:00
}
type MiscConfig struct {
FirstChunkBytes uint64 `yaml:"first-chunk-bytes"`
ChunkBytes uint64 `yaml:"chunk-bytes"`
}
type Config struct {
Upstreams []Upstream `yaml:"upstream"`
Storage Storage `yaml:"storage"`
Cache Cache `yaml:"cache"`
Misc MiscConfig `yaml:"misc"`
}
type StreamObject struct {
Headers http.Header
Buffer *bytes.Buffer
Offset int
ctx context.Context
wg *sync.WaitGroup
}
func (memoryObject *StreamObject) StreamTo(w io.Writer, wg *sync.WaitGroup) error {
defer wg.Done()
offset := 0
if w == nil {
w = io.Discard
}
OUTER:
for {
select {
case <-memoryObject.ctx.Done():
break OUTER
default:
}
newOffset := memoryObject.Offset
if newOffset == offset {
time.Sleep(time.Millisecond)
continue
}
bytes := memoryObject.Buffer.Bytes()[offset:newOffset]
written, err := w.Write(bytes)
if err != nil {
return err
}
offset += written
}
time.Sleep(time.Millisecond)
logrus.WithFields(logrus.Fields{
"start": offset,
"end": memoryObject.Buffer.Len(),
"n": memoryObject.Buffer.Len() - offset,
}).Trace("remain bytes")
_, err := w.Write(memoryObject.Buffer.Bytes()[offset:])
return err
}
type Server struct {
Config
lu *sync.Mutex
o map[string]*StreamObject
}
type Chunk struct {
buffer []byte
error error
}
func configFromFile(path string) (*Config, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
config := &Config{
Upstreams: []Upstream{
{
Server: "https://mirrors.ustc.edu.cn",
},
},
Storage: Storage{
Type: "local",
Local: &LocalStorage{
Path: "./data",
},
2024-12-18 17:16:55 +08:00
Accel: Accel{
ResponseWithHeaders: []string{"X-Sendfile", "X-Accel-Redirect"},
},
2024-12-16 21:11:32 +08:00
},
Misc: MiscConfig{
FirstChunkBytes: 1024 * 1024 * 50,
ChunkBytes: 1024 * 1024,
},
Cache: Cache{
2024-12-19 00:03:22 +08:00
RefreshAfter: time.Hour,
2024-12-16 21:11:32 +08:00
},
}
if err := yaml.NewDecoder(file).Decode(&config); err != nil {
return nil, err
}
2024-12-18 17:08:47 +08:00
if config.Storage.Local != nil {
localPath, err := filepath.Abs(config.Storage.Local.Path)
if err != nil {
return nil, err
}
config.Storage.Local.Path = localPath
}
2024-12-16 21:11:32 +08:00
return config, nil
}
2024-12-18 17:16:55 +08:00
func (server *Server) serveFile(w http.ResponseWriter, r *http.Request, path string) {
if location := r.Header.Get(server.Storage.Accel.EnableByHeader); server.Storage.Accel.EnableByHeader != "" && location != "" {
relPath, err := filepath.Rel(server.Storage.Local.Path, path)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
2024-12-19 01:27:14 +08:00
return
2024-12-18 17:16:55 +08:00
}
accelPath := filepath.Join(location, relPath)
for _, headerKey := range server.Storage.Accel.ResponseWithHeaders {
w.Header().Set(headerKey, accelPath)
}
return
}
http.ServeFile(w, r, path)
}
2024-12-16 21:11:32 +08:00
func (server *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
fullpath := filepath.Join(server.Storage.Local.Path, r.URL.Path)
fullpath, err := filepath.Abs(fullpath)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
2024-12-18 17:08:47 +08:00
if !strings.HasPrefix(fullpath, server.Storage.Local.Path) {
http.Error(w, "crossing local directory boundary", http.StatusBadRequest)
return
}
2024-12-16 21:11:32 +08:00
ranged := r.Header.Get("Range") != ""
localStatus, mtime, err := server.checkLocal(w, r, fullpath)
if os.IsPermission(err) {
http.Error(w, err.Error(), http.StatusForbidden)
} else if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else if localStatus != localNotExists {
if localStatus == localExistsButNeedHead {
if ranged {
server.streamOnline(nil, r, mtime, fullpath)
2024-12-18 17:16:55 +08:00
server.serveFile(w, r, fullpath)
2024-12-16 21:11:32 +08:00
} else {
server.streamOnline(w, r, mtime, fullpath)
}
} else {
2024-12-18 17:16:55 +08:00
server.serveFile(w, r, fullpath)
2024-12-16 21:11:32 +08:00
}
} else {
if ranged {
server.streamOnline(nil, r, mtime, fullpath)
2024-12-18 17:16:55 +08:00
server.serveFile(w, r, fullpath)
2024-12-16 21:11:32 +08:00
} else {
server.streamOnline(w, r, mtime, fullpath)
}
}
}
type localStatus int
const (
localNotExists localStatus = iota
localExists
localExistsButNeedHead
)
func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key string) (exists localStatus, mtime time.Time, err error) {
if stat, err := os.Stat(key); err == nil {
2024-12-19 00:03:22 +08:00
refreshAfter := server.Cache.RefreshAfter
refresh := ""
for _, policy := range server.Cache.Policies {
if match, err := regexp.MatchString(policy.Match, key); err != nil {
return 0, zeroTime, err
} else if match {
if dur, err := time.ParseDuration(policy.RefreshAfter); err != nil {
if slices.Contains([]string{"always", "never"}, policy.RefreshAfter) {
refresh = policy.RefreshAfter
} else {
return 0, zeroTime, err
}
} else {
refreshAfter = dur
}
break
}
}
if mtime := stat.ModTime(); mtime.Add(refreshAfter).Before(time.Now()) || refresh == "always" && refresh != "never" {
2024-12-16 21:11:32 +08:00
return localExistsButNeedHead, mtime.In(time.UTC), nil
}
2024-12-18 14:17:59 +08:00
return localExists, mtime.In(time.UTC), nil
2024-12-16 21:11:32 +08:00
} else if os.IsPermission(err) {
http.Error(w, err.Error(), http.StatusForbidden)
} else if !os.IsNotExist(err) {
return localNotExists, zeroTime, err
}
return localNotExists, zeroTime, nil
}
func (server *Server) streamOnline(w http.ResponseWriter, r *http.Request, mtime time.Time, key string) {
memoryObject, exists := server.o[r.URL.Path]
locked := false
defer func() {
if locked {
server.lu.Unlock()
locked = false
}
}()
if !exists {
server.lu.Lock()
locked = true
}
memoryObject, exists = server.o[r.URL.Path]
if exists {
if locked {
server.lu.Unlock()
locked = false
}
if w != nil {
memoryObject.wg.Add(1)
for k := range memoryObject.Headers {
v := memoryObject.Headers.Get(k)
w.Header().Set(k, v)
}
if err := memoryObject.StreamTo(w, memoryObject.wg); err != nil {
2024-12-18 14:17:59 +08:00
logrus.WithError(err).Warn("failed to stream response with existing memory object")
2024-12-16 21:11:32 +08:00
}
}
} else {
logrus.WithField("mtime", mtime).Trace("checking fastest upstream")
selectedIdx, response, chunks, err := server.fastesUpstream(r, mtime)
logrus.WithFields(logrus.Fields{
"upstreamIdx": selectedIdx,
}).Trace("fastest upstream")
if chunks == nil && mtime != zeroTime {
logrus.WithFields(logrus.Fields{"upstreamIdx": selectedIdx, "key": key}).Trace("not modified. using local version")
if w != nil {
2024-12-18 17:16:55 +08:00
server.serveFile(w, r, key)
2024-12-16 21:11:32 +08:00
}
return
}
if err != nil {
logrus.WithError(err).Warn("failed to select fastest upstream")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if selectedIdx == -1 || response == nil || chunks == nil {
logrus.Trace("no upstream is selected")
http.NotFound(w, r)
return
}
if response.StatusCode == http.StatusNotModified {
logrus.WithField("upstreamIdx", selectedIdx).Trace("not modified. using local version")
2024-12-18 10:47:18 +08:00
os.Chtimes(key, zeroTime, time.Now())
2024-12-18 17:16:55 +08:00
server.serveFile(w, r, key)
2024-12-16 21:11:32 +08:00
return
}
buffer := &bytes.Buffer{}
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
memoryObject = &StreamObject{
Headers: response.Header,
Buffer: buffer,
ctx: ctx,
wg: &sync.WaitGroup{},
}
server.o[r.URL.Path] = memoryObject
server.lu.Unlock()
locked = false
err = nil
2024-12-16 21:11:32 +08:00
if w != nil {
memoryObject.wg.Add(1)
for k := range memoryObject.Headers {
v := memoryObject.Headers.Get(k)
w.Header().Set(k, v)
}
go memoryObject.StreamTo(w, memoryObject.wg)
}
for chunk := range chunks {
if chunk.error != nil {
2024-12-18 14:17:59 +08:00
err = chunk.error
2024-12-16 21:11:32 +08:00
logrus.WithError(err).Warn("failed to read from upstream")
}
if chunk.buffer == nil {
break
}
n, _ := buffer.Write(chunk.buffer)
memoryObject.Offset += n
}
cancel()
memoryObject.wg.Wait()
2024-12-18 14:17:59 +08:00
if err != nil {
logrus.WithError(err).WithField("upstreamIdx", selectedIdx).Error("something happened during download. will not cache this response")
}
2024-12-16 21:11:32 +08:00
go func() {
if err == nil {
logrus.Trace("preparing to release memory object")
mtime := zeroTime
lastModifiedHeader := response.Header.Get("Last-Modified")
if lastModified, err := time.Parse(time.RFC1123, lastModifiedHeader); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"value": lastModifiedHeader,
"url": response.Request.URL,
}).Trace("failed to parse last modified header value")
} else {
mtime = lastModified
}
if err := os.MkdirAll(server.Storage.Local.Path, 0755); err != nil {
logrus.Warn(err)
}
fp, err := os.CreateTemp(server.Storage.Local.Path, "temp.*")
name := fp.Name()
if err != nil {
logrus.WithFields(logrus.Fields{
"key": key,
"path": server.Storage.Local.Path,
"pattern": "temp.*",
}).WithError(err).Warn("ftime.Time{}ailed to create template file")
} else if _, err := fp.Write(buffer.Bytes()); err != nil {
fp.Close()
os.Remove(name)
logrus.WithError(err).Warn("failed to write into template file")
} else if err := fp.Close(); err != nil {
os.Remove(name)
logrus.WithError(err).Warn("failed to close template file")
} else {
os.Chtimes(name, zeroTime, mtime)
dirname := filepath.Dir(key)
os.MkdirAll(dirname, 0755)
os.Remove(key)
os.Rename(name, key)
}
2024-12-16 21:11:32 +08:00
}
server.lu.Lock()
defer server.lu.Unlock()
delete(server.o, r.URL.Path)
logrus.Trace("memory object released")
}()
2024-12-16 21:11:32 +08:00
}
}
func (server *Server) fastesUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, resultCh chan Chunk, resultErr error) {
returnLock := &sync.Mutex{}
upstreams := len(server.Upstreams)
cancelFuncs := make([]func(), upstreams)
selectedCh := make(chan int, 1)
selectedOnce := &sync.Once{}
wg := &sync.WaitGroup{}
wg.Add(len(server.Upstreams))
logrus.WithField("size", len(server.Upstreams)).Trace("wg")
defer close(selectedCh)
for idx := range server.Upstreams {
idx := idx
ctx, cancel := context.WithCancel(context.Background())
cancelFuncs[idx] = cancel
logger := logrus.WithField("upstreamIdx", idx)
go func() {
defer wg.Done()
response, ch, err := server.tryUpstream(ctx, idx, r, lastModified)
if err == context.Canceled { // others returned
logger.Trace("context canceled")
return
}
if err != nil {
2024-12-18 14:17:59 +08:00
if err != context.Canceled && err != context.DeadlineExceeded {
logger.WithError(err).Warn("upstream has error")
}
2024-12-16 21:11:32 +08:00
return
}
if response == nil {
return
}
locked := returnLock.TryLock()
if !locked {
return
}
defer returnLock.Unlock()
selectedOnce.Do(func() {
resultResponse, resultCh, resultErr = response, ch, err
selectedCh <- idx
2024-12-18 14:17:59 +08:00
for cancelIdx, cancel := range cancelFuncs {
if cancelIdx == idx {
logrus.WithField("upstreamIdx", cancelIdx).Trace("selected thus not canceled")
2024-12-16 21:11:32 +08:00
continue
}
2024-12-18 14:17:59 +08:00
logrus.WithField("upstreamIdx", cancelIdx).Trace("not selected and thus canceled")
2024-12-16 21:11:32 +08:00
cancel()
}
logger.Trace("upstream is selected")
})
logger.Trace("voted")
return
}()
}
wg.Wait()
logrus.Trace("all upstream tried")
resultIdx = -1
select {
case idx := <-selectedCh:
resultIdx = idx
default:
}
return
}
func (server *Server) tryUpstream(ctx context.Context, upstreamIdx int, r *http.Request, lastModified time.Time) (response *http.Response, chunks chan Chunk, err error) {
upstream := server.Upstreams[upstreamIdx]
logger := logrus.WithField("upstreamIdx", upstreamIdx)
newpath, matched, err := upstream.GetPath(r.URL.Path)
logger.WithFields(logrus.Fields{
"path": newpath,
"matched": matched,
}).Trace("trying upstream")
if err != nil {
return nil, nil, err
}
if !matched {
return nil, nil, nil
}
newurl := upstream.Server + newpath
method := r.Method
if lastModified != zeroTime {
method = http.MethodGet
}
request, err := http.NewRequestWithContext(ctx, method, newurl, nil)
if err != nil {
return nil, nil, err
}
if lastModified != zeroTime {
logger.WithFields(logrus.Fields{
"mtime": lastModified.Format(time.RFC1123),
}).Trace("check modified since")
request.Header.Set("If-Modified-Since", lastModified.Format(time.RFC1123))
}
for _, k := range []string{"User-Agent"} {
if _, exists := request.Header[k]; exists {
request.Header.Set(k, r.Header.Get(k))
}
}
response, err = http.DefaultClient.Do(request)
if err != nil {
return nil, nil, err
}
logrus.WithField("status", response.StatusCode).Trace("responded")
if response.StatusCode == http.StatusNotModified {
return response, nil, nil
}
if response.StatusCode >= 400 && response.StatusCode < 500 {
return nil, nil, nil
}
if response.StatusCode < 200 || response.StatusCode >= 500 {
logrus.WithFields(logrus.Fields{
"url": newurl,
"status": response.StatusCode,
}).Warn("unexpected status")
return response, nil, fmt.Errorf("unexpected status(url=%v): %v: %v", newurl, response.StatusCode, response)
}
var currentOffset int64
2024-12-16 21:11:32 +08:00
ch := make(chan Chunk, 1024)
buffer := make([]byte, server.Misc.FirstChunkBytes)
start := time.Now()
n, err := io.ReadAtLeast(response.Body, buffer, len(buffer))
if err != nil {
if n == 0 {
return response, nil, err
}
}
logger.WithField("duration", time.Now().Sub(start)).Tracef("first %v bytes", n)
ch <- Chunk{buffer: buffer[:n]}
go func() {
defer close(ch)
for {
buffer := make([]byte, server.Misc.ChunkBytes)
n, err := io.ReadAtLeast(response.Body, buffer, len(buffer))
if n > 0 {
ch <- Chunk{buffer: buffer[:n]}
currentOffset += int64(n)
2024-12-16 21:11:32 +08:00
}
if response.ContentLength > 0 && currentOffset == response.ContentLength && err == io.EOF || err == io.ErrUnexpectedEOF {
2024-12-16 21:11:32 +08:00
logger.Trace("done")
return
}
if err != nil {
ch <- Chunk{error: err}
return
}
}
}()
return response, ch, nil
}
2024-12-18 10:47:18 +08:00
var (
configFilePath = "config.yaml"
logLevel = "info"
sentrydsn = ""
)
func init() {
if v, ok := os.LookupEnv("CONFIG_PATH"); ok {
configFilePath = v
}
if v, ok := os.LookupEnv("LOG_LEVEL"); ok {
logLevel = v
}
if v, ok := os.LookupEnv("SENTRY_DSN"); ok {
sentrydsn = v
}
flag.StringVar(&configFilePath, "config", configFilePath, "path to config file")
flag.StringVar(&logLevel, "log-level", logLevel, "log level. (trace, debug, info, warn, error)")
flag.StringVar(&sentrydsn, "sentry", sentrydsn, "sentry dsn to report errors")
}
2024-12-16 21:11:32 +08:00
func main() {
2024-12-18 10:47:18 +08:00
flag.Parse()
if lvl, err := logrus.ParseLevel(logLevel); err != nil {
logrus.WithError(err).Panic("failed to parse log level")
} else {
logrus.SetLevel(lvl)
}
if sentrydsn != "" {
if err := sentry.Init(sentry.ClientOptions{
Dsn: sentrydsn,
}); err != nil {
logrus.WithField("dsn", sentrydsn).WithError(err).Panic("failed to setup sentry")
}
defer sentry.Flush(time.Second * 3)
}
2024-12-16 21:11:32 +08:00
logrus.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02T15:04:05.000",
})
2024-12-18 10:47:18 +08:00
config, err := configFromFile(configFilePath)
2024-12-16 21:11:32 +08:00
if err != nil {
panic(err)
}
ch := make(chan any, 10)
for idx := 0; idx < 10; idx += 1 {
ch <- idx
}
server := Server{
Config: *config,
lu: &sync.Mutex{},
o: make(map[string]*StreamObject),
}
2024-12-18 14:30:14 +08:00
http.HandleFunc("GET /{path...}", server.handleRequest)
2024-12-16 21:11:32 +08:00
logrus.WithFields(logrus.Fields{"addr": ":8881"}).Info("serving app")
http.ListenAndServe(":8881", nil)
}