788 lines
21 KiB
Go
788 lines
21 KiB
Go
package cacheproxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
type reqCtxKey int
|
|
|
|
const (
|
|
reqCtxAllowedRedirect reqCtxKey = iota
|
|
)
|
|
|
|
var zeroTime time.Time
|
|
|
|
var preclosedChan = make(chan struct{})
|
|
|
|
func init() {
|
|
close(preclosedChan)
|
|
}
|
|
|
|
var (
|
|
httpClient = http.Client{
|
|
// check allowed redirect
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
if allowedRedirect, ok := req.Context().Value(reqCtxAllowedRedirect).(string); ok {
|
|
if matched, err := regexp.MatchString(allowedRedirect, req.URL.String()); err != nil {
|
|
return err
|
|
} else if !matched {
|
|
return http.ErrUseLastResponse
|
|
}
|
|
return nil
|
|
}
|
|
return http.ErrUseLastResponse
|
|
},
|
|
}
|
|
)
|
|
|
|
type StreamObject struct {
|
|
Headers http.Header
|
|
TempFile *os.File // The temporary file holding the download content.
|
|
Offset int64 // The number of bytes written to TempFile.
|
|
Done bool
|
|
Error error
|
|
|
|
mu *sync.Mutex
|
|
cond *sync.Cond
|
|
|
|
fileWrittenCh chan struct{} // Closed when the file is fully written and renamed.
|
|
}
|
|
|
|
type Server struct {
|
|
Config
|
|
|
|
lu *sync.Mutex
|
|
o map[string]*StreamObject
|
|
}
|
|
|
|
func NewServer(config Config) *Server {
|
|
return &Server{
|
|
Config: config,
|
|
|
|
lu: &sync.Mutex{},
|
|
o: make(map[string]*StreamObject),
|
|
}
|
|
}
|
|
|
|
func (server *Server) serveFile(w http.ResponseWriter, r *http.Request, path string) {
|
|
if location := r.Header.Get(server.Storage.Local.Accel.EnableByHeader); server.Storage.Local.Accel.EnableByHeader != "" && location != "" {
|
|
relPath, err := filepath.Rel(server.Storage.Local.Path, path)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
accelPath := filepath.Join(location, relPath)
|
|
|
|
for _, headerKey := range server.Storage.Local.Accel.RespondWithHeaders {
|
|
w.Header().Set(headerKey, accelPath)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
http.ServeFile(w, r, path)
|
|
}
|
|
|
|
func (server *Server) HandleRequestWithCache(w http.ResponseWriter, r *http.Request) {
|
|
fullpath := filepath.Join(server.Storage.Local.Path, r.URL.Path)
|
|
fullpath, err := filepath.Abs(fullpath)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
if !strings.HasPrefix(fullpath, server.Storage.Local.Path) {
|
|
http.Error(w, "crossing local directory boundary", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
ranged := r.Header.Get("Range") != ""
|
|
|
|
localStatus, mtime, err := server.checkLocal(w, r, fullpath)
|
|
slog.With("status", localStatus, "mtime", mtime, "error", err, "key", fullpath).Debug("local status checked")
|
|
if os.IsPermission(err) {
|
|
http.Error(w, err.Error(), http.StatusForbidden)
|
|
return
|
|
} else if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if localStatus == localExists {
|
|
server.serveFile(w, r, fullpath)
|
|
return
|
|
}
|
|
|
|
// localExistsButNeedHead or localNotExists
|
|
// Both need to go online.
|
|
if ranged {
|
|
server.serveRangedRequest(w, r, fullpath, mtime)
|
|
} else {
|
|
// For full requests, we wait for the download to complete and then serve the file.
|
|
// This maintains the original behavior for now.
|
|
ch := server.startOrJoinStream(r, mtime, fullpath)
|
|
if ch != nil {
|
|
<-ch
|
|
}
|
|
server.serveFile(w, r, fullpath)
|
|
}
|
|
}
|
|
|
|
type localStatus int
|
|
|
|
const (
|
|
localNotExists localStatus = iota
|
|
localExists
|
|
localExistsButNeedHead
|
|
)
|
|
|
|
func (server *Server) checkLocal(w http.ResponseWriter, _ *http.Request, key string) (exists localStatus, mtime time.Time, err error) {
|
|
if stat, err := os.Stat(key); err == nil {
|
|
refreshAfter := server.Cache.RefreshAfter
|
|
refresh := ""
|
|
|
|
for _, policy := range server.Cache.Policies {
|
|
if match, err := regexp.MatchString(policy.Match, key); err != nil {
|
|
return 0, zeroTime, err
|
|
} else if match {
|
|
if dur, err := time.ParseDuration(policy.RefreshAfter); err != nil {
|
|
if slices.Contains([]string{"always", "never"}, policy.RefreshAfter) {
|
|
refresh = policy.RefreshAfter
|
|
} else {
|
|
return 0, zeroTime, err
|
|
}
|
|
} else {
|
|
refreshAfter = dur
|
|
}
|
|
break
|
|
}
|
|
}
|
|
mtime := stat.ModTime()
|
|
slog.With("policy", refresh, "after", refreshAfter, "mtime", mtime, "key", key).Debug("refresh policy checked")
|
|
if (mtime.Add(refreshAfter).Before(time.Now()) || refresh == "always") && refresh != "never" {
|
|
return localExistsButNeedHead, mtime.In(time.UTC), nil
|
|
}
|
|
return localExists, mtime.In(time.UTC), nil
|
|
} else if os.IsPermission(err) {
|
|
http.Error(w, err.Error(), http.StatusForbidden)
|
|
} else if !os.IsNotExist(err) {
|
|
return localNotExists, zeroTime, err
|
|
}
|
|
|
|
return localNotExists, zeroTime, nil
|
|
}
|
|
|
|
func (server *Server) serveRangedRequest(w http.ResponseWriter, r *http.Request, key string, mtime time.Time) {
|
|
server.lu.Lock()
|
|
memoryObject, exists := server.o[r.URL.Path]
|
|
if !exists {
|
|
server.lu.Unlock()
|
|
// This is the first request for this file, so we start the download.
|
|
server.startOrJoinStream(r, mtime, key)
|
|
|
|
// Re-acquire the lock to get the newly created stream object.
|
|
server.lu.Lock()
|
|
memoryObject = server.o[r.URL.Path]
|
|
if memoryObject == nil {
|
|
server.lu.Unlock()
|
|
// This can happen if the upstream fails very quickly.
|
|
http.Error(w, "Failed to start download stream", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
server.lu.Unlock()
|
|
|
|
// At this point, we have a memoryObject, and a producer goroutine is downloading the file.
|
|
// Now we implement the consumer logic.
|
|
|
|
// Duplicate the file descriptor from the producer's temp file. This creates a new,
|
|
// independent file descriptor that points to the same underlying file description.
|
|
// This is more efficient and robust than opening the file by path.
|
|
fd, err := syscall.Dup(int(memoryObject.TempFile.Fd()))
|
|
if err != nil {
|
|
http.Error(w, "Failed to duplicate file descriptor", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
// We create a new *os.File from the duplicated descriptor. The consumer is now
|
|
// responsible for closing this new file.
|
|
consumerFile := os.NewFile(uintptr(fd), memoryObject.TempFile.Name())
|
|
if consumerFile == nil {
|
|
syscall.Close(fd) // Clean up if NewFile fails
|
|
http.Error(w, "Failed to create file from duplicated descriptor", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer consumerFile.Close()
|
|
|
|
rangeHeader := r.Header.Get("Range")
|
|
if rangeHeader == "" {
|
|
// This should not happen if called from HandleRequestWithCache, but as a safeguard:
|
|
http.Error(w, "Range header is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Parse the Range header. We only support a single range like "bytes=start-end".
|
|
var start, end int64
|
|
parts := strings.Split(strings.TrimPrefix(rangeHeader, "bytes="), "-")
|
|
if len(parts) != 2 {
|
|
http.Error(w, "Invalid Range header", http.StatusBadRequest)
|
|
return
|
|
}
|
|
start, err = strconv.ParseInt(parts[0], 10, 64)
|
|
if err != nil {
|
|
http.Error(w, "Invalid Range header", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if parts[1] != "" {
|
|
end, err = strconv.ParseInt(parts[1], 10, 64)
|
|
if err != nil {
|
|
http.Error(w, "Invalid Range header", http.StatusBadRequest)
|
|
return
|
|
}
|
|
} else {
|
|
// An empty end means "to the end of the file". We don't know the full size yet.
|
|
// We'll have to handle this dynamically.
|
|
end = -1 // Sentinel value
|
|
}
|
|
|
|
memoryObject.mu.Lock()
|
|
defer memoryObject.mu.Unlock()
|
|
|
|
// Wait until we have the headers.
|
|
for memoryObject.Headers == nil && !memoryObject.Done {
|
|
memoryObject.cond.Wait()
|
|
}
|
|
|
|
if memoryObject.Error != nil {
|
|
http.Error(w, memoryObject.Error.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
contentLengthStr := memoryObject.Headers.Get("Content-Length")
|
|
totalSize, _ := strconv.ParseInt(contentLengthStr, 10, 64)
|
|
if end == -1 || end >= totalSize {
|
|
end = totalSize - 1
|
|
}
|
|
|
|
if start >= totalSize || start > end {
|
|
http.Error(w, "Range not satisfiable", http.StatusRequestedRangeNotSatisfiable)
|
|
return
|
|
}
|
|
|
|
var bytesSent int64
|
|
bytesToSend := end - start + 1
|
|
headersWritten := false
|
|
|
|
for bytesSent < bytesToSend && memoryObject.Error == nil {
|
|
// Calculate what we need.
|
|
neededStart := start + bytesSent
|
|
|
|
// Wait for the data to be available.
|
|
for memoryObject.Offset <= neededStart && !memoryObject.Done {
|
|
memoryObject.cond.Wait()
|
|
}
|
|
|
|
// Check for error AFTER waiting. This is the critical fix.
|
|
if memoryObject.Error != nil {
|
|
// If headers haven't been written, we can send a 500 error.
|
|
// If they have, it's too late, the connection will just be closed.
|
|
if !headersWritten {
|
|
http.Error(w, memoryObject.Error.Error(), http.StatusInternalServerError)
|
|
}
|
|
return // Use return instead of break to exit immediately.
|
|
}
|
|
|
|
// If we are here, we have some data to send. Write headers if we haven't already.
|
|
if !headersWritten {
|
|
w.Header().Set("Content-Range", "bytes "+strconv.FormatInt(start, 10)+"-"+strconv.FormatInt(end, 10)+"/"+contentLengthStr)
|
|
w.Header().Set("Content-Length", strconv.FormatInt(end-start+1, 10))
|
|
w.Header().Set("Accept-Ranges", "bytes")
|
|
w.WriteHeader(http.StatusPartialContent)
|
|
headersWritten = true
|
|
}
|
|
|
|
// Data is available, read from the temporary file.
|
|
// We calculate how much we can read in this iteration.
|
|
readNow := memoryObject.Offset - neededStart
|
|
if readNow <= 0 {
|
|
// This can happen if we woke up but the data isn't what we need.
|
|
// The loop will continue to wait.
|
|
continue
|
|
}
|
|
|
|
// Don't read more than the client requested in total.
|
|
remainingToSend := bytesToSend - bytesSent
|
|
if readNow > remainingToSend {
|
|
readNow = remainingToSend
|
|
}
|
|
|
|
// Read the chunk from the file at the correct offset.
|
|
buffer := make([]byte, readNow)
|
|
bytesRead, err := consumerFile.ReadAt(buffer, neededStart)
|
|
if err != nil && err != io.EOF {
|
|
if !headersWritten {
|
|
http.Error(w, "Error reading from cache stream", http.StatusInternalServerError)
|
|
}
|
|
return
|
|
}
|
|
|
|
if bytesRead > 0 {
|
|
n, err := w.Write(buffer[:bytesRead])
|
|
if err != nil {
|
|
// Client closed connection, just return.
|
|
return
|
|
}
|
|
bytesSent += int64(n)
|
|
}
|
|
|
|
if memoryObject.Done && bytesSent >= bytesToSend {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// startOrJoinStream ensures a download stream is active for the given key.
|
|
// If a stream already exists, it returns the channel that signals completion.
|
|
// If not, it starts a new download producer and returns its completion channel.
|
|
func (server *Server) startOrJoinStream(r *http.Request, mtime time.Time, key string) <-chan struct{} {
|
|
server.lu.Lock()
|
|
memoryObject, exists := server.o[r.URL.Path]
|
|
if exists {
|
|
// A download is already in progress. Return its completion channel.
|
|
server.lu.Unlock()
|
|
return memoryObject.fileWrittenCh
|
|
}
|
|
|
|
// No active stream, create a new one.
|
|
if err := os.MkdirAll(filepath.Dir(key), 0755); err != nil {
|
|
slog.With("error", err).Warn("failed to create local storage path for temp file")
|
|
server.lu.Unlock()
|
|
return preclosedChan // Return a closed channel to prevent blocking
|
|
}
|
|
|
|
tempFilePattern := server.Storage.Local.TemporaryFilePattern
|
|
if tempFilePattern == "" {
|
|
tempFilePattern = "cache-proxy-*"
|
|
}
|
|
tempFile, err := os.CreateTemp(filepath.Dir(key), tempFilePattern)
|
|
if err != nil {
|
|
slog.With("error", err).Warn("failed to create temporary file")
|
|
server.lu.Unlock()
|
|
return preclosedChan
|
|
}
|
|
|
|
mu := &sync.Mutex{}
|
|
fileWrittenCh := make(chan struct{})
|
|
memoryObject = &StreamObject{
|
|
TempFile: tempFile,
|
|
mu: mu,
|
|
cond: sync.NewCond(mu),
|
|
fileWrittenCh: fileWrittenCh,
|
|
}
|
|
server.o[r.URL.Path] = memoryObject
|
|
server.lu.Unlock()
|
|
|
|
// This is the producer goroutine
|
|
go func(mo *StreamObject) {
|
|
var err error
|
|
downloadSucceeded := false
|
|
tempFileName := mo.TempFile.Name()
|
|
|
|
defer func() {
|
|
// On completion (or error), update the stream object,
|
|
// wake up all consumers, and then clean up.
|
|
mo.mu.Lock()
|
|
mo.Done = true
|
|
mo.Error = err
|
|
mo.cond.Broadcast()
|
|
mo.mu.Unlock()
|
|
|
|
// Close the temp file handle.
|
|
mo.TempFile.Close()
|
|
|
|
// If download failed, remove the temp file.
|
|
if !downloadSucceeded {
|
|
os.Remove(tempFileName)
|
|
}
|
|
|
|
// The producer's job is done. Remove the object from the central map.
|
|
// Any existing consumers still hold a reference to the object,
|
|
// so it won't be garbage collected until they are done.
|
|
server.lu.Lock()
|
|
delete(server.o, r.URL.Path)
|
|
server.lu.Unlock()
|
|
slog.Debug("memory object released by producer")
|
|
close(mo.fileWrittenCh)
|
|
}()
|
|
|
|
slog.With("mtime", mtime).Debug("checking fastest upstream")
|
|
selectedIdx, response, firstChunk, upstreamErr := server.fastestUpstream(r, mtime)
|
|
if upstreamErr != nil {
|
|
if !errors.Is(upstreamErr, io.EOF) && !errors.Is(upstreamErr, io.ErrUnexpectedEOF) {
|
|
slog.With("error", upstreamErr).Warn("failed to select fastest upstream")
|
|
err = upstreamErr
|
|
return
|
|
}
|
|
}
|
|
|
|
if selectedIdx == -1 || response == nil {
|
|
slog.Debug("no upstream is selected")
|
|
err = errors.New("no suitable upstream found")
|
|
return
|
|
}
|
|
|
|
if response.StatusCode == http.StatusNotModified {
|
|
slog.With("upstreamIdx", selectedIdx).Debug("not modified. using local version")
|
|
os.Chtimes(key, zeroTime, time.Now())
|
|
// In this case, we don't have a new file, so we just exit.
|
|
// The temp file will be cleaned up by the defer.
|
|
return
|
|
}
|
|
|
|
defer response.Body.Close()
|
|
|
|
slog.With("upstreamIdx", selectedIdx).Debug("found fastest upstream")
|
|
|
|
mo.mu.Lock()
|
|
mo.Headers = response.Header
|
|
mo.cond.Broadcast() // Broadcast headers availability
|
|
mo.mu.Unlock()
|
|
|
|
// Write the first chunk that we already downloaded.
|
|
if len(firstChunk) > 0 {
|
|
written, writeErr := mo.TempFile.Write(firstChunk)
|
|
if writeErr != nil {
|
|
err = writeErr
|
|
return
|
|
}
|
|
mo.mu.Lock()
|
|
mo.Offset += int64(written)
|
|
mo.cond.Broadcast()
|
|
mo.mu.Unlock()
|
|
}
|
|
|
|
// Download the rest of the file in chunks
|
|
buffer := make([]byte, server.Misc.ChunkBytes)
|
|
for {
|
|
n, readErr := response.Body.Read(buffer)
|
|
if n > 0 {
|
|
written, writeErr := mo.TempFile.Write(buffer[:n])
|
|
if writeErr != nil {
|
|
err = writeErr
|
|
break
|
|
}
|
|
mo.mu.Lock()
|
|
mo.Offset += int64(written)
|
|
mo.cond.Broadcast()
|
|
mo.mu.Unlock()
|
|
}
|
|
if readErr != nil {
|
|
if readErr != io.EOF {
|
|
err = readErr
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
// After download, if no critical error, rename the temp file to its final destination.
|
|
if err == nil {
|
|
// Set modification time
|
|
mtime := zeroTime
|
|
lastModifiedHeader := response.Header.Get("Last-Modified")
|
|
if lastModified, lmErr := time.Parse(time.RFC1123, lastModifiedHeader); lmErr == nil {
|
|
mtime = lastModified
|
|
}
|
|
// Close file before Chtimes and Rename
|
|
mo.TempFile.Close()
|
|
|
|
os.Chtimes(tempFileName, zeroTime, mtime)
|
|
|
|
// Rename the file
|
|
if renameErr := os.Rename(tempFileName, key); renameErr != nil {
|
|
slog.With("error", renameErr, "from", tempFileName, "to", key).Warn("failed to rename temp file")
|
|
err = renameErr
|
|
os.Remove(tempFileName) // Attempt to clean up if rename fails
|
|
} else {
|
|
downloadSucceeded = true
|
|
}
|
|
} else {
|
|
logger := slog.With("upstreamIdx", selectedIdx)
|
|
logger.Error("something happened during download. will not cache this response.", "error", err)
|
|
}
|
|
}(memoryObject)
|
|
|
|
return fileWrittenCh
|
|
}
|
|
|
|
func (server *Server) fastestUpstream(r *http.Request, lastModified time.Time) (resultIdx int, resultResponse *http.Response, firstChunk []byte, resultErr error) {
|
|
returnLock := &sync.Mutex{}
|
|
upstreams := len(server.Upstreams)
|
|
cancelFuncs := make([]func(), upstreams)
|
|
updateCh := make(chan int, 1)
|
|
updateOnce := &sync.Once{}
|
|
notModifiedCh := make(chan int, 1)
|
|
notModifiedOnce := &sync.Once{}
|
|
|
|
resultIdx = -1
|
|
|
|
var resultFirstChunk []byte
|
|
|
|
defer close(updateCh)
|
|
defer close(notModifiedCh)
|
|
defer func() {
|
|
for cancelIdx, cancel := range cancelFuncs {
|
|
if cancelIdx == resultIdx || cancel == nil {
|
|
continue
|
|
}
|
|
cancel()
|
|
}
|
|
}()
|
|
|
|
groups := make(map[int][]int)
|
|
for upstreamIdx, upstream := range server.Upstreams {
|
|
if _, matched, err := upstream.GetPath(r.URL.Path); err != nil {
|
|
resultErr = err
|
|
return
|
|
} else if !matched {
|
|
continue
|
|
}
|
|
|
|
priority := 0
|
|
for _, priorityGroup := range upstream.PriorityGroups {
|
|
if matched, err := regexp.MatchString(priorityGroup.Match, r.URL.Path); err != nil {
|
|
resultErr = err
|
|
return
|
|
} else if matched {
|
|
priority = priorityGroup.Priority
|
|
break
|
|
}
|
|
}
|
|
groups[priority] = append(groups[priority], upstreamIdx)
|
|
}
|
|
|
|
priorities := make([]int, 0, len(groups))
|
|
for priority := range groups {
|
|
priorities = append(priorities, priority)
|
|
}
|
|
slices.Sort(priorities)
|
|
slices.Reverse(priorities)
|
|
|
|
for _, priority := range priorities {
|
|
upstreams := groups[priority]
|
|
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(len(upstreams))
|
|
|
|
logger := slog.With()
|
|
if priority != 0 {
|
|
logger = logger.With("priority", priority)
|
|
}
|
|
|
|
for _, idx := range upstreams {
|
|
idx := idx
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancelFuncs[idx] = cancel
|
|
|
|
logger := logger.With("upstreamIdx", idx)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
response, chunk, err := server.tryUpstream(ctx, idx, priority, r, lastModified)
|
|
if err == context.Canceled { // others returned
|
|
logger.Debug("context canceled")
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
|
logger.With("error", err).Warn("upstream has error")
|
|
}
|
|
return
|
|
}
|
|
if response == nil {
|
|
return
|
|
}
|
|
if response.StatusCode != http.StatusOK && response.StatusCode != http.StatusNotModified {
|
|
return
|
|
}
|
|
locked := returnLock.TryLock()
|
|
if !locked {
|
|
if response != nil {
|
|
response.Body.Close()
|
|
}
|
|
return
|
|
}
|
|
defer returnLock.Unlock()
|
|
|
|
if response.StatusCode == http.StatusNotModified {
|
|
notModifiedOnce.Do(func() {
|
|
resultResponse, resultErr = response, err
|
|
notModifiedCh <- idx
|
|
})
|
|
logger.Debug("voted not modified")
|
|
return
|
|
}
|
|
|
|
updateOnce.Do(func() {
|
|
resultResponse, resultFirstChunk, resultErr = response, chunk, err
|
|
updateCh <- idx
|
|
|
|
for cancelIdx, cancel := range cancelFuncs {
|
|
if cancelIdx == idx || cancel == nil {
|
|
continue
|
|
}
|
|
cancel()
|
|
}
|
|
})
|
|
|
|
logger.Debug("voted update")
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
select {
|
|
case idx := <-updateCh:
|
|
resultIdx = idx
|
|
firstChunk = resultFirstChunk
|
|
logger.With("upstreamIdx", resultIdx).Debug("upstream selected")
|
|
return
|
|
default:
|
|
select {
|
|
case idx := <-notModifiedCh:
|
|
resultIdx = idx
|
|
logger.With("upstreamIdx", resultIdx).Debug("all upstream not modified")
|
|
return
|
|
default:
|
|
logger.Debug("no valid upstream found")
|
|
}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int, r *http.Request, lastModified time.Time) (response *http.Response, firstChunk []byte, err error) {
|
|
upstream := server.Upstreams[upstreamIdx]
|
|
|
|
newpath, matched, err := upstream.GetPath(r.URL.Path)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if !matched {
|
|
return nil, nil, nil
|
|
}
|
|
logger := slog.With("upstreamIdx", upstreamIdx, "server", upstream.Server, "path", newpath)
|
|
if priority != 0 {
|
|
logger = logger.With("priority", priority)
|
|
}
|
|
|
|
logger.With(
|
|
"matched", matched,
|
|
).Debug("trying upstream")
|
|
|
|
newurl := upstream.Server + newpath
|
|
method := r.Method
|
|
if lastModified != zeroTime {
|
|
method = http.MethodGet
|
|
}
|
|
|
|
if upstream.AllowedRedirect != nil {
|
|
ctx = context.WithValue(ctx, reqCtxAllowedRedirect, *upstream.AllowedRedirect)
|
|
}
|
|
|
|
request, err := http.NewRequestWithContext(ctx, method, newurl, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if lastModified != zeroTime {
|
|
request.Header.Set("If-Modified-Since", lastModified.Format(time.RFC1123))
|
|
}
|
|
|
|
for _, k := range []string{"User-Agent", "Accept"} {
|
|
if _, exists := r.Header[k]; exists {
|
|
request.Header.Set(k, r.Header.Get(k))
|
|
}
|
|
}
|
|
response, err = httpClient.Do(request)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
shouldCloseBody := true
|
|
defer func() {
|
|
if shouldCloseBody && response != nil {
|
|
response.Body.Close()
|
|
}
|
|
}()
|
|
|
|
if response.StatusCode == http.StatusNotModified {
|
|
return response, nil, nil
|
|
}
|
|
|
|
responseCheckers := upstream.Checkers
|
|
if len(responseCheckers) == 0 {
|
|
responseCheckers = append(responseCheckers, Checker{})
|
|
}
|
|
|
|
for _, checker := range responseCheckers {
|
|
if len(checker.StatusCodes) == 0 {
|
|
checker.StatusCodes = append(checker.StatusCodes, http.StatusOK)
|
|
}
|
|
|
|
if !slices.Contains(checker.StatusCodes, response.StatusCode) {
|
|
return nil, nil, err
|
|
}
|
|
|
|
for _, headerChecker := range checker.Headers {
|
|
if headerChecker.Match == nil {
|
|
// check header exists
|
|
if _, ok := response.Header[headerChecker.Name]; !ok {
|
|
logger.Debug("missing header", "header", headerChecker.Name)
|
|
return nil, nil, nil
|
|
}
|
|
} else {
|
|
// check header match
|
|
value := response.Header.Get(headerChecker.Name)
|
|
if matched, err := regexp.MatchString(*headerChecker.Match, value); err != nil {
|
|
return nil, nil, err
|
|
} else if !matched {
|
|
logger.Debug("invalid header value",
|
|
"header", headerChecker.Name,
|
|
"value", value,
|
|
"matcher", *headerChecker.Match,
|
|
)
|
|
|
|
return nil, nil, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
buffer := make([]byte, server.Misc.FirstChunkBytes)
|
|
n, err := io.ReadAtLeast(response.Body, buffer, len(buffer))
|
|
|
|
if err != nil {
|
|
if n == 0 {
|
|
return response, nil, err
|
|
}
|
|
|
|
if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) {
|
|
err = nil
|
|
}
|
|
}
|
|
|
|
shouldCloseBody = false
|
|
return response, buffer[:n], err
|
|
}
|