2025-01-09 23:30:42 +08:00
package cacheproxy
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log/slog"
2025-02-22 23:27:41 +08:00
"net"
2025-01-09 23:30:42 +08:00
"net/http"
"os"
"path/filepath"
"regexp"
"slices"
"strings"
"sync"
"time"
)
2025-01-21 15:18:07 +08:00
type reqCtxKey int
const (
reqCtxAllowedRedirect reqCtxKey = iota
)
2025-01-09 23:30:42 +08:00
var zeroTime time . Time
2025-01-18 01:27:37 +08:00
var (
httpClient = http . Client {
2025-01-21 15:18:07 +08:00
// check allowed redirect
2025-01-18 01:27:37 +08:00
CheckRedirect : func ( req * http . Request , via [ ] * http . Request ) error {
2025-01-21 15:18:07 +08:00
lastRequest := via [ len ( via ) - 1 ]
if allowedRedirect , ok := lastRequest . Context ( ) . Value ( reqCtxAllowedRedirect ) . ( string ) ; ok {
if matched , err := regexp . MatchString ( allowedRedirect , req . URL . String ( ) ) ; err != nil {
return err
} else if ! matched {
return http . ErrUseLastResponse
}
return nil
}
2025-01-18 01:27:37 +08:00
return http . ErrUseLastResponse
} ,
}
)
2025-01-09 23:30:42 +08:00
type StreamObject struct {
Headers http . Header
Buffer * bytes . Buffer
Offset int
ctx context . Context
wg * sync . WaitGroup
}
func ( memoryObject * StreamObject ) StreamTo ( w io . Writer , wg * sync . WaitGroup ) error {
defer wg . Done ( )
offset := 0
if w == nil {
w = io . Discard
}
OUTER :
for {
select {
case <- memoryObject . ctx . Done ( ) :
break OUTER
default :
}
newOffset := memoryObject . Offset
if newOffset == offset {
time . Sleep ( time . Millisecond )
continue
}
bytes := memoryObject . Buffer . Bytes ( ) [ offset : newOffset ]
written , err := w . Write ( bytes )
if err != nil {
return err
}
offset += written
}
time . Sleep ( time . Millisecond )
_ , err := w . Write ( memoryObject . Buffer . Bytes ( ) [ offset : ] )
return err
}
type Server struct {
Config
lu * sync . Mutex
o map [ string ] * StreamObject
}
func NewServer ( config Config ) * Server {
return & Server {
Config : config ,
lu : & sync . Mutex { } ,
o : make ( map [ string ] * StreamObject ) ,
}
}
type Chunk struct {
buffer [ ] byte
error error
}
func ( server * Server ) serveFile ( w http . ResponseWriter , r * http . Request , path string ) {
2025-03-03 09:15:37 +08:00
if location := r . Header . Get ( server . Storage . Local . Accel . EnableByHeader ) ; server . Storage . Local . Accel . EnableByHeader != "" && location != "" {
2025-01-09 23:30:42 +08:00
relPath , err := filepath . Rel ( server . Storage . Local . Path , path )
if err != nil {
http . Error ( w , err . Error ( ) , http . StatusBadRequest )
return
}
accelPath := filepath . Join ( location , relPath )
2025-03-03 09:15:37 +08:00
for _ , headerKey := range server . Storage . Local . Accel . RespondWithHeaders {
2025-01-09 23:30:42 +08:00
w . Header ( ) . Set ( headerKey , accelPath )
}
return
}
http . ServeFile ( w , r , path )
}
func ( server * Server ) HandleRequestWithCache ( w http . ResponseWriter , r * http . Request ) {
fullpath := filepath . Join ( server . Storage . Local . Path , r . URL . Path )
fullpath , err := filepath . Abs ( fullpath )
if err != nil {
http . Error ( w , err . Error ( ) , http . StatusBadRequest )
return
}
if ! strings . HasPrefix ( fullpath , server . Storage . Local . Path ) {
http . Error ( w , "crossing local directory boundary" , http . StatusBadRequest )
return
}
ranged := r . Header . Get ( "Range" ) != ""
localStatus , mtime , err := server . checkLocal ( w , r , fullpath )
slog . With ( "status" , localStatus , "mtime" , mtime , "error" , err , "key" , fullpath ) . Debug ( "local status checked" )
if os . IsPermission ( err ) {
http . Error ( w , err . Error ( ) , http . StatusForbidden )
} else if err != nil {
http . Error ( w , err . Error ( ) , http . StatusInternalServerError )
} else if localStatus != localNotExists {
if localStatus == localExistsButNeedHead {
if ranged {
server . streamOnline ( nil , r , mtime , fullpath )
server . serveFile ( w , r , fullpath )
} else {
server . streamOnline ( w , r , mtime , fullpath )
}
} else {
server . serveFile ( w , r , fullpath )
}
} else {
if ranged {
server . streamOnline ( nil , r , mtime , fullpath )
server . serveFile ( w , r , fullpath )
} else {
server . streamOnline ( w , r , mtime , fullpath )
}
}
}
type localStatus int
const (
localNotExists localStatus = iota
localExists
localExistsButNeedHead
)
func ( server * Server ) checkLocal ( w http . ResponseWriter , _ * http . Request , key string ) ( exists localStatus , mtime time . Time , err error ) {
if stat , err := os . Stat ( key ) ; err == nil {
refreshAfter := server . Cache . RefreshAfter
refresh := ""
for _ , policy := range server . Cache . Policies {
if match , err := regexp . MatchString ( policy . Match , key ) ; err != nil {
return 0 , zeroTime , err
} else if match {
if dur , err := time . ParseDuration ( policy . RefreshAfter ) ; err != nil {
if slices . Contains ( [ ] string { "always" , "never" } , policy . RefreshAfter ) {
refresh = policy . RefreshAfter
} else {
return 0 , zeroTime , err
}
} else {
refreshAfter = dur
}
break
}
}
mtime := stat . ModTime ( )
slog . With ( "policy" , refresh , "after" , refreshAfter , "mtime" , mtime , "key" , key ) . Debug ( "refresh policy checked" )
if ( mtime . Add ( refreshAfter ) . Before ( time . Now ( ) ) || refresh == "always" ) && refresh != "never" {
return localExistsButNeedHead , mtime . In ( time . UTC ) , nil
}
return localExists , mtime . In ( time . UTC ) , nil
} else if os . IsPermission ( err ) {
http . Error ( w , err . Error ( ) , http . StatusForbidden )
} else if ! os . IsNotExist ( err ) {
return localNotExists , zeroTime , err
}
return localNotExists , zeroTime , nil
}
func ( server * Server ) streamOnline ( w http . ResponseWriter , r * http . Request , mtime time . Time , key string ) {
memoryObject , exists := server . o [ r . URL . Path ]
locked := false
defer func ( ) {
if locked {
server . lu . Unlock ( )
locked = false
}
} ( )
if ! exists {
server . lu . Lock ( )
locked = true
memoryObject , exists = server . o [ r . URL . Path ]
}
if exists {
if locked {
server . lu . Unlock ( )
locked = false
}
if w != nil {
memoryObject . wg . Add ( 1 )
for k := range memoryObject . Headers {
v := memoryObject . Headers . Get ( k )
w . Header ( ) . Set ( k , v )
}
if err := memoryObject . StreamTo ( w , memoryObject . wg ) ; err != nil {
slog . With ( "error" , err ) . Warn ( "failed to stream response with existing memory object" )
}
}
} else {
slog . With ( "mtime" , mtime ) . Debug ( "checking fastest upstream" )
selectedIdx , response , chunks , err := server . fastesUpstream ( r , mtime )
if chunks == nil && mtime != zeroTime {
slog . With ( "upstreamIdx" , selectedIdx , "key" , key ) . Debug ( "not modified. using local version" )
if w != nil {
server . serveFile ( w , r , key )
}
return
}
if err != nil {
slog . With ( "error" , err ) . Warn ( "failed to select fastest upstream" )
http . Error ( w , err . Error ( ) , http . StatusInternalServerError )
return
}
if selectedIdx == - 1 || response == nil || chunks == nil {
slog . Debug ( "no upstream is selected" )
http . NotFound ( w , r )
return
}
if response . StatusCode == http . StatusNotModified {
slog . With ( "upstreamIdx" , selectedIdx ) . Debug ( "not modified. using local version" )
os . Chtimes ( key , zeroTime , time . Now ( ) )
server . serveFile ( w , r , key )
return
}
slog . With (
"upstreamIdx" , selectedIdx ,
) . Debug ( "found fastest upstream" )
buffer := & bytes . Buffer { }
ctx , cancel := context . WithCancel ( r . Context ( ) )
defer cancel ( )
memoryObject = & StreamObject {
Headers : response . Header ,
Buffer : buffer ,
ctx : ctx ,
wg : & sync . WaitGroup { } ,
}
server . o [ r . URL . Path ] = memoryObject
server . lu . Unlock ( )
locked = false
err = nil
if w != nil {
memoryObject . wg . Add ( 1 )
for k := range memoryObject . Headers {
v := memoryObject . Headers . Get ( k )
w . Header ( ) . Set ( k , v )
}
go memoryObject . StreamTo ( w , memoryObject . wg )
}
for chunk := range chunks {
if chunk . error != nil {
err = chunk . error
2025-02-06 11:05:18 +08:00
if ! errors . Is ( err , io . EOF ) && ! errors . Is ( err , io . ErrUnexpectedEOF ) {
slog . With ( "error" , err ) . Warn ( "failed to read from upstream" )
}
2025-01-09 23:30:42 +08:00
}
if chunk . buffer == nil {
break
}
n , _ := buffer . Write ( chunk . buffer )
memoryObject . Offset += n
}
cancel ( )
memoryObject . wg . Wait ( )
if response . ContentLength > 0 {
if memoryObject . Offset == int ( response . ContentLength ) && err != nil {
2025-01-18 01:27:37 +08:00
if ! ( errors . Is ( err , io . EOF ) || errors . Is ( err , io . ErrUnexpectedEOF ) ) {
slog . With ( "read-length" , memoryObject . Offset , "content-length" , response . ContentLength , "error" , err , "upstreamIdx" , selectedIdx ) . Debug ( "something happened during download. but response body is read as whole. so error is reset to nil" )
2025-01-09 23:30:42 +08:00
}
err = nil
}
} else if err == io . EOF {
err = nil
}
if err != nil {
2025-02-22 23:27:41 +08:00
logger := slog . With ( "upstreamIdx" , selectedIdx )
logger . Error ( "something happened during download. will not cache this response. setting lingering to reset the connection." )
hijacker , ok := w . ( http . Hijacker )
if ! ok {
logger . Warn ( "response writer is not a hijacker. failed to set lingering" )
return
}
conn , _ , err := hijacker . Hijack ( )
if err != nil {
logger . With ( "error" , err ) . Warn ( "hijack failed. failed to set lingering" )
return
}
defer conn . Close ( )
tcpConn , ok := conn . ( * net . TCPConn )
if ! ok {
logger . With ( "error" , err ) . Warn ( "connection is not a *net.TCPConn. failed to set lingering" )
return
}
if err := tcpConn . SetLinger ( 0 ) ; err != nil {
logger . With ( "error" , err ) . Warn ( "failed to set lingering" )
return
}
logger . Debug ( "connection set to linger. it will be reset once the conn.Close is called" )
2025-01-09 23:30:42 +08:00
}
go func ( ) {
defer func ( ) {
server . lu . Lock ( )
defer server . lu . Unlock ( )
delete ( server . o , r . URL . Path )
slog . Debug ( "memory object released" )
} ( )
if err == nil {
slog . Debug ( "preparing to release memory object" )
mtime := zeroTime
lastModifiedHeader := response . Header . Get ( "Last-Modified" )
if lastModified , err := time . Parse ( time . RFC1123 , lastModifiedHeader ) ; err != nil {
slog . With (
"error" , err ,
"value" , lastModifiedHeader ,
"url" , response . Request . URL ,
) . Debug ( "failed to parse last modified header value. set modified time to now" )
} else {
slog . With (
"header" , lastModifiedHeader ,
"value" , lastModified ,
"url" , response . Request . URL ,
) . Debug ( "found modified time" )
mtime = lastModified
}
if err := os . MkdirAll ( server . Storage . Local . Path , 0755 ) ; err != nil {
slog . With ( "error" , err ) . Warn ( "failed to create local storage path" )
}
if server . Config . Storage . Local . TemporaryFilePattern == "" {
if err := os . WriteFile ( key , buffer . Bytes ( ) , 0644 ) ; err != nil {
slog . With ( "error" , err ) . Warn ( "failed to write file" )
os . Remove ( key )
}
return
}
fp , err := os . CreateTemp ( server . Storage . Local . Path , server . Storage . Local . TemporaryFilePattern )
if err != nil {
slog . With (
"key" , key ,
"path" , server . Storage . Local . Path ,
"pattern" , server . Storage . Local . TemporaryFilePattern ,
"error" , err ,
) . Warn ( "failed to create template file" )
return
}
name := fp . Name ( )
if _ , err := fp . Write ( buffer . Bytes ( ) ) ; err != nil {
fp . Close ( )
os . Remove ( name )
slog . With ( "error" , err ) . Warn ( "failed to write into template file" )
} else if err := fp . Close ( ) ; err != nil {
os . Remove ( name )
slog . With ( "error" , err ) . Warn ( "failed to close template file" )
} else {
os . Chtimes ( name , zeroTime , mtime )
dirname := filepath . Dir ( key )
os . MkdirAll ( dirname , 0755 )
os . Remove ( key )
os . Rename ( name , key )
}
}
} ( )
}
}
func ( server * Server ) fastesUpstream ( r * http . Request , lastModified time . Time ) ( resultIdx int , resultResponse * http . Response , resultCh chan Chunk , resultErr error ) {
returnLock := & sync . Mutex { }
upstreams := len ( server . Upstreams )
cancelFuncs := make ( [ ] func ( ) , upstreams )
2025-02-06 00:33:45 +08:00
updateCh := make ( chan int , 1 )
updateOnce := & sync . Once { }
notModifiedCh := make ( chan int , 1 )
notModifiedOnce := & sync . Once { }
2025-02-06 09:50:03 +08:00
resultIdx = - 1
2025-02-06 00:33:45 +08:00
defer close ( updateCh )
2025-02-06 09:50:03 +08:00
defer close ( notModifiedCh )
defer func ( ) {
for cancelIdx , cancel := range cancelFuncs {
if cancelIdx == resultIdx || cancel == nil {
continue
}
cancel ( )
}
} ( )
2025-01-09 23:30:42 +08:00
2025-02-06 11:05:18 +08:00
groups := make ( map [ int ] [ ] int )
for upstreamIdx , upstream := range server . Upstreams {
if _ , matched , err := upstream . GetPath ( r . URL . Path ) ; err != nil {
return - 1 , nil , nil , err
} else if ! matched {
continue
}
2025-01-09 23:30:42 +08:00
2025-02-06 11:05:18 +08:00
priority := 0
for _ , priorityGroup := range upstream . PriorityGroups {
if matched , err := regexp . MatchString ( priorityGroup . Match , r . URL . Path ) ; err != nil {
return - 1 , nil , nil , err
} else if matched {
priority = priorityGroup . Priority
break
2025-01-09 23:30:42 +08:00
}
2025-02-06 11:05:18 +08:00
}
groups [ priority ] = append ( groups [ priority ] , upstreamIdx )
}
priorities := make ( [ ] int , 0 , len ( groups ) )
for priority := range groups {
priorities = append ( priorities , priority )
}
slices . Sort ( priorities )
slices . Reverse ( priorities )
for _ , priority := range priorities {
upstreams := groups [ priority ]
wg := & sync . WaitGroup { }
wg . Add ( len ( upstreams ) )
2025-01-09 23:30:42 +08:00
2025-02-06 11:05:18 +08:00
logger := slog . With ( )
if priority != 0 {
logger = logger . With ( "priority" , priority )
}
for _ , idx := range upstreams {
idx := idx
ctx , cancel := context . WithCancel ( context . Background ( ) )
cancelFuncs [ idx ] = cancel
logger := logger . With ( "upstreamIdx" , idx )
go func ( ) {
defer wg . Done ( )
response , ch , err := server . tryUpstream ( ctx , idx , priority , r , lastModified )
if err == context . Canceled { // others returned
logger . Debug ( "context canceled" )
return
2025-01-09 23:30:42 +08:00
}
2025-02-06 11:05:18 +08:00
if err != nil {
if ! errors . Is ( err , context . Canceled ) && ! errors . Is ( err , context . DeadlineExceeded ) {
logger . With ( "error" , err ) . Warn ( "upstream has error" )
}
return
}
if response == nil {
return
}
if response . StatusCode != http . StatusOK && response . StatusCode != http . StatusNotModified {
return
}
locked := returnLock . TryLock ( )
if ! locked {
return
}
defer returnLock . Unlock ( )
if response . StatusCode == http . StatusNotModified {
notModifiedOnce . Do ( func ( ) {
resultResponse , resultCh , resultErr = response , ch , err
notModifiedCh <- idx
} )
logger . Debug ( "voted not modified" )
return
}
updateOnce . Do ( func ( ) {
2025-02-06 00:33:45 +08:00
resultResponse , resultCh , resultErr = response , ch , err
2025-02-06 11:05:18 +08:00
updateCh <- idx
2025-02-06 00:33:45 +08:00
2025-02-06 11:05:18 +08:00
for cancelIdx , cancel := range cancelFuncs {
2025-02-12 16:27:38 +08:00
if cancelIdx == idx || cancel == nil {
2025-02-06 11:05:18 +08:00
continue
}
cancel ( )
}
} )
2025-01-09 23:30:42 +08:00
2025-02-06 11:05:18 +08:00
logger . Debug ( "voted update" )
} ( )
}
2025-01-09 23:30:42 +08:00
2025-02-06 11:05:18 +08:00
wg . Wait ( )
2025-01-09 23:30:42 +08:00
2025-02-06 00:55:29 +08:00
select {
2025-02-06 11:05:18 +08:00
case idx := <- updateCh :
2025-02-06 00:55:29 +08:00
resultIdx = idx
2025-02-06 11:05:18 +08:00
logger . With ( "upstreamIdx" , resultIdx ) . Debug ( "upstream selected" )
return
2025-02-06 00:55:29 +08:00
default :
2025-02-06 11:05:18 +08:00
select {
case idx := <- notModifiedCh :
resultIdx = idx
logger . With ( "upstreamIdx" , resultIdx ) . Debug ( "all upstream not modified" )
return
default :
logger . Debug ( "no valid upstream found" )
}
2025-02-06 00:55:29 +08:00
}
2025-01-09 23:30:42 +08:00
}
2025-02-06 11:05:18 +08:00
return - 1 , nil , nil , nil
2025-01-09 23:30:42 +08:00
}
2025-02-06 11:05:18 +08:00
func ( server * Server ) tryUpstream ( ctx context . Context , upstreamIdx , priority int , r * http . Request , lastModified time . Time ) ( response * http . Response , chunks chan Chunk , err error ) {
2025-01-09 23:30:42 +08:00
upstream := server . Upstreams [ upstreamIdx ]
newpath , matched , err := upstream . GetPath ( r . URL . Path )
if err != nil {
return nil , nil , err
}
if ! matched {
return nil , nil , nil
}
logger := slog . With ( "upstreamIdx" , upstreamIdx , "server" , upstream . Server , "path" , newpath )
2025-02-06 11:05:18 +08:00
if priority != 0 {
logger = logger . With ( "priority" , priority )
}
2025-01-09 23:30:42 +08:00
logger . With (
"matched" , matched ,
) . Debug ( "trying upstream" )
newurl := upstream . Server + newpath
method := r . Method
if lastModified != zeroTime {
method = http . MethodGet
}
2025-01-21 15:18:07 +08:00
if upstream . AllowedRedirect != nil {
ctx = context . WithValue ( ctx , reqCtxAllowedRedirect , * upstream . AllowedRedirect )
}
2025-01-09 23:30:42 +08:00
request , err := http . NewRequestWithContext ( ctx , method , newurl , nil )
if err != nil {
return nil , nil , err
}
if lastModified != zeroTime {
request . Header . Set ( "If-Modified-Since" , lastModified . Format ( time . RFC1123 ) )
}
2025-01-18 01:27:37 +08:00
for _ , k := range [ ] string { "User-Agent" , "Accept" } {
if _ , exists := r . Header [ k ] ; exists {
2025-01-09 23:30:42 +08:00
request . Header . Set ( k , r . Header . Get ( k ) )
}
}
2025-01-18 01:27:37 +08:00
response , err = httpClient . Do ( request )
2025-01-09 23:30:42 +08:00
if err != nil {
return nil , nil , err
}
if response . StatusCode == http . StatusNotModified {
return response , nil , nil
}
if response . StatusCode >= 400 && response . StatusCode < 500 {
return nil , nil , nil
}
if response . StatusCode < 200 || response . StatusCode >= 500 {
logger . With (
"url" , newurl ,
"status" , response . StatusCode ,
) . Warn ( "unexpected status" )
return response , nil , fmt . Errorf ( "unexpected status(url=%v): %v: %v" , newurl , response . StatusCode , response )
}
var currentOffset int64
ch := make ( chan Chunk , 1024 )
buffer := make ( [ ] byte , server . Misc . FirstChunkBytes )
n , err := io . ReadAtLeast ( response . Body , buffer , len ( buffer ) )
if err != nil {
if n == 0 {
return response , nil , err
}
}
ch <- Chunk { buffer : buffer [ : n ] }
go func ( ) {
defer close ( ch )
for {
buffer := make ( [ ] byte , server . Misc . ChunkBytes )
n , err := io . ReadAtLeast ( response . Body , buffer , len ( buffer ) )
if n > 0 {
ch <- Chunk { buffer : buffer [ : n ] }
currentOffset += int64 ( n )
}
if response . ContentLength > 0 && currentOffset == response . ContentLength && err == io . EOF || err == io . ErrUnexpectedEOF {
return
}
if err != nil {
ch <- Chunk { error : err }
return
}
}
} ( )
return response , ch , nil
}