first commit
This commit is contained in:
575
cmd/proxy/main.go
Normal file
575
cmd/proxy/main.go
Normal file
@ -0,0 +1,575 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "net/http/pprof"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var zeroTime time.Time
|
||||
|
||||
type UpstreamMatch struct {
|
||||
Match string `yaml:"match"`
|
||||
Replace string `yaml:"replace"`
|
||||
}
|
||||
|
||||
type Upstream struct {
|
||||
Server string `yaml:"server"`
|
||||
Match UpstreamMatch `yaml:"match"`
|
||||
}
|
||||
|
||||
func (upstream Upstream) GetPath(orig string) (string, bool, error) {
|
||||
if upstream.Match.Match == "" || upstream.Match.Replace == "" {
|
||||
return orig, true, nil
|
||||
}
|
||||
matcher, err := regexp.Compile(upstream.Match.Match)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
return matcher.ReplaceAllString(orig, upstream.Match.Replace), matcher.MatchString(orig), nil
|
||||
}
|
||||
|
||||
type LocalStorage struct {
|
||||
Path string `yaml:"path"`
|
||||
}
|
||||
|
||||
type Storage struct {
|
||||
Type string `yaml:"type"`
|
||||
Local *LocalStorage `yaml:"local"`
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
}
|
||||
|
||||
type MiscConfig struct {
|
||||
FirstChunkBytes uint64 `yaml:"first-chunk-bytes"`
|
||||
ChunkBytes uint64 `yaml:"chunk-bytes"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Upstreams []Upstream `yaml:"upstream"`
|
||||
Storage Storage `yaml:"storage"`
|
||||
Cache Cache `yaml:"cache"`
|
||||
Misc MiscConfig `yaml:"misc"`
|
||||
}
|
||||
|
||||
type StreamObject struct {
|
||||
Headers http.Header
|
||||
Buffer *bytes.Buffer
|
||||
Offset int
|
||||
|
||||
ctx context.Context
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (memoryObject *StreamObject) StreamTo(w io.Writer, wg *sync.WaitGroup) error {
|
||||
defer wg.Done()
|
||||
offset := 0
|
||||
if w == nil {
|
||||
w = io.Discard
|
||||
}
|
||||
OUTER:
|
||||
for {
|
||||
select {
|
||||
case <-memoryObject.ctx.Done():
|
||||
logrus.Info("ctx done")
|
||||
break OUTER
|
||||
default:
|
||||
}
|
||||
|
||||
newOffset := memoryObject.Offset
|
||||
if newOffset == offset {
|
||||
time.Sleep(time.Millisecond)
|
||||
continue
|
||||
}
|
||||
logrus.WithFields(logrus.Fields{"start": offset, "end": newOffset}).Info("writing")
|
||||
bytes := memoryObject.Buffer.Bytes()[offset:newOffset]
|
||||
written, err := w.Write(bytes)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Info("write failed")
|
||||
return err
|
||||
}
|
||||
logrus.WithFields(logrus.Fields{"n": written}).Info("written")
|
||||
|
||||
offset += written
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"start": offset,
|
||||
"end": memoryObject.Buffer.Len(),
|
||||
"n": memoryObject.Buffer.Len() - offset,
|
||||
}).Trace("remain bytes")
|
||||
|
||||
_, err := w.Write(memoryObject.Buffer.Bytes()[offset:])
|
||||
return err
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
Config
|
||||
|
||||
lu *sync.Mutex
|
||||
o map[string]*StreamObject
|
||||
}
|
||||
|
||||
type Chunk struct {
|
||||
buffer []byte
|
||||
error error
|
||||
}
|
||||
|
||||
func configFromFile(path string) (*Config, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
config := &Config{
|
||||
Upstreams: []Upstream{
|
||||
{
|
||||
Server: "https://mirrors.ustc.edu.cn",
|
||||
},
|
||||
},
|
||||
Storage: Storage{
|
||||
Type: "local",
|
||||
Local: &LocalStorage{
|
||||
Path: "./data",
|
||||
},
|
||||
},
|
||||
Misc: MiscConfig{
|
||||
FirstChunkBytes: 1024 * 1024 * 50,
|
||||
ChunkBytes: 1024 * 1024,
|
||||
},
|
||||
Cache: Cache{
|
||||
Timeout: time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
if err := yaml.NewDecoder(file).Decode(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (server *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
fullpath := filepath.Join(server.Storage.Local.Path, r.URL.Path)
|
||||
fullpath, err := filepath.Abs(fullpath)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ranged := r.Header.Get("Range") != ""
|
||||
|
||||
localStatus, mtime, err := server.checkLocal(w, r, fullpath)
|
||||
if os.IsPermission(err) {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
} else if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
} else if localStatus != localNotExists {
|
||||
if localStatus == localExistsButNeedHead {
|
||||
if ranged {
|
||||
server.streamOnline(nil, r, mtime, fullpath)
|
||||
http.ServeFile(w, r, fullpath)
|
||||
} else {
|
||||
server.streamOnline(w, r, mtime, fullpath)
|
||||
}
|
||||
} else {
|
||||
http.ServeFile(w, r, fullpath)
|
||||
}
|
||||
} else {
|
||||
if ranged {
|
||||
server.streamOnline(nil, r, mtime, fullpath)
|
||||
http.ServeFile(w, r, fullpath)
|
||||
} else {
|
||||
server.streamOnline(w, r, mtime, fullpath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type localStatus int
|
||||
|
||||
const (
|
||||
localNotExists localStatus = iota
|
||||
localExists
|
||||
localExistsButNeedHead
|
||||
)
|
||||
|
||||
func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key string) (exists localStatus, mtime time.Time, err error) {
|
||||
if stat, err := os.Stat(key); err == nil {
|
||||
logrus.Println(stat.ModTime(), stat.ModTime().Add(server.Cache.Timeout), time.Now())
|
||||
if mtime := stat.ModTime(); mtime.Add(server.Cache.Timeout).Before(time.Now()) {
|
||||
return localExistsButNeedHead, mtime.In(time.UTC), nil
|
||||
}
|
||||
return localExists, zeroTime, nil
|
||||
} else if os.IsPermission(err) {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
} else if !os.IsNotExist(err) {
|
||||
return localNotExists, zeroTime, err
|
||||
}
|
||||
|
||||
return localNotExists, zeroTime, nil
|
||||
}
|
||||
|
||||
func (server *Server) streamOnline(w http.ResponseWriter, r *http.Request, mtime time.Time, key string) {
|
||||
memoryObject, exists := server.o[r.URL.Path]
|
||||
locked := false
|
||||
defer func() {
|
||||
if locked {
|
||||
server.lu.Unlock()
|
||||
locked = false
|
||||
}
|
||||
}()
|
||||
if !exists {
|
||||
server.lu.Lock()
|
||||
locked = true
|
||||
}
|
||||
memoryObject, exists = server.o[r.URL.Path]
|
||||
if exists {
|
||||
if locked {
|
||||
server.lu.Unlock()
|
||||
locked = false
|
||||
}
|
||||
|
||||
if w != nil {
|
||||
memoryObject.wg.Add(1)
|
||||
for k := range memoryObject.Headers {
|
||||
v := memoryObject.Headers.Get(k)
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
|
||||
if err := memoryObject.StreamTo(w, memoryObject.wg); err != nil {
|
||||
logrus.WithError(err).Warn("failed to stream response")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logrus.WithField("mtime", mtime).Trace("checking fastest upstream")
|
||||
selectedIdx, response, chunks, err := server.fastesUpstream(r, mtime)
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"upstreamIdx": selectedIdx,
|
||||
}).Trace("fastest upstream")
|
||||
if chunks == nil && mtime != zeroTime {
|
||||
logrus.WithFields(logrus.Fields{"upstreamIdx": selectedIdx, "key": key}).Trace("not modified. using local version")
|
||||
if w != nil {
|
||||
http.ServeFile(w, r, key)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warn("failed to select fastest upstream")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if selectedIdx == -1 || response == nil || chunks == nil {
|
||||
logrus.Trace("no upstream is selected")
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if response.StatusCode == http.StatusNotModified {
|
||||
logrus.WithField("upstreamIdx", selectedIdx).Trace("not modified. using local version")
|
||||
http.ServeFile(w, r, key)
|
||||
return
|
||||
}
|
||||
|
||||
buffer := &bytes.Buffer{}
|
||||
ctx, cancel := context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
|
||||
memoryObject = &StreamObject{
|
||||
Headers: response.Header,
|
||||
Buffer: buffer,
|
||||
|
||||
ctx: ctx,
|
||||
|
||||
wg: &sync.WaitGroup{},
|
||||
}
|
||||
|
||||
server.o[r.URL.Path] = memoryObject
|
||||
server.lu.Unlock()
|
||||
locked = false
|
||||
|
||||
if w != nil {
|
||||
memoryObject.wg.Add(1)
|
||||
|
||||
for k := range memoryObject.Headers {
|
||||
v := memoryObject.Headers.Get(k)
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
|
||||
go memoryObject.StreamTo(w, memoryObject.wg)
|
||||
}
|
||||
|
||||
for chunk := range chunks {
|
||||
if chunk.error != nil {
|
||||
logrus.WithError(err).Warn("failed to read from upstream")
|
||||
}
|
||||
if chunk.buffer == nil {
|
||||
break
|
||||
}
|
||||
n, _ := buffer.Write(chunk.buffer)
|
||||
memoryObject.Offset += n
|
||||
|
||||
}
|
||||
cancel()
|
||||
|
||||
memoryObject.wg.Wait()
|
||||
|
||||
go func() {
|
||||
logrus.Trace("preparing to release memory object")
|
||||
mtime := zeroTime
|
||||
lastModifiedHeader := response.Header.Get("Last-Modified")
|
||||
if lastModified, err := time.Parse(time.RFC1123, lastModifiedHeader); err != nil {
|
||||
logrus.WithError(err).WithFields(logrus.Fields{
|
||||
"value": lastModifiedHeader,
|
||||
"url": response.Request.URL,
|
||||
}).Trace("failed to parse last modified header value")
|
||||
} else {
|
||||
mtime = lastModified
|
||||
}
|
||||
if err := os.MkdirAll(server.Storage.Local.Path, 0755); err != nil {
|
||||
logrus.Warn(err)
|
||||
}
|
||||
fp, err := os.CreateTemp(server.Storage.Local.Path, "temp.*")
|
||||
name := fp.Name()
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"key": key,
|
||||
"path": server.Storage.Local.Path,
|
||||
"pattern": "temp.*",
|
||||
}).WithError(err).Warn("ftime.Time{}ailed to create template file")
|
||||
} else if _, err := fp.Write(buffer.Bytes()); err != nil {
|
||||
fp.Close()
|
||||
os.Remove(name)
|
||||
|
||||
logrus.WithError(err).Warn("failed to write into template file")
|
||||
} else if err := fp.Close(); err != nil {
|
||||
os.Remove(name)
|
||||
|
||||
logrus.WithError(err).Warn("failed to close template file")
|
||||
} else {
|
||||
os.Chtimes(name, zeroTime, mtime)
|
||||
dirname := filepath.Dir(key)
|
||||
os.MkdirAll(dirname, 0755)
|
||||
os.Remove(key)
|
||||
os.Rename(name, key)
|
||||
}
|
||||
|
||||
server.lu.Lock()
|
||||
defer server.lu.Unlock()
|
||||
|
||||
delete(server.o, r.URL.Path)
|
||||
logrus.Trace("memory object released")
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) fastesUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, resultCh chan Chunk, resultErr error) {
|
||||
returnLock := &sync.Mutex{}
|
||||
upstreams := len(server.Upstreams)
|
||||
cancelFuncs := make([]func(), upstreams)
|
||||
selectedCh := make(chan int, 1)
|
||||
selectedOnce := &sync.Once{}
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(len(server.Upstreams))
|
||||
logrus.WithField("size", len(server.Upstreams)).Trace("wg")
|
||||
|
||||
defer close(selectedCh)
|
||||
for idx := range server.Upstreams {
|
||||
idx := idx
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancelFuncs[idx] = cancel
|
||||
|
||||
logger := logrus.WithField("upstreamIdx", idx)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
response, ch, err := server.tryUpstream(ctx, idx, r, lastModified)
|
||||
if err == context.Canceled { // others returned
|
||||
logger.Trace("context canceled")
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("upstream has error")
|
||||
return
|
||||
}
|
||||
if response == nil {
|
||||
return
|
||||
}
|
||||
locked := returnLock.TryLock()
|
||||
if !locked {
|
||||
return
|
||||
}
|
||||
defer returnLock.Unlock()
|
||||
|
||||
selectedOnce.Do(func() {
|
||||
resultResponse, resultCh, resultErr = response, ch, err
|
||||
selectedCh <- idx
|
||||
|
||||
for idx, cancel := range cancelFuncs {
|
||||
if resultIdx == idx {
|
||||
logrus.WithField("upstreamIdx", idx).Trace("selected thus not canceled")
|
||||
continue
|
||||
}
|
||||
logrus.WithField("upstreamIdx", idx).Trace("not selected and thus canceled")
|
||||
cancel()
|
||||
}
|
||||
|
||||
logger.Trace("upstream is selected")
|
||||
})
|
||||
|
||||
logger.Trace("voted")
|
||||
|
||||
return
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
logrus.Trace("all upstream tried")
|
||||
|
||||
resultIdx = -1
|
||||
select {
|
||||
case idx := <-selectedCh:
|
||||
resultIdx = idx
|
||||
default:
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (server *Server) tryUpstream(ctx context.Context, upstreamIdx int, r *http.Request, lastModified time.Time) (response *http.Response, chunks chan Chunk, err error) {
|
||||
upstream := server.Upstreams[upstreamIdx]
|
||||
|
||||
logger := logrus.WithField("upstreamIdx", upstreamIdx)
|
||||
|
||||
newpath, matched, err := upstream.GetPath(r.URL.Path)
|
||||
logger.WithFields(logrus.Fields{
|
||||
"path": newpath,
|
||||
"matched": matched,
|
||||
}).Trace("trying upstream")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if !matched {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
newurl := upstream.Server + newpath
|
||||
method := r.Method
|
||||
if lastModified != zeroTime {
|
||||
method = http.MethodGet
|
||||
}
|
||||
request, err := http.NewRequestWithContext(ctx, method, newurl, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if lastModified != zeroTime {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"mtime": lastModified.Format(time.RFC1123),
|
||||
}).Trace("check modified since")
|
||||
request.Header.Set("If-Modified-Since", lastModified.Format(time.RFC1123))
|
||||
}
|
||||
|
||||
for _, k := range []string{"User-Agent"} {
|
||||
if _, exists := request.Header[k]; exists {
|
||||
request.Header.Set(k, r.Header.Get(k))
|
||||
}
|
||||
}
|
||||
response, err = http.DefaultClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
logrus.WithField("status", response.StatusCode).Trace("responded")
|
||||
if response.StatusCode == http.StatusNotModified {
|
||||
return response, nil, nil
|
||||
}
|
||||
if response.StatusCode >= 400 && response.StatusCode < 500 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if response.StatusCode < 200 || response.StatusCode >= 500 {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"url": newurl,
|
||||
"status": response.StatusCode,
|
||||
}).Warn("unexpected status")
|
||||
return response, nil, fmt.Errorf("unexpected status(url=%v): %v: %v", newurl, response.StatusCode, response)
|
||||
}
|
||||
|
||||
ch := make(chan Chunk, 1024)
|
||||
|
||||
buffer := make([]byte, server.Misc.FirstChunkBytes)
|
||||
start := time.Now()
|
||||
n, err := io.ReadAtLeast(response.Body, buffer, len(buffer))
|
||||
|
||||
if err != nil {
|
||||
if n == 0 {
|
||||
return response, nil, err
|
||||
}
|
||||
}
|
||||
logger.WithField("duration", time.Now().Sub(start)).Tracef("first %v bytes", n)
|
||||
ch <- Chunk{buffer: buffer[:n]}
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
for {
|
||||
buffer := make([]byte, server.Misc.ChunkBytes)
|
||||
n, err := io.ReadAtLeast(response.Body, buffer, len(buffer))
|
||||
if n > 0 {
|
||||
ch <- Chunk{buffer: buffer[:n]}
|
||||
}
|
||||
if err == io.EOF {
|
||||
logger.Trace("done")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
ch <- Chunk{error: err}
|
||||
logger.WithError(err).Trace("failed")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return response, ch, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
logrus.SetLevel(logrus.TraceLevel)
|
||||
logrus.SetFormatter(&logrus.TextFormatter{
|
||||
FullTimestamp: true,
|
||||
TimestampFormat: "2006-01-02T15:04:05.000",
|
||||
})
|
||||
config, err := configFromFile("config.yaml")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ch := make(chan any, 10)
|
||||
for idx := 0; idx < 10; idx += 1 {
|
||||
ch <- idx
|
||||
}
|
||||
|
||||
server := Server{
|
||||
Config: *config,
|
||||
|
||||
lu: &sync.Mutex{},
|
||||
o: make(map[string]*StreamObject),
|
||||
}
|
||||
|
||||
http.HandleFunc("/{path...}", server.handleRequest)
|
||||
logrus.WithFields(logrus.Fields{"addr": ":8881"}).Info("serving app")
|
||||
http.ListenAndServe(":8881", nil)
|
||||
}
|
Reference in New Issue
Block a user