cache-proxy/server.go
guochao 83dfcba4ae
All checks were successful
build container / build-container (push) Successful in 5m45s
run go test / test (push) Successful in 3m15s
move x-accel into local storage configuration
2025-03-03 09:15:37 +08:00

671 lines
16 KiB
Go

package cacheproxy
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"slices"
"strings"
"sync"
"time"
)
type reqCtxKey int
const (
reqCtxAllowedRedirect reqCtxKey = iota
)
var zeroTime time.Time
var (
httpClient = http.Client{
// check allowed redirect
CheckRedirect: func(req *http.Request, via []*http.Request) error {
lastRequest := via[len(via)-1]
if allowedRedirect, ok := lastRequest.Context().Value(reqCtxAllowedRedirect).(string); ok {
if matched, err := regexp.MatchString(allowedRedirect, req.URL.String()); err != nil {
return err
} else if !matched {
return http.ErrUseLastResponse
}
return nil
}
return http.ErrUseLastResponse
},
}
)
type StreamObject struct {
Headers http.Header
Buffer *bytes.Buffer
Offset int
ctx context.Context
wg *sync.WaitGroup
}
func (memoryObject *StreamObject) StreamTo(w io.Writer, wg *sync.WaitGroup) error {
defer wg.Done()
offset := 0
if w == nil {
w = io.Discard
}
OUTER:
for {
select {
case <-memoryObject.ctx.Done():
break OUTER
default:
}
newOffset := memoryObject.Offset
if newOffset == offset {
time.Sleep(time.Millisecond)
continue
}
bytes := memoryObject.Buffer.Bytes()[offset:newOffset]
written, err := w.Write(bytes)
if err != nil {
return err
}
offset += written
}
time.Sleep(time.Millisecond)
_, err := w.Write(memoryObject.Buffer.Bytes()[offset:])
return err
}
type Server struct {
Config
lu *sync.Mutex
o map[string]*StreamObject
}
func NewServer(config Config) *Server {
return &Server{
Config: config,
lu: &sync.Mutex{},
o: make(map[string]*StreamObject),
}
}
type Chunk struct {
buffer []byte
error error
}
func (server *Server) serveFile(w http.ResponseWriter, r *http.Request, path string) {
if location := r.Header.Get(server.Storage.Local.Accel.EnableByHeader); server.Storage.Local.Accel.EnableByHeader != "" && location != "" {
relPath, err := filepath.Rel(server.Storage.Local.Path, path)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
accelPath := filepath.Join(location, relPath)
for _, headerKey := range server.Storage.Local.Accel.RespondWithHeaders {
w.Header().Set(headerKey, accelPath)
}
return
}
http.ServeFile(w, r, path)
}
func (server *Server) HandleRequestWithCache(w http.ResponseWriter, r *http.Request) {
fullpath := filepath.Join(server.Storage.Local.Path, r.URL.Path)
fullpath, err := filepath.Abs(fullpath)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !strings.HasPrefix(fullpath, server.Storage.Local.Path) {
http.Error(w, "crossing local directory boundary", http.StatusBadRequest)
return
}
ranged := r.Header.Get("Range") != ""
localStatus, mtime, err := server.checkLocal(w, r, fullpath)
slog.With("status", localStatus, "mtime", mtime, "error", err, "key", fullpath).Debug("local status checked")
if os.IsPermission(err) {
http.Error(w, err.Error(), http.StatusForbidden)
} else if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else if localStatus != localNotExists {
if localStatus == localExistsButNeedHead {
if ranged {
server.streamOnline(nil, r, mtime, fullpath)
server.serveFile(w, r, fullpath)
} else {
server.streamOnline(w, r, mtime, fullpath)
}
} else {
server.serveFile(w, r, fullpath)
}
} else {
if ranged {
server.streamOnline(nil, r, mtime, fullpath)
server.serveFile(w, r, fullpath)
} else {
server.streamOnline(w, r, mtime, fullpath)
}
}
}
type localStatus int
const (
localNotExists localStatus = iota
localExists
localExistsButNeedHead
)
func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key string) (exists localStatus, mtime time.Time, err error) {
if stat, err := os.Stat(key); err == nil {
refreshAfter := server.Cache.RefreshAfter
refresh := ""
for _, policy := range server.Cache.Policies {
if match, err := regexp.MatchString(policy.Match, key); err != nil {
return 0, zeroTime, err
} else if match {
if dur, err := time.ParseDuration(policy.RefreshAfter); err != nil {
if slices.Contains([]string{"always", "never"}, policy.RefreshAfter) {
refresh = policy.RefreshAfter
} else {
return 0, zeroTime, err
}
} else {
refreshAfter = dur
}
break
}
}
mtime := stat.ModTime()
slog.With("policy", refresh, "after", refreshAfter, "mtime", mtime, "key", key).Debug("refresh policy checked")
if (mtime.Add(refreshAfter).Before(time.Now()) || refresh == "always") && refresh != "never" {
return localExistsButNeedHead, mtime.In(time.UTC), nil
}
return localExists, mtime.In(time.UTC), nil
} else if os.IsPermission(err) {
http.Error(w, err.Error(), http.StatusForbidden)
} else if !os.IsNotExist(err) {
return localNotExists, zeroTime, err
}
return localNotExists, zeroTime, nil
}
func (server *Server) streamOnline(w http.ResponseWriter, r *http.Request, mtime time.Time, key string) {
memoryObject, exists := server.o[r.URL.Path]
locked := false
defer func() {
if locked {
server.lu.Unlock()
locked = false
}
}()
if !exists {
server.lu.Lock()
locked = true
memoryObject, exists = server.o[r.URL.Path]
}
if exists {
if locked {
server.lu.Unlock()
locked = false
}
if w != nil {
memoryObject.wg.Add(1)
for k := range memoryObject.Headers {
v := memoryObject.Headers.Get(k)
w.Header().Set(k, v)
}
if err := memoryObject.StreamTo(w, memoryObject.wg); err != nil {
slog.With("error", err).Warn("failed to stream response with existing memory object")
}
}
} else {
slog.With("mtime", mtime).Debug("checking fastest upstream")
selectedIdx, response, chunks, err := server.fastesUpstream(r, mtime)
if chunks == nil && mtime != zeroTime {
slog.With("upstreamIdx", selectedIdx, "key", key).Debug("not modified. using local version")
if w != nil {
server.serveFile(w, r, key)
}
return
}
if err != nil {
slog.With("error", err).Warn("failed to select fastest upstream")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if selectedIdx == -1 || response == nil || chunks == nil {
slog.Debug("no upstream is selected")
http.NotFound(w, r)
return
}
if response.StatusCode == http.StatusNotModified {
slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version")
os.Chtimes(key, zeroTime, time.Now())
server.serveFile(w, r, key)
return
}
slog.With(
"upstreamIdx", selectedIdx,
).Debug("found fastest upstream")
buffer := &bytes.Buffer{}
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
memoryObject = &StreamObject{
Headers: response.Header,
Buffer: buffer,
ctx: ctx,
wg: &sync.WaitGroup{},
}
server.o[r.URL.Path] = memoryObject
server.lu.Unlock()
locked = false
err = nil
if w != nil {
memoryObject.wg.Add(1)
for k := range memoryObject.Headers {
v := memoryObject.Headers.Get(k)
w.Header().Set(k, v)
}
go memoryObject.StreamTo(w, memoryObject.wg)
}
for chunk := range chunks {
if chunk.error != nil {
err = chunk.error
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
slog.With("error", err).Warn("failed to read from upstream")
}
}
if chunk.buffer == nil {
break
}
n, _ := buffer.Write(chunk.buffer)
memoryObject.Offset += n
}
cancel()
memoryObject.wg.Wait()
if response.ContentLength > 0 {
if memoryObject.Offset == int(response.ContentLength) && err != nil {
if !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
slog.With("read-length", memoryObject.Offset, "content-length", response.ContentLength, "error", err, "upstreamIdx", selectedIdx).Debug("something happened during download. but response body is read as whole. so error is reset to nil")
}
err = nil
}
} else if err == io.EOF {
err = nil
}
if err != nil {
logger := slog.With("upstreamIdx", selectedIdx)
logger.Error("something happened during download. will not cache this response. setting lingering to reset the connection.")
hijacker, ok := w.(http.Hijacker)
if !ok {
logger.Warn("response writer is not a hijacker. failed to set lingering")
return
}
conn, _, err := hijacker.Hijack()
if err != nil {
logger.With("error", err).Warn("hijack failed. failed to set lingering")
return
}
defer conn.Close()
tcpConn, ok := conn.(*net.TCPConn)
if !ok {
logger.With("error", err).Warn("connection is not a *net.TCPConn. failed to set lingering")
return
}
if err := tcpConn.SetLinger(0); err != nil {
logger.With("error", err).Warn("failed to set lingering")
return
}
logger.Debug("connection set to linger. it will be reset once the conn.Close is called")
}
go func() {
defer func() {
server.lu.Lock()
defer server.lu.Unlock()
delete(server.o, r.URL.Path)
slog.Debug("memory object released")
}()
if err == nil {
slog.Debug("preparing to release memory object")
mtime := zeroTime
lastModifiedHeader := response.Header.Get("Last-Modified")
if lastModified, err := time.Parse(time.RFC1123, lastModifiedHeader); err != nil {
slog.With(
"error", err,
"value", lastModifiedHeader,
"url", response.Request.URL,
).Debug("failed to parse last modified header value. set modified time to now")
} else {
slog.With(
"header", lastModifiedHeader,
"value", lastModified,
"url", response.Request.URL,
).Debug("found modified time")
mtime = lastModified
}
if err := os.MkdirAll(server.Storage.Local.Path, 0755); err != nil {
slog.With("error", err).Warn("failed to create local storage path")
}
if server.Config.Storage.Local.TemporaryFilePattern == "" {
if err := os.WriteFile(key, buffer.Bytes(), 0644); err != nil {
slog.With("error", err).Warn("failed to write file")
os.Remove(key)
}
return
}
fp, err := os.CreateTemp(server.Storage.Local.Path, server.Storage.Local.TemporaryFilePattern)
if err != nil {
slog.With(
"key", key,
"path", server.Storage.Local.Path,
"pattern", server.Storage.Local.TemporaryFilePattern,
"error", err,
).Warn("failed to create template file")
return
}
name := fp.Name()
if _, err := fp.Write(buffer.Bytes()); err != nil {
fp.Close()
os.Remove(name)
slog.With("error", err).Warn("failed to write into template file")
} else if err := fp.Close(); err != nil {
os.Remove(name)
slog.With("error", err).Warn("failed to close template file")
} else {
os.Chtimes(name, zeroTime, mtime)
dirname := filepath.Dir(key)
os.MkdirAll(dirname, 0755)
os.Remove(key)
os.Rename(name, key)
}
}
}()
}
}
func (server *Server) fastesUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, resultCh chan Chunk, resultErr error) {
returnLock := &sync.Mutex{}
upstreams := len(server.Upstreams)
cancelFuncs := make([]func(), upstreams)
updateCh := make(chan int, 1)
updateOnce := &sync.Once{}
notModifiedCh := make(chan int, 1)
notModifiedOnce := &sync.Once{}
resultIdx = -1
defer close(updateCh)
defer close(notModifiedCh)
defer func() {
for cancelIdx, cancel := range cancelFuncs {
if cancelIdx == resultIdx || cancel == nil {
continue
}
cancel()
}
}()
groups := make(map[int][]int)
for upstreamIdx, upstream := range server.Upstreams {
if _, matched, err := upstream.GetPath(r.URL.Path); err != nil {
return -1, nil, nil, err
} else if !matched {
continue
}
priority := 0
for _, priorityGroup := range upstream.PriorityGroups {
if matched, err := regexp.MatchString(priorityGroup.Match, r.URL.Path); err != nil {
return -1, nil, nil, err
} else if matched {
priority = priorityGroup.Priority
break
}
}
groups[priority] = append(groups[priority], upstreamIdx)
}
priorities := make([]int, 0, len(groups))
for priority := range groups {
priorities = append(priorities, priority)
}
slices.Sort(priorities)
slices.Reverse(priorities)
for _, priority := range priorities {
upstreams := groups[priority]
wg := &sync.WaitGroup{}
wg.Add(len(upstreams))
logger := slog.With()
if priority != 0 {
logger = logger.With("priority", priority)
}
for _, idx := range upstreams {
idx := idx
ctx, cancel := context.WithCancel(context.Background())
cancelFuncs[idx] = cancel
logger := logger.With("upstreamIdx", idx)
go func() {
defer wg.Done()
response, ch, err := server.tryUpstream(ctx, idx, priority, r, lastModified)
if err == context.Canceled { // others returned
logger.Debug("context canceled")
return
}
if err != nil {
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
logger.With("error", err).Warn("upstream has error")
}
return
}
if response == nil {
return
}
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusNotModified {
return
}
locked := returnLock.TryLock()
if !locked {
return
}
defer returnLock.Unlock()
if response.StatusCode == http.StatusNotModified {
notModifiedOnce.Do(func() {
resultResponse, resultCh, resultErr = response, ch, err
notModifiedCh <- idx
})
logger.Debug("voted not modified")
return
}
updateOnce.Do(func() {
resultResponse, resultCh, resultErr = response, ch, err
updateCh <- idx
for cancelIdx, cancel := range cancelFuncs {
if cancelIdx == idx || cancel == nil {
continue
}
cancel()
}
})
logger.Debug("voted update")
}()
}
wg.Wait()
select {
case idx := <-updateCh:
resultIdx = idx
logger.With("upstreamIdx", resultIdx).Debug("upstream selected")
return
default:
select {
case idx := <-notModifiedCh:
resultIdx = idx
logger.With("upstreamIdx", resultIdx).Debug("all upstream not modified")
return
default:
logger.Debug("no valid upstream found")
}
}
}
return -1, nil, nil, nil
}
func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int, r *http.Request, lastModified time.Time) (response *http.Response, chunks chan Chunk, err error) {
upstream := server.Upstreams[upstreamIdx]
newpath, matched, err := upstream.GetPath(r.URL.Path)
if err != nil {
return nil, nil, err
}
if !matched {
return nil, nil, nil
}
logger := slog.With("upstreamIdx", upstreamIdx, "server", upstream.Server, "path", newpath)
if priority != 0 {
logger = logger.With("priority", priority)
}
logger.With(
"matched", matched,
).Debug("trying upstream")
newurl := upstream.Server + newpath
method := r.Method
if lastModified != zeroTime {
method = http.MethodGet
}
if upstream.AllowedRedirect != nil {
ctx = context.WithValue(ctx, reqCtxAllowedRedirect, *upstream.AllowedRedirect)
}
request, err := http.NewRequestWithContext(ctx, method, newurl, nil)
if err != nil {
return nil, nil, err
}
if lastModified != zeroTime {
request.Header.Set("If-Modified-Since", lastModified.Format(time.RFC1123))
}
for _, k := range []string{"User-Agent", "Accept"} {
if _, exists := r.Header[k]; exists {
request.Header.Set(k, r.Header.Get(k))
}
}
response, err = httpClient.Do(request)
if err != nil {
return nil, nil, err
}
if response.StatusCode == http.StatusNotModified {
return response, nil, nil
}
if response.StatusCode >= 400 && response.StatusCode < 500 {
return nil, nil, nil
}
if response.StatusCode < 200 || response.StatusCode >= 500 {
logger.With(
"url", newurl,
"status", response.StatusCode,
).Warn("unexpected status")
return response, nil, fmt.Errorf("unexpected status(url=%v): %v: %v", newurl, response.StatusCode, response)
}
var currentOffset int64
ch := make(chan Chunk, 1024)
buffer := make([]byte, server.Misc.FirstChunkBytes)
n, err := io.ReadAtLeast(response.Body, buffer, len(buffer))
if err != nil {
if n == 0 {
return response, nil, err
}
}
ch <- Chunk{buffer: buffer[:n]}
go func() {
defer close(ch)
for {
buffer := make([]byte, server.Misc.ChunkBytes)
n, err := io.ReadAtLeast(response.Body, buffer, len(buffer))
if n > 0 {
ch <- Chunk{buffer: buffer[:n]}
currentOffset += int64(n)
}
if response.ContentLength > 0 && currentOffset == response.ContentLength && err == io.EOF || err == io.ErrUnexpectedEOF {
return
}
if err != nil {
ch <- Chunk{error: err}
return
}
}
}()
return response, ch, nil
}