sync from project
This commit is contained in:
13
go.mod
13
go.mod
@ -3,8 +3,19 @@ module git.jeffthecoder.xyz/public/lazyhandler
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/go-playground/validator/v10 v10.26.0
|
||||
github.com/go-session/session/v3 v3.2.1
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
)
|
||||
|
||||
require github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 // indirect
|
||||
require (
|
||||
github.com/bytedance/gopkg v0.1.2 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
)
|
||||
|
38
go.sum
38
go.sum
@ -1,6 +1,17 @@
|
||||
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8=
|
||||
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/bytedance/gopkg v0.1.2 h1:8o2feYuxknDpN+O7kPwvSXfMEKfYvJYiA2K7aonoMEQ=
|
||||
github.com/bytedance/gopkg v0.1.2/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
|
||||
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
|
||||
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
|
||||
github.com/go-session/session/v3 v3.2.1 h1:APQf5JFW84+bhbqRjEZO8J+IppSgT1jMQTFI/XVyIFY=
|
||||
github.com/go-session/session/v3 v3.2.1/go.mod h1:RftEBbyuzqkNCAxIrCLJe+rfBqB/4G11qxq9KYKrx4M=
|
||||
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0=
|
||||
@ -9,14 +20,23 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0=
|
||||
github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
|
||||
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
|
||||
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
44
inject.go
44
inject.go
@ -12,10 +12,13 @@ import (
|
||||
"git.jeffthecoder.xyz/public/lazyhandler/util"
|
||||
)
|
||||
|
||||
type ErrArgumentIsNotExtractable int
|
||||
type ErrArgumentIsNotExtractable struct {
|
||||
Index int
|
||||
Type reflect.Type
|
||||
}
|
||||
|
||||
func (err ErrArgumentIsNotExtractable) Error() string {
|
||||
return fmt.Sprintf("argument %v is not extractable", int(err))
|
||||
return fmt.Sprintf("argument %v is not extractable: %s", err.Index, err.Type.String())
|
||||
}
|
||||
|
||||
type ErrReturnValueNotConvertableIntoResponsePart int
|
||||
@ -27,6 +30,7 @@ func (err ErrReturnValueNotConvertableIntoResponsePart) Error() string {
|
||||
var (
|
||||
ErrNotAFunc = errors.New("not a function")
|
||||
ErrDuplicateResponseWriterExtractor = errors.New("duplicate response writer extractor")
|
||||
ErrDuplicateRequestBodyExtractor = errors.New("duplicate request body extractor")
|
||||
ErrResponseWriterCannotBeExtracted = errors.New("http.ResponseWriter extractor must not exists if function has return value")
|
||||
)
|
||||
|
||||
@ -50,6 +54,7 @@ func MagicHandler(fn any) (http.Handler, error) {
|
||||
}
|
||||
|
||||
responseWriterExtracted := -1
|
||||
requestBodyExtracted := -1
|
||||
|
||||
extractors := make([]func(http.ResponseWriter, *http.Request) (any, error), t.NumIn())
|
||||
for idx := 0; idx < t.NumIn(); idx++ {
|
||||
@ -57,7 +62,10 @@ func MagicHandler(fn any) (http.Handler, error) {
|
||||
|
||||
extractor, isTakeResponseWriter := magic.GetExtractor(in)
|
||||
if extractor == nil {
|
||||
return nil, ErrArgumentIsNotExtractable(idx)
|
||||
return nil, ErrArgumentIsNotExtractable{
|
||||
Index: idx,
|
||||
Type: in,
|
||||
}
|
||||
}
|
||||
if isTakeResponseWriter {
|
||||
if responseWriterExtracted >= 0 {
|
||||
@ -65,6 +73,14 @@ func MagicHandler(fn any) (http.Handler, error) {
|
||||
}
|
||||
responseWriterExtracted = idx
|
||||
}
|
||||
|
||||
if magic.IsTakeBody(in) {
|
||||
if requestBodyExtracted >= 0 {
|
||||
return nil, ErrDuplicateRequestBodyExtractor
|
||||
}
|
||||
requestBodyExtracted = idx
|
||||
}
|
||||
|
||||
extractors[idx] = extractor
|
||||
}
|
||||
|
||||
@ -98,6 +114,14 @@ func MagicHandler(fn any) (http.Handler, error) {
|
||||
if _, ok := util.Implements[magic.RespondWriter](out); ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := util.Implements[magic.RespondWriter](reflect.PointerTo(out)); ok {
|
||||
continue
|
||||
}
|
||||
if out.Kind() == reflect.Pointer {
|
||||
if _, ok := util.Implements[magic.RespondWriter](out.Elem()); ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// last is error
|
||||
if _, ok := util.Implements[error](out); ok && idx == t.NumOut()-1 {
|
||||
@ -108,12 +132,6 @@ func MagicHandler(fn any) (http.Handler, error) {
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var closers []io.Closer
|
||||
defer func() {
|
||||
for _, closer := range closers {
|
||||
closer.Close()
|
||||
}
|
||||
}()
|
||||
in := make([]reflect.Value, len(extractors))
|
||||
for idx, extractor := range extractors {
|
||||
v, err := extractor(w, r)
|
||||
@ -128,9 +146,6 @@ func MagicHandler(fn any) (http.Handler, error) {
|
||||
}
|
||||
return
|
||||
}
|
||||
if closer, ok := v.(io.Closer); ok {
|
||||
defer closer.Close()
|
||||
}
|
||||
in[idx] = reflect.ValueOf(v)
|
||||
}
|
||||
values := funcValue.Call(in)
|
||||
@ -169,6 +184,11 @@ func MagicHandler(fn any) (http.Handler, error) {
|
||||
continue
|
||||
}
|
||||
responseWriter.WriteResponse(w)
|
||||
} else if responseWriter, ok := util.Implements[magic.RespondWriter](&value); ok {
|
||||
if responseWriter == nil {
|
||||
continue
|
||||
}
|
||||
responseWriter.WriteResponse(w)
|
||||
} else if headers, ok := obj.(http.Header); ok { // not
|
||||
for name, values := range headers {
|
||||
for idx, value := range values {
|
||||
|
508
inject_test.go
508
inject_test.go
@ -30,10 +30,47 @@ type HelloWorldParams struct {
|
||||
}
|
||||
|
||||
type MultipleParams struct {
|
||||
Numbers []int
|
||||
Strings []string
|
||||
OptionalNumber *int
|
||||
OptionalString *string
|
||||
Numbers []int `query:"numbers"`
|
||||
Strings []string `query:"strings"`
|
||||
OptionalNumber *int `query:"optionalNumber"`
|
||||
OptionalString *string `query:"optionalString"`
|
||||
}
|
||||
|
||||
type QueryTypes struct {
|
||||
IntField int `query:"int_field"`
|
||||
StringField string `query:"string_field"`
|
||||
BoolField bool `query:"bool_field"`
|
||||
FloatField32 float32 `query:"float32_field"`
|
||||
FloatField64 float64 `query:"float64_field"`
|
||||
IntSlice []int `query:"int_slice"`
|
||||
StringSlice []string `query:"string_slice"`
|
||||
BoolSlice []bool `query:"bool_slice"`
|
||||
OptionalInt *int `query:"optional_int"`
|
||||
}
|
||||
|
||||
type IntegerPathValues struct {
|
||||
IntField int `pathvalue:"int_val"`
|
||||
Int8Field int8 `pathvalue:"int8_val"`
|
||||
Int16Field int16 `pathvalue:"int16_val"`
|
||||
Int32Field int32 `pathvalue:"int32_val"`
|
||||
Int64Field int64 `pathvalue:"int64_val"`
|
||||
UintField uint `pathvalue:"uint_val"`
|
||||
Uint8Field uint8 `pathvalue:"uint8_val"`
|
||||
Uint16Field uint16 `pathvalue:"uint16_val"`
|
||||
Uint32Field uint32 `pathvalue:"uint32_val"`
|
||||
Uint64Field uint64 `pathvalue:"uint64_val"`
|
||||
}
|
||||
|
||||
type ErrorFromRequest struct{}
|
||||
|
||||
func (e ErrorFromRequest) FromRequest(r *http.Request) error {
|
||||
return fmt.Errorf("custom error from request")
|
||||
}
|
||||
|
||||
type PtrErrorFromRequest struct{}
|
||||
|
||||
func (e *PtrErrorFromRequest) FromRequest(r *http.Request) error {
|
||||
return fmt.Errorf("custom error from pointer request")
|
||||
}
|
||||
|
||||
func TestNotAFunc(t *testing.T) {
|
||||
@ -60,6 +97,12 @@ func TestMethod(t *testing.T) {
|
||||
MagicHandler(obj.Func)
|
||||
}
|
||||
|
||||
func TestJsonResponse(t *testing.T) {
|
||||
MustMagicHandler(func() magic.Json[map[string]any] {
|
||||
return magic.Json[map[string]any]{}
|
||||
})
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
Path string
|
||||
Pattern string
|
||||
@ -278,13 +321,468 @@ var (
|
||||
// })
|
||||
// },
|
||||
// },
|
||||
{
|
||||
Pattern: "/hello-query",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(query magic.Query[HelloWorldParams]) magic.Json[HelloWorldResponse] {
|
||||
name := "world"
|
||||
if query.Data.Name != nil {
|
||||
name = *query.Data.Name
|
||||
}
|
||||
return magic.Json[HelloWorldResponse]{
|
||||
Data: HelloWorldResponse{
|
||||
Greeting: "hello, " + name,
|
||||
},
|
||||
}
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-query",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(query magic.Query[MultipleParams]) {
|
||||
log.Println(query.Data)
|
||||
// Add assertions for MultipleParams
|
||||
if len(query.Data.Numbers) != 3 || query.Data.Numbers[0] != 1 || query.Data.Numbers[1] != 2 || query.Data.Numbers[2] != 3 {
|
||||
panic(fmt.Sprintf("expected numbers [1 2 3], got %v", query.Data.Numbers))
|
||||
}
|
||||
if len(query.Data.Strings) != 2 || query.Data.Strings[0] != "a" || query.Data.Strings[1] != "b" {
|
||||
panic(fmt.Sprintf("expected strings [a b], got %v", query.Data.Strings))
|
||||
}
|
||||
if query.Data.OptionalNumber == nil || *query.Data.OptionalNumber != 10 {
|
||||
panic(fmt.Sprintf("expected optional number 10, got %v", query.Data.OptionalNumber))
|
||||
}
|
||||
if query.Data.OptionalString == nil || *query.Data.OptionalString != "optional" {
|
||||
panic(fmt.Sprintf("expected optional string 'optional', got %v", query.Data.OptionalString))
|
||||
}
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
req, err := http.NewRequest("GET", base+"/test-query?numbers=1&numbers=2&numbers=3&strings=a&strings=b&optionalNumber=10&optionalString=optional", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-query-optional-missing",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(query magic.Query[MultipleParams]) {
|
||||
// Test missing optional fields
|
||||
if query.Data.OptionalNumber != nil {
|
||||
panic(fmt.Sprintf("expected optional number to be nil, got %v", query.Data.OptionalNumber))
|
||||
}
|
||||
if query.Data.OptionalString != nil {
|
||||
panic(fmt.Sprintf("expected optional string to be nil, got %v", query.Data.OptionalString))
|
||||
}
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
req, err := http.NewRequest("GET", base+"/test-query-optional-missing?numbers=1&strings=a", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-query-types",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
|
||||
// Test various data types and slices
|
||||
if query.Data.IntField != 123 {
|
||||
panic(fmt.Sprintf("expected int_field 123, got %d", query.Data.IntField))
|
||||
}
|
||||
if query.Data.StringField != "test_string" {
|
||||
panic(fmt.Sprintf("expected string_field 'test_string', got '%s'", query.Data.StringField))
|
||||
}
|
||||
if !query.Data.BoolField {
|
||||
panic(fmt.Sprintf("expected bool_field true, got %t", query.Data.BoolField))
|
||||
}
|
||||
if query.Data.FloatField32 != 1.23 {
|
||||
panic(fmt.Sprintf("expected float32_field 1.23, got %f", query.Data.FloatField32))
|
||||
}
|
||||
if query.Data.FloatField64 != 4.56 {
|
||||
panic(fmt.Sprintf("expected float64_field 4.56, got %f", query.Data.FloatField64))
|
||||
}
|
||||
if len(query.Data.IntSlice) != 2 || query.Data.IntSlice[0] != 1 || query.Data.IntSlice[1] != 2 {
|
||||
panic(fmt.Sprintf("expected int_slice [1 2], got %v", query.Data.IntSlice))
|
||||
}
|
||||
if len(query.Data.StringSlice) != 2 || query.Data.StringSlice[0] != "x" || query.Data.StringSlice[1] != "y" {
|
||||
panic(fmt.Sprintf("expected string_slice [x y], got %v", query.Data.StringSlice))
|
||||
}
|
||||
if len(query.Data.BoolSlice) != 2 || query.Data.BoolSlice[0] != true || query.Data.BoolSlice[1] != false {
|
||||
panic(fmt.Sprintf("expected bool_slice [true false], got %v", query.Data.BoolSlice))
|
||||
}
|
||||
if query.Data.OptionalInt == nil || *query.Data.OptionalInt != 99 {
|
||||
panic(fmt.Sprintf("expected optional_int 99, got %v", query.Data.OptionalInt))
|
||||
}
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
req, err := http.NewRequest("GET", base+"/test-query-types?int_field=123&string_field=test_string&bool_field=true&float32_field=1.23&float64_field=4.56&int_slice=1&int_slice=2&string_slice=x&string_slice=y&bool_slice=true&bool_slice=false&optional_int=99", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-query-types-missing-optional",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
|
||||
// Test missing optional int
|
||||
if query.Data.OptionalInt != nil {
|
||||
panic(fmt.Sprintf("expected optional_int to be nil, got %v", query.Data.OptionalInt))
|
||||
}
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
req, err := http.NewRequest("GET", base+"/test-query-types-missing-optional?int_field=123&string_field=test_string&bool_field=true&float32_field=1.23&float64_field=4.56&int_slice=1&int_slice=2&string_slice=x&string_slice=y&bool_slice=true&bool_slice=false", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-query-invalid-int",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
|
||||
panic("handler should not be called for invalid input")
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
req, err := http.NewRequest("GET", base+"/test-query-invalid-int?int_field=abc&string_field=test", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-query-invalid-bool",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
|
||||
panic("handler should not be called for invalid input")
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
req, err := http.NewRequest("GET", base+"/test-query-invalid-bool?int_field=123&string_field=test&bool_field=not_a_bool", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-query-invalid-float",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
|
||||
panic("handler should not be called for invalid input")
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
req, err := http.NewRequest("GET", base+"/test-query-invalid-float?int_field=123&string_field=test&bool_field=true&float64_field=not_a_float", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-form",
|
||||
CreateHandler: func() http.Handler {
|
||||
type TestForm struct {
|
||||
Name string `form:"name"`
|
||||
Age int `form:"age"`
|
||||
IsPro bool `form:"is_pro"`
|
||||
}
|
||||
return MustMagicHandler(func(form magic.Form[TestForm]) {
|
||||
if form.Data.Name != "test" {
|
||||
panic(fmt.Sprintf("expected name 'test', got '%s'", form.Data.Name))
|
||||
}
|
||||
if form.Data.Age != 30 {
|
||||
panic(fmt.Sprintf("expected age 30, got %d", form.Data.Age))
|
||||
}
|
||||
if !form.Data.IsPro {
|
||||
panic(fmt.Sprintf("expected is_pro true, got %t", form.Data.IsPro))
|
||||
}
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
body := strings.NewReader("name=test&age=30&is_pro=true")
|
||||
req, err := http.NewRequest("POST", base+"/test-form", body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-form-advanced",
|
||||
CreateHandler: func() http.Handler {
|
||||
type TestFormAdvanced struct {
|
||||
Names []string `form:"names"`
|
||||
Ages [2]int `form:"ages"`
|
||||
ID string `form:"id,match=^[a-f0-9]{8}$"`
|
||||
Optional *string `form:"optional"`
|
||||
Numbers []int `form:"numbers"`
|
||||
Required string `form:"required"`
|
||||
FixedSize [3]bool `form:"fixed_size"`
|
||||
}
|
||||
return MustMagicHandler(func(form magic.Form[TestFormAdvanced]) {
|
||||
// Test successful extraction
|
||||
if len(form.Data.Names) != 2 || form.Data.Names[0] != "alice" || form.Data.Names[1] != "bob" {
|
||||
panic(fmt.Sprintf("expected names [alice bob], got %v", form.Data.Names))
|
||||
}
|
||||
if form.Data.Ages[0] != 25 || form.Data.Ages[1] != 30 {
|
||||
panic(fmt.Sprintf("expected ages [25 30], got %v", form.Data.Ages))
|
||||
}
|
||||
if form.Data.ID != "abcdef01" {
|
||||
panic(fmt.Sprintf("expected id 'abcdef01', got '%s'", form.Data.ID))
|
||||
}
|
||||
if form.Data.Optional == nil || *form.Data.Optional != "present" {
|
||||
panic(fmt.Sprintf("expected optional 'present', got %v", form.Data.Optional))
|
||||
}
|
||||
if len(form.Data.Numbers) != 3 || form.Data.Numbers[0] != 1 || form.Data.Numbers[1] != 2 || form.Data.Numbers[2] != 3 {
|
||||
panic(fmt.Sprintf("expected numbers [1 2 3], got %v", form.Data.Numbers))
|
||||
}
|
||||
if form.Data.Required != "required_value" {
|
||||
panic(fmt.Sprintf("expected required 'required_value', got '%s'", form.Data.Required))
|
||||
}
|
||||
if form.Data.FixedSize[0] != true || form.Data.FixedSize[1] != false || form.Data.FixedSize[2] != true {
|
||||
panic(fmt.Sprintf("expected fixed_size [true false true], got %v", form.Data.FixedSize))
|
||||
}
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
body := strings.NewReader("names=alice&names=bob&ages=25&ages=30&id=abcdef01&optional=present&numbers=1&numbers=2&numbers=3&required=required_value&fixed_size=true&fixed_size=false&fixed_size=true")
|
||||
req, err := http.NewRequest("POST", base+"/test-form-advanced", body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-form-advanced-errors",
|
||||
CreateHandler: func() http.Handler {
|
||||
type TestFormAdvancedErrors struct {
|
||||
Required string `form:"required"`
|
||||
}
|
||||
return MustMagicHandler(func(form magic.Form[TestFormAdvancedErrors]) {
|
||||
// This handler should not be reached in case of errors,
|
||||
// the magic handler should return an error before calling this.
|
||||
panic("handler should not be called for error test cases")
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
// This request is intentionally malformed to trigger errors
|
||||
body := strings.NewReader("names=alice&ages=25&ages=30&id=invalid-id&numbers=1&required=required_value&fixed_size=true&fixed_size=false") // Missing one fixed_size value, invalid ID
|
||||
req, err := http.NewRequest("POST", base+"/test-form-advanced-errors", body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
// Expecting an error response due to invalid form data
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
// Further checks on the error message body could be added here
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-json-non-struct",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(body magic.Json[string]) string {
|
||||
return "Received: " + body.Data
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
body := strings.NewReader(`"a simple string"`)
|
||||
req, err := http.NewRequest("POST", base+"/test-json-non-struct", body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return req
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
bodyString := string(bodyBytes)
|
||||
// The response is the string "Received: " + the JSON string,
|
||||
// which includes the quotes from the original JSON body.
|
||||
expected := `"Received: \"a simple string\""`
|
||||
if !strings.Contains(bodyString, expected) {
|
||||
return fmt.Errorf("expected response body to contain '%s', got '%s'", expected, bodyString)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-from-request-error",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(e ErrorFromRequest) {
|
||||
panic("handler should not be called for error test cases")
|
||||
})
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-ptr-from-request-error",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(e *PtrErrorFromRequest) {
|
||||
panic("handler should not be called for error test cases")
|
||||
})
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-form-invalid-int",
|
||||
CreateHandler: func() http.Handler {
|
||||
type TestForm struct {
|
||||
Age int `form:"age"`
|
||||
}
|
||||
return MustMagicHandler(func(form magic.Form[TestForm]) {
|
||||
panic("handler should not be called for invalid input")
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
body := strings.NewReader("age=abc")
|
||||
req, err := http.NewRequest("POST", base+"/test-form-invalid-int", body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-form-array-too-few",
|
||||
CreateHandler: func() http.Handler {
|
||||
type TestForm struct {
|
||||
Ages [2]int `form:"ages"`
|
||||
}
|
||||
return MustMagicHandler(func(form magic.Form[TestForm]) {
|
||||
panic("handler should not be called for invalid input")
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
body := strings.NewReader("ages=25")
|
||||
req, err := http.NewRequest("POST", base+"/test-form-array-too-few", body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-query-slice-invalid-entry",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
|
||||
panic("handler should not be called for invalid input")
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
req, err := http.NewRequest("GET", base+"/test-query-slice-invalid-entry?int_slice=1&int_slice=abc&int_slice=3", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
Pattern: "/test-double-body-read",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(body io.Reader, jsonBody magic.Json[HelloWorldRequest]) {
|
||||
// This handler should panic because we are trying to read the body twice.
|
||||
panic("handler should not be called")
|
||||
})
|
||||
},
|
||||
ExpectPanicOnCreateHandler: true,
|
||||
},
|
||||
{
|
||||
Pattern: "/websocket-fail",
|
||||
CreateHandler: func() http.Handler {
|
||||
return MustMagicHandler(func(ws *websocket.Conn) {
|
||||
// This handler should not be called if websocket upgrade fails.
|
||||
panic("handler should not be called")
|
||||
})
|
||||
},
|
||||
MakeRequest: func(base string) *http.Request {
|
||||
// A regular GET request without the upgrade headers will fail the websocket upgrade.
|
||||
req, err := http.NewRequest("GET", base+"/websocket-fail", nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req
|
||||
},
|
||||
CheckResponse: func(resp *http.Response) error {
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
37
magic/autodecode.go
Normal file
37
magic/autodecode.go
Normal file
@ -0,0 +1,37 @@
|
||||
package magic
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ErrUknownContentType string
|
||||
|
||||
func (err ErrUknownContentType) Error() string {
|
||||
return "unknown content-type: " + string(err)
|
||||
}
|
||||
|
||||
type AutoDecode[T any] struct {
|
||||
Data T
|
||||
}
|
||||
|
||||
func (d *AutoDecode[T]) FromRequest(r *http.Request) error {
|
||||
contentType, _, _ := strings.Cut(r.Header.Get("Content-Type"), ";")
|
||||
switch contentType {
|
||||
case "application/json":
|
||||
decoder := &Json[T]{}
|
||||
if err := decoder.FromRequest(r); err != nil {
|
||||
return err
|
||||
}
|
||||
d.Data = decoder.Data
|
||||
case "multipart/form-data", "application/x-www-form-urlencoded":
|
||||
decoder := &Form[T]{}
|
||||
if err := decoder.FromRequest(r); err != nil {
|
||||
return err
|
||||
}
|
||||
d.Data = decoder.Data
|
||||
}
|
||||
return ErrUknownContentType(contentType)
|
||||
}
|
||||
|
||||
func (d AutoDecode[T]) TakeRequestBody() {}
|
@ -4,6 +4,8 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -54,21 +56,21 @@ var (
|
||||
if v, err := strconv.ParseInt(s, 10, 16); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return int8(v), nil
|
||||
return int16(v), nil
|
||||
}
|
||||
},
|
||||
reflect.Int32: func(s string) (any, error) {
|
||||
if v, err := strconv.ParseInt(s, 10, 32); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return int8(v), nil
|
||||
return int32(v), nil
|
||||
}
|
||||
},
|
||||
reflect.Int64: func(s string) (any, error) {
|
||||
if v, err := strconv.ParseInt(s, 10, 64); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return int8(v), nil
|
||||
return v, nil
|
||||
}
|
||||
},
|
||||
|
||||
@ -108,4 +110,24 @@ var (
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
defaultValidator = NewValidator("")
|
||||
)
|
||||
|
||||
func NewValidator(nameFromTag string) *validator.Validate {
|
||||
validate := validator.New(
|
||||
validator.WithRequiredStructEnabled(),
|
||||
)
|
||||
|
||||
if nameFromTag != "" {
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := strings.SplitN(fld.Tag.Get(nameFromTag), ",", 2)[0]
|
||||
if name == "-" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
})
|
||||
}
|
||||
|
||||
return validate
|
||||
}
|
@ -11,6 +11,8 @@ import (
|
||||
var (
|
||||
extractors = make(map[reflect.Type]func(*http.Request) (any, error))
|
||||
|
||||
extractorsTakeBody = make(map[reflect.Type]bool)
|
||||
|
||||
extractorsTakesResponseWriter = make(map[reflect.Type]func(http.ResponseWriter, *http.Request) (any, error))
|
||||
)
|
||||
|
||||
@ -18,6 +20,12 @@ type FromRequest interface {
|
||||
FromRequest(*http.Request) error
|
||||
}
|
||||
|
||||
// Marker interface
|
||||
type TakeRequestBody interface {
|
||||
TakeRequestBody()
|
||||
}
|
||||
|
||||
// Marker interface
|
||||
type TakeResponseWriter interface {
|
||||
TakeResponseWriter(http.ResponseWriter)
|
||||
}
|
||||
@ -43,6 +51,42 @@ func RegisterExtractorGeneric[T any](extractor func(*http.Request) (any, error))
|
||||
RegisterExtractor(reflect.TypeOf(pointerToT).Elem(), extractor)
|
||||
}
|
||||
|
||||
func RegisterExtractorThatTakesBody(o any, extractor func(*http.Request) (any, error)) {
|
||||
if t, ok := o.(reflect.Type); ok {
|
||||
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
|
||||
extractors[t] = extractor
|
||||
extractorsTakeBody[t] = true
|
||||
} else if v, ok := o.(reflect.Value); ok {
|
||||
slog.With("type", v.Type().String(), "value", v.Interface()).Debug("extractor type registered with reflect.Value")
|
||||
|
||||
extractors[v.Type()] = extractor
|
||||
extractorsTakeBody[v.Type()] = true
|
||||
} else {
|
||||
t := reflect.TypeOf(o)
|
||||
slog.With("type", t.String(), "value", o).Debug("extractor type registered with object")
|
||||
|
||||
extractors[t] = extractor
|
||||
extractorsTakeBody[t] = true
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterExtractorThatTakesBodyGeneric[T any](extractor func(*http.Request) (any, error)) {
|
||||
var pointerToT *T
|
||||
RegisterExtractorThatTakesBody(reflect.TypeOf(pointerToT).Elem(), extractor)
|
||||
}
|
||||
|
||||
func IsTakeBody(t reflect.Type) bool {
|
||||
if isTakeBody := extractorsTakeBody[t]; isTakeBody {
|
||||
return isTakeBody
|
||||
}
|
||||
|
||||
if _, ok := util.Implements[TakeRequestBody](t); ok {
|
||||
return ok
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func RegisterExtractorThatTakesResponseWriter(o any, extractor func(http.ResponseWriter, *http.Request) (any, error)) {
|
||||
if t, ok := o.(reflect.Type); ok {
|
||||
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
|
||||
@ -70,62 +114,57 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
|
||||
_, ptrIsTakeResponseWriter := util.Implements[TakeResponseWriter](t.Elem())
|
||||
isTakeResponseWriter = isTakeResponseWriter || ptrIsTakeResponseWriter
|
||||
}
|
||||
if _, ok := util.Implements[FromRequest](t); ok {
|
||||
|
||||
fromRequestType := reflect.TypeOf((*FromRequest)(nil)).Elem()
|
||||
if t.Implements(fromRequestType) {
|
||||
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
val := reflect.New(t).Elem()
|
||||
if t.Kind() == reflect.Pointer {
|
||||
// T.Implement(FromRequest) and T is Pointer
|
||||
// create new actual T and call T.FromRequest on it
|
||||
// if error happens, no error is returned. just return nil
|
||||
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
// var t T
|
||||
// return t, t.FromRequest(r)
|
||||
z := reflect.New(t.Elem())
|
||||
val = reflect.New(t.Elem())
|
||||
}
|
||||
|
||||
if isTakeResponseWriter {
|
||||
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
|
||||
if taker, ok := val.Addr().Interface().(TakeResponseWriter); ok {
|
||||
taker.TakeResponseWriter(w)
|
||||
}
|
||||
}
|
||||
|
||||
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
if err, ok := results[0].Interface().(error); ok && err != nil {
|
||||
if errResponse, ok := err.(ErrorResponse); ok {
|
||||
return reflect.Zero(t).Interface(), errResponse
|
||||
if fromR, ok := val.Interface().(FromRequest); ok {
|
||||
if err := fromR.FromRequest(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reflect.Zero(t).Interface(), nil
|
||||
} else if val.CanAddr() {
|
||||
if fromR, ok := val.Addr().Interface().(FromRequest); ok {
|
||||
if err := fromR.FromRequest(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return z.Interface(), nil
|
||||
}
|
||||
}
|
||||
if t.Kind() == reflect.Pointer {
|
||||
|
||||
return val.Interface(), nil
|
||||
}
|
||||
return val.Addr().Interface(), nil
|
||||
}, isTakeResponseWriter
|
||||
}
|
||||
// T.Implement(FromRequest) and T is not Pointer
|
||||
// create zero T and call T.FromRequest directly on it
|
||||
|
||||
if reflect.PointerTo(t).Implements(fromRequestType) {
|
||||
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
// var t T
|
||||
// return t, t.FromRequest(r)
|
||||
z := reflect.Zero(t)
|
||||
val := reflect.New(t)
|
||||
|
||||
if isTakeResponseWriter {
|
||||
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
|
||||
if taker, ok := val.Interface().(TakeResponseWriter); ok {
|
||||
taker.TakeResponseWriter(w)
|
||||
}
|
||||
}
|
||||
|
||||
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
if err, ok := results[0].Interface().(error); ok {
|
||||
return z.Interface(), err
|
||||
if fromR, ok := val.Interface().(FromRequest); ok {
|
||||
if err := fromR.FromRequest(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return z.Interface(), nil
|
||||
}, isTakeResponseWriter
|
||||
} else if v, ok := util.PointerImplements[FromRequest](t); ok {
|
||||
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
// var t T
|
||||
// return t, t.FromRequest(r)
|
||||
z := reflect.New(v.Type().Elem())
|
||||
|
||||
if isTakeResponseWriter {
|
||||
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
|
||||
}
|
||||
|
||||
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
if err := results[0].Interface(); err == nil {
|
||||
return z.Elem().Interface(), nil
|
||||
}
|
||||
return z.Elem().Interface(), results[0].Interface().(error)
|
||||
return val.Elem().Interface(), nil
|
||||
}, isTakeResponseWriter
|
||||
}
|
||||
|
||||
@ -142,7 +181,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
// Create a pointer to the value
|
||||
vp := reflect.New(reflect.TypeOf(v))
|
||||
vp.Elem().Set(reflect.ValueOf(v))
|
||||
return vp.Interface(), nil
|
||||
}, false
|
||||
}
|
||||
}
|
||||
@ -160,7 +202,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
// Create a pointer to the value
|
||||
vp := reflect.New(reflect.TypeOf(v))
|
||||
vp.Elem().Set(reflect.ValueOf(v))
|
||||
return vp.Interface(), nil
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
206
magic/form.go
Normal file
206
magic/form.go
Normal file
@ -0,0 +1,206 @@
|
||||
package magic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
cachedStructFormFields = make(map[reflect.Type][]structFormFields)
|
||||
cachedUnsupportedFormFields = make(map[reflect.Type]structFormFields)
|
||||
)
|
||||
|
||||
type Form[T any] struct {
|
||||
Data T
|
||||
}
|
||||
|
||||
type structFormFields struct {
|
||||
PathKey string
|
||||
FieldKey string
|
||||
|
||||
Type reflect.Type
|
||||
MinValues, MaxValues int
|
||||
|
||||
Match *regexp.Regexp
|
||||
}
|
||||
|
||||
type UnsupportedFormType struct {
|
||||
inner reflect.Type
|
||||
}
|
||||
|
||||
func (err UnsupportedFormType) Error() string {
|
||||
return "unsupported type for Form: " + err.inner.String()
|
||||
}
|
||||
|
||||
type FormValueNotFitIn struct {
|
||||
Count int
|
||||
|
||||
field structFormFields
|
||||
}
|
||||
|
||||
func (err FormValueNotFitIn) Error() string {
|
||||
return fmt.Sprintf("Form %v is expected to has %v~%v values, got %v", err.field.PathKey, err.field.MinValues, err.field.MaxValues, err.Count)
|
||||
}
|
||||
|
||||
type InvalidFormValue struct {
|
||||
Match *regexp.Regexp
|
||||
Value string
|
||||
}
|
||||
|
||||
func (err InvalidFormValue) Error() string {
|
||||
return fmt.Sprintf("value not matched with regexp %v: %v", err.Match, err.Value)
|
||||
}
|
||||
|
||||
func findReflectFormFields(v reflect.Value) ([]structFormFields, error) {
|
||||
if cached, ok := cachedStructFormFields[v.Type()]; ok {
|
||||
return cached, nil
|
||||
}
|
||||
if cached, ok := cachedUnsupportedFormFields[v.Type()]; ok {
|
||||
return nil, UnsupportedPathValueType{inner: cached.Type}
|
||||
}
|
||||
|
||||
var fields []structFormFields
|
||||
if v.Kind() == reflect.Struct {
|
||||
t := v.Type()
|
||||
for idx := 0; idx < t.NumField(); idx++ {
|
||||
field := t.Field(idx)
|
||||
|
||||
pathKey := field.Name
|
||||
var match *regexp.Regexp
|
||||
|
||||
if tags, ok := field.Tag.Lookup("form"); ok {
|
||||
parts := strings.Split(strings.TrimSpace(tags), ",")
|
||||
if len(parts) == 0 {
|
||||
// do nothing
|
||||
} else {
|
||||
if parts[0] != "" {
|
||||
pathKey = parts[0]
|
||||
}
|
||||
|
||||
for _, part := range parts[1:] {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "match=") {
|
||||
part := strings.TrimPrefix(part, "match=")
|
||||
if unquoted, err := strconv.Unquote(part); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
part = unquoted
|
||||
}
|
||||
exp, err := regexp.Compile(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
match = exp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fieldType := field.Type
|
||||
maxValues, minValues := 1, 1
|
||||
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
minValues = 0
|
||||
} else if fieldType.Kind() == reflect.Array {
|
||||
maxValues, minValues = fieldType.Len(), fieldType.Len()
|
||||
fieldType = fieldType.Elem()
|
||||
} else if fieldType.Kind() == reflect.Slice {
|
||||
fieldType = fieldType.Elem()
|
||||
maxValues, minValues = -1, 0
|
||||
}
|
||||
|
||||
if pathValueConvertTable[fieldType.Kind()] == nil {
|
||||
if _, ok := fieldType.MethodByName("FromString"); !ok {
|
||||
cachedUnsupportedPathValueFields[t] = field
|
||||
return nil, UnsupportedPathValueType{inner: field.Type}
|
||||
}
|
||||
}
|
||||
fields = append(fields, structFormFields{
|
||||
PathKey: pathKey,
|
||||
FieldKey: field.Name,
|
||||
|
||||
Type: fieldType,
|
||||
MaxValues: maxValues,
|
||||
MinValues: minValues,
|
||||
|
||||
Match: match,
|
||||
})
|
||||
|
||||
cachedStructFormFields[t] = fields
|
||||
}
|
||||
} else {
|
||||
return nil, UnsupportedFormType{inner: v.Type()}
|
||||
}
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
func (form *Form[T]) FromRequest(r *http.Request) error {
|
||||
v := reflect.ValueOf(form).Elem().FieldByName("Data")
|
||||
fields, err := findReflectFormFields(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
return err
|
||||
}
|
||||
q := r.PostForm
|
||||
|
||||
for _, field := range fields {
|
||||
values := q[field.PathKey]
|
||||
if len(values) < field.MinValues || (field.MaxValues > 0 && len(values) > field.MaxValues) {
|
||||
return FormValueNotFitIn{Count: len(values), field: field}
|
||||
}
|
||||
|
||||
structField := v.FieldByName(field.FieldKey)
|
||||
|
||||
if structField.Kind() == reflect.Slice && structField.Len() < len(values) {
|
||||
if structField.Cap() < len(values) {
|
||||
structField.SetCap(len(values))
|
||||
}
|
||||
structField.SetLen(len(values))
|
||||
}
|
||||
|
||||
for idx, value := range values {
|
||||
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
|
||||
if idx != len(values)-1 {
|
||||
// only the last takes effect
|
||||
continue
|
||||
}
|
||||
}
|
||||
// check match regexp
|
||||
if field.Match != nil && !field.Match.MatchString(value) {
|
||||
return InvalidFormValue{Match: field.Match, Value: value}
|
||||
}
|
||||
|
||||
// convert to elem value
|
||||
convert := pathValueConvertTable[field.Type.Kind()]
|
||||
result, err := convert(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reflectValue := reflect.ValueOf(result)
|
||||
|
||||
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
|
||||
if field.MinValues == 0 {
|
||||
newReflectValue := reflect.New(field.Type)
|
||||
newReflectValue.Elem().Set(reflectValue)
|
||||
reflectValue = newReflectValue
|
||||
}
|
||||
structField.Set(reflectValue)
|
||||
} else {
|
||||
structField.Index(idx).Set(reflectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (form Form[T]) TakeRequestBody() {}
|
@ -2,8 +2,15 @@ package magic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
var (
|
||||
jsonValidator = NewValidator("json")
|
||||
)
|
||||
|
||||
type JsonDecodeError struct {
|
||||
@ -18,10 +25,25 @@ func (err JsonDecodeError) WriteResponse(rw http.ResponseWriter) {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
type Json[T any] struct {
|
||||
Data T
|
||||
type ValidateError []string
|
||||
|
||||
func (err ValidateError) Error() string {
|
||||
return "invalid or missing fields: [" + strings.Join(err, ", ") + "]"
|
||||
}
|
||||
|
||||
func (v ValidateError) WriteResponse(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(v.Error()))
|
||||
}
|
||||
|
||||
type Json[T any] struct {
|
||||
Data T `json:"data"`
|
||||
}
|
||||
|
||||
// TakeRequestBody is a marker method to indicate that this extractor takes the request body.
|
||||
func (j Json[T]) TakeRequestBody() {}
|
||||
|
||||
func NewJson[T any](data T) Json[T] {
|
||||
return Json[T]{
|
||||
Data: data,
|
||||
@ -29,35 +51,41 @@ func NewJson[T any](data T) Json[T] {
|
||||
}
|
||||
|
||||
func (data *Json[T]) FromRequest(request *http.Request) error {
|
||||
if request.Body == nil {
|
||||
panic("body is already taken")
|
||||
}
|
||||
|
||||
bodyReader := request.Body
|
||||
defer bodyReader.Close()
|
||||
|
||||
request.Body = http.NoBody
|
||||
request.Body = nil
|
||||
|
||||
if err := json.NewDecoder(bodyReader).Decode(&data.Data); err != nil {
|
||||
return JsonDecodeError{inner: err}
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(data.Data)
|
||||
if value.Kind() == reflect.Struct || (value.Kind() == reflect.Ptr && value.Elem().Kind() == reflect.Struct) {
|
||||
if err := jsonValidator.StructCtx(request.Context(), data.Data); err == nil {
|
||||
return nil
|
||||
} else if validateErr, isValidateErr := err.(validator.ValidationErrors); isValidateErr {
|
||||
fields := make([]string, len(validateErr))
|
||||
for idx, err := range validateErr {
|
||||
fields[idx] = err.Field()
|
||||
}
|
||||
return ValidateError(fields)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type counter struct {
|
||||
counter int64
|
||||
}
|
||||
|
||||
func (c *counter) Write(b []byte) (int, error) {
|
||||
n := len(b)
|
||||
c.counter += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (container Json[T]) Write(w io.Writer) (int64, error) {
|
||||
counter := &counter{}
|
||||
err := json.NewEncoder(io.MultiWriter(counter, w)).Encode(container.Data)
|
||||
return counter.counter, err
|
||||
}
|
||||
|
||||
func (json Json[T]) WriteResponse(w http.ResponseWriter) {
|
||||
json.Write(w)
|
||||
func (data *Json[T]) WriteResponse(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if data == nil {
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
@ -8,5 +8,6 @@ import (
|
||||
type Map map[string]any
|
||||
|
||||
func (m Map) RespondWriter(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(m)
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ func (query *Query[T]) FromRequest(r *http.Request) error {
|
||||
|
||||
if structField.Kind() == reflect.Slice && structField.Len() < len(values) {
|
||||
if structField.Cap() < len(values) {
|
||||
structField.SetCap(len(values))
|
||||
structField.Grow(len(values) - structField.Cap())
|
||||
}
|
||||
structField.SetLen(len(values))
|
||||
}
|
||||
|
@ -2,7 +2,10 @@ package magic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"git.jeffthecoder.xyz/public/lazyhandler/middleware/cleanup"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -12,4 +15,49 @@ func init() {
|
||||
RegisterExtractorGeneric[context.Context](func(r *http.Request) (any, error) {
|
||||
return r.Context(), nil
|
||||
})
|
||||
// RegisterExtractorThatTakesBodyGeneric for io.ReadCloser extracts the request body.
|
||||
//
|
||||
// IMPORTANT: This extractor consumes the request body (r.Body). After this extractor
|
||||
// is used, r.Body will be set to nil to prevent it from being read multiple times.
|
||||
// Any subsequent attempt to read the body will fail.
|
||||
//
|
||||
// A cleanup function is automatically registered to close the body once the request
|
||||
// is finished. This requires the cleanup middleware to be present in the handler chain.
|
||||
RegisterExtractorThatTakesBodyGeneric[io.ReadCloser](func(r *http.Request) (any, error) {
|
||||
if r.Body == nil {
|
||||
panic("body is already taken")
|
||||
}
|
||||
|
||||
body := r.Body
|
||||
r.Body = nil
|
||||
|
||||
cleanup.Register(r.Context(), cleanup.CleanupFunc(func() {
|
||||
body.Close()
|
||||
}))
|
||||
|
||||
return body, nil
|
||||
})
|
||||
|
||||
// RegisterExtractorThatTakesBodyGeneric for io.Reader extracts the request body.
|
||||
//
|
||||
// IMPORTANT: This extractor consumes the request body (r.Body). After this extractor
|
||||
// is used, r.Body will be set to nil to prevent it from being read multiple times.
|
||||
// Any subsequent attempt to read the body will fail.
|
||||
//
|
||||
// A cleanup function is automatically registered to close the body once the request
|
||||
// is finished. This requires the cleanup middleware to be present in the handler chain.
|
||||
RegisterExtractorThatTakesBodyGeneric[io.Reader](func(r *http.Request) (any, error) {
|
||||
if r.Body == nil {
|
||||
panic("body is already taken")
|
||||
}
|
||||
|
||||
body := r.Body
|
||||
r.Body = nil
|
||||
|
||||
cleanup.Register(r.Context(), cleanup.CleanupFunc(func() {
|
||||
body.Close()
|
||||
}))
|
||||
|
||||
return body, nil
|
||||
})
|
||||
}
|
||||
|
@ -22,6 +22,26 @@ type ErrorResponse interface {
|
||||
WriteResponse(http.ResponseWriter)
|
||||
}
|
||||
|
||||
func WrapSimpleErrorWithStatus(message string, status int) ErrorResponse {
|
||||
return simpleErrorResponseWithoutBody{
|
||||
message: message,
|
||||
status: status,
|
||||
}
|
||||
}
|
||||
|
||||
type simpleErrorResponseWithoutBody struct {
|
||||
message string
|
||||
status int
|
||||
}
|
||||
|
||||
func (err simpleErrorResponseWithoutBody) Error() string {
|
||||
return err.message
|
||||
}
|
||||
|
||||
func (err simpleErrorResponseWithoutBody) WriteResponse(w http.ResponseWriter) {
|
||||
w.WriteHeader(err.status)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterExtractorThatTakesResponseWriterGeneric[http.ResponseWriter](func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
return w, nil
|
||||
|
@ -2,12 +2,8 @@ package magic
|
||||
|
||||
import "net/http"
|
||||
|
||||
type State[T any] struct {
|
||||
Data T
|
||||
}
|
||||
|
||||
func RegisterState[T any](data T) {
|
||||
RegisterExtractor(State[T]{}, func(r *http.Request) (any, error) {
|
||||
return State[T]{Data: data}, nil
|
||||
RegisterExtractor(data, func(r *http.Request) (any, error) {
|
||||
return data, nil
|
||||
})
|
||||
}
|
||||
|
61
middleware/cleanup/cleanup.go
Normal file
61
middleware/cleanup/cleanup.go
Normal file
@ -0,0 +1,61 @@
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
"git.jeffthecoder.xyz/public/lazyhandler/middleware"
|
||||
)
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
cleanupCtxKey ctxKey = iota
|
||||
)
|
||||
|
||||
type CleanupContext struct {
|
||||
funcs []Cleanup
|
||||
}
|
||||
|
||||
type Cleanup interface {
|
||||
Name() string
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
type CleanupFunc func()
|
||||
|
||||
func (fn CleanupFunc) Name() string {
|
||||
return reflect.TypeOf(fn).String()
|
||||
}
|
||||
|
||||
func (fn CleanupFunc) Cleanup() {
|
||||
defer func() {
|
||||
if v := recover(); v != nil {
|
||||
slog.With("v", v).Warn("cleanup panicked")
|
||||
}
|
||||
}()
|
||||
|
||||
fn()
|
||||
}
|
||||
|
||||
func Collect() middleware.Middleware {
|
||||
return middleware.WrapFunc(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := &CleanupContext{}
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cleanupCtxKey, ctx)))
|
||||
for _, cleanupFunc := range ctx.funcs {
|
||||
defer cleanupFunc.Cleanup()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Register adds a Cleanup function to the CleanupContext in the provided context.
|
||||
// If the CleanupContext is not found in the context, the Cleanup function is not registered.
|
||||
func Register(ctx context.Context, c Cleanup) {
|
||||
if ctx, ok := ctx.Value(cleanupCtxKey).(*CleanupContext); ok {
|
||||
ctx.funcs = append(ctx.funcs, c)
|
||||
}
|
||||
}
|
54
middleware/cleanup/cleanup_test.go
Normal file
54
middleware/cleanup/cleanup_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCleanupMiddleware(t *testing.T) {
|
||||
var executionOrder []string
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Register(r.Context(), CleanupFunc(func() {
|
||||
executionOrder = append(executionOrder, "first")
|
||||
}))
|
||||
Register(r.Context(), CleanupFunc(func() {
|
||||
executionOrder = append(executionOrder, "second")
|
||||
}))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware := Collect().WrapHandler(handler)
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||
rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
expectedOrder := "second,first"
|
||||
actualOrder := strings.Join(executionOrder, ",")
|
||||
if actualOrder != expectedOrder {
|
||||
t.Errorf("cleanup functions executed in wrong order: got %s want %s",
|
||||
actualOrder, expectedOrder)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupMissingContext(t *testing.T) {
|
||||
// This test ensures that Register does not panic when the context is missing.
|
||||
// The function should fail silently.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("The code panicked when it should not have")
|
||||
}
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
// No middleware, so no context
|
||||
Register(req.Context(), CleanupFunc(func() {}))
|
||||
}
|
@ -75,6 +75,10 @@ func (log Log) WrapHandler(next http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// Logger retrieves the slog.Logger from the request context.
|
||||
// If the logger is not found in the context (e.g., the httplog middleware is not used),
|
||||
// it creates and returns a new logger with basic request information.
|
||||
// Using the httplog middleware is recommended to ensure the configured logger is available.
|
||||
func Logger(r *http.Request) *slog.Logger {
|
||||
if logger, ok := r.Context().Value(loggerKey).(*slog.Logger); ok {
|
||||
return logger.With("time", time.Now())
|
||||
|
85
middleware/httplog/log_test.go
Normal file
85
middleware/httplog/log_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package httplog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockHijacker struct {
|
||||
*httptest.ResponseRecorder
|
||||
hijacked bool
|
||||
}
|
||||
|
||||
func (m *mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
m.hijacked = true
|
||||
// Return dummy values, not used in this test
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func TestResponseRecorder_Hijack(t *testing.T) {
|
||||
recorder := &responseRecorder{
|
||||
ResponseWriter: &mockHijacker{ResponseRecorder: httptest.NewRecorder()},
|
||||
}
|
||||
|
||||
_, _, err := recorder.Hijack()
|
||||
if err != nil {
|
||||
t.Fatalf("Hijack failed: %v", err)
|
||||
}
|
||||
|
||||
if recorder.ResponseWriter != nil {
|
||||
t.Error("ResponseWriter should be nil after Hijack")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, nil)))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
logger := Logger(r)
|
||||
logger.Info("test message")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("with middleware", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
logMiddleware := Log{LogStart: true, LogFinish: true}
|
||||
wrappedHandler := logMiddleware.WrapHandler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) {
|
||||
t.Error("expected start log message, but not found")
|
||||
}
|
||||
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) {
|
||||
t.Error("expected handler log message, but not found")
|
||||
}
|
||||
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) {
|
||||
t.Error("expected finish log message, but not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without middleware", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) {
|
||||
t.Error("expected handler log message, but not found")
|
||||
}
|
||||
if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) {
|
||||
t.Error("unexpected start log message found")
|
||||
}
|
||||
if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) {
|
||||
t.Error("unexpected finish log message found")
|
||||
}
|
||||
})
|
||||
}
|
@ -42,3 +42,20 @@ func DebugPanicHandler(w http.ResponseWriter, r *http.Request, err any) {
|
||||
func Debug() middleware.Middleware {
|
||||
return Recover(DebugPanicHandler)
|
||||
}
|
||||
|
||||
// ProductionPanicHandler is a PanicResponseFunc that provides a generic error message
|
||||
// in production environments and logs the detailed panic error.
|
||||
func ProductionPanicHandler(w http.ResponseWriter, r *http.Request, err any) {
|
||||
httplog.Logger(r).With("panic", err).Error("request panicked")
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": "Internal Server Error",
|
||||
})
|
||||
}
|
||||
|
||||
// Production returns a middleware that recovers from panics and handles them
|
||||
// with ProductionPanicHandler.
|
||||
func Production() middleware.Middleware {
|
||||
return Recover(ProductionPanicHandler)
|
||||
}
|
||||
|
56
middleware/recover/recover_test.go
Normal file
56
middleware/recover/recover_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
package recover
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func panicHandler(w http.ResponseWriter, r *http.Request) {
|
||||
panic("test panic")
|
||||
}
|
||||
|
||||
func TestDebugRecover(t *testing.T) {
|
||||
handler := http.HandlerFunc(panicHandler)
|
||||
wrappedHandler := Debug().WrapHandler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/panic", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response body: %v", err)
|
||||
}
|
||||
|
||||
if resp["error"] != "test panic" {
|
||||
t.Errorf("expected error 'test panic', got '%v'", resp["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProductionRecover(t *testing.T) {
|
||||
handler := http.HandlerFunc(panicHandler)
|
||||
wrappedHandler := Production().WrapHandler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/panic", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response body: %v", err)
|
||||
}
|
||||
|
||||
if resp["error"] != "Internal Server Error" {
|
||||
t.Errorf("expected error 'Internal Server Error', got '%v'", resp["error"])
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user