From 835045346da2c1bc0ddb258892587dcfbe84c42d Mon Sep 17 00:00:00 2001 From: guochao Date: Tue, 1 Apr 2025 14:09:43 +0800 Subject: [PATCH] add header checker --- config.go | 11 +++++++++++ config.yaml | 4 ++++ server.go | 54 ++++++++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 60 insertions(+), 9 deletions(-) diff --git a/config.go b/config.go index dd81579..38bb01d 100644 --- a/config.go +++ b/config.go @@ -15,9 +15,20 @@ type PathTransformation struct { Replace string `yaml:"replace"` } +type HeaderChecker struct { + Name string `yaml:"name"` + Match *string `yaml:"match"` +} + +type Checker struct { + StatusCodes []int `yaml:"status-codes"` + Headers []HeaderChecker `yaml:"headers"` +} + type Upstream struct { Server string `yaml:"server"` Path PathTransformation `yaml:"path"` + Checkers []Checker `yaml:"checkers"` AllowedRedirect *string `yaml:"allowed-redirect"` PriorityGroups []UpstreamPriorityGroup `yaml:"priority-groups"` } diff --git a/config.yaml b/config.yaml index d093f0f..a01b9fa 100644 --- a/config.yaml +++ b/config.yaml @@ -11,6 +11,10 @@ upstream: path: match: /(debian|ubuntu|ubuntu-releases|alpine|archlinux|kali|manjaro|msys2|almalinux|rocky|centos|centos-stream|centos-vault|fedora|epel|elrepo|remi|rpmfusion|tailscale|gnu|rust-static|pypi)/(.*) replace: '/$1/$2' + checkers: + headers: + - name: content-type + match: application/octet-stream - server: https://packages.microsoft.com/repos/code path: match: /microsoft-code(.*) diff --git a/server.go b/server.go index fbf1e17..5796e99 100644 --- a/server.go +++ b/server.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "errors" - "fmt" "io" "log/slog" "net" @@ -618,18 +617,54 @@ func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int if err != nil { return nil, nil, err } + streaming := false + defer func() { + if !streaming { + response.Body.Close() + } + }() + if response.StatusCode == http.StatusNotModified { return response, nil, nil } - if response.StatusCode >= 400 && response.StatusCode < 500 { - return nil, nil, nil + + responseCheckers := upstream.Checkers + if len(responseCheckers) == 0 { + responseCheckers = append(responseCheckers, Checker{}) } - if response.StatusCode < 200 || response.StatusCode >= 500 { - logger.With( - "url", newurl, - "status", response.StatusCode, - ).Warn("unexpected status") - return response, nil, fmt.Errorf("unexpected status(url=%v): %v: %v", newurl, response.StatusCode, response) + + for _, checker := range responseCheckers { + if len(checker.StatusCodes) == 0 { + checker.StatusCodes = append(checker.StatusCodes, http.StatusOK) + } + + if !slices.Contains(checker.StatusCodes, response.StatusCode) { + return nil, nil, err + } + + for _, headerChecker := range checker.Headers { + if headerChecker.Match == nil { + // check header exists + if _, ok := response.Header[headerChecker.Name]; !ok { + logger.Debug("missing header", "header", headerChecker.Name) + return nil, nil, nil + } + } else { + // check header match + value := response.Header.Get(headerChecker.Name) + if matched, err := regexp.MatchString(*headerChecker.Match, value); err != nil { + return nil, nil, err + } else if !matched { + logger.Debug("invalid header value", + "header", headerChecker.Name, + "value", value, + "matcher", *headerChecker.Match, + ) + + return nil, nil, nil + } + } + } } var currentOffset int64 @@ -646,6 +681,7 @@ func (server *Server) tryUpstream(ctx context.Context, upstreamIdx, priority int } ch <- Chunk{buffer: buffer[:n]} + streaming = true go func() { defer close(ch)