diff --git a/go.mod b/go.mod index 3326100..ffa5131 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,19 @@ module git.jeffthecoder.xyz/public/lazyhandler go 1.24.0 require ( + github.com/go-playground/validator/v10 v10.26.0 github.com/go-session/session/v3 v3.2.1 github.com/gorilla/websocket v1.5.3 ) -require github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 // indirect +require ( + github.com/bytedance/gopkg v0.1.2 // indirect + github.com/gabriel-vasile/mimetype v1.4.9 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + golang.org/x/crypto v0.39.0 // indirect + golang.org/x/net v0.41.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.26.0 // indirect +) diff --git a/go.sum b/go.sum index 6e9d68c..38b3159 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,17 @@ -github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8= -github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/bytedance/gopkg v0.1.2 h1:8o2feYuxknDpN+O7kPwvSXfMEKfYvJYiA2K7aonoMEQ= +github.com/bytedance/gopkg v0.1.2/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= +github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k= +github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/go-session/session/v3 v3.2.1 h1:APQf5JFW84+bhbqRjEZO8J+IppSgT1jMQTFI/XVyIFY= github.com/go-session/session/v3 v3.2.1/go.mod h1:RftEBbyuzqkNCAxIrCLJe+rfBqB/4G11qxq9KYKrx4M= github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0= @@ -9,14 +20,23 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0= github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/inject.go b/inject.go index 0799aa4..90b0887 100644 --- a/inject.go +++ b/inject.go @@ -12,10 +12,13 @@ import ( "git.jeffthecoder.xyz/public/lazyhandler/util" ) -type ErrArgumentIsNotExtractable int +type ErrArgumentIsNotExtractable struct { + Index int + Type reflect.Type +} func (err ErrArgumentIsNotExtractable) Error() string { - return fmt.Sprintf("argument %v is not extractable", int(err)) + return fmt.Sprintf("argument %v is not extractable: %s", err.Index, err.Type.String()) } type ErrReturnValueNotConvertableIntoResponsePart int @@ -27,6 +30,7 @@ func (err ErrReturnValueNotConvertableIntoResponsePart) Error() string { var ( ErrNotAFunc = errors.New("not a function") ErrDuplicateResponseWriterExtractor = errors.New("duplicate response writer extractor") + ErrDuplicateRequestBodyExtractor = errors.New("duplicate request body extractor") ErrResponseWriterCannotBeExtracted = errors.New("http.ResponseWriter extractor must not exists if function has return value") ) @@ -50,6 +54,7 @@ func MagicHandler(fn any) (http.Handler, error) { } responseWriterExtracted := -1 + requestBodyExtracted := -1 extractors := make([]func(http.ResponseWriter, *http.Request) (any, error), t.NumIn()) for idx := 0; idx < t.NumIn(); idx++ { @@ -57,7 +62,10 @@ func MagicHandler(fn any) (http.Handler, error) { extractor, isTakeResponseWriter := magic.GetExtractor(in) if extractor == nil { - return nil, ErrArgumentIsNotExtractable(idx) + return nil, ErrArgumentIsNotExtractable{ + Index: idx, + Type: in, + } } if isTakeResponseWriter { if responseWriterExtracted >= 0 { @@ -65,6 +73,14 @@ func MagicHandler(fn any) (http.Handler, error) { } responseWriterExtracted = idx } + + if magic.IsTakeBody(in) { + if requestBodyExtracted >= 0 { + return nil, ErrDuplicateRequestBodyExtractor + } + requestBodyExtracted = idx + } + extractors[idx] = extractor } @@ -98,6 +114,14 @@ func MagicHandler(fn any) (http.Handler, error) { if _, ok := util.Implements[magic.RespondWriter](out); ok { continue } + if _, ok := util.Implements[magic.RespondWriter](reflect.PointerTo(out)); ok { + continue + } + if out.Kind() == reflect.Pointer { + if _, ok := util.Implements[magic.RespondWriter](out.Elem()); ok { + continue + } + } // last is error if _, ok := util.Implements[error](out); ok && idx == t.NumOut()-1 { @@ -108,12 +132,6 @@ func MagicHandler(fn any) (http.Handler, error) { } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var closers []io.Closer - defer func() { - for _, closer := range closers { - closer.Close() - } - }() in := make([]reflect.Value, len(extractors)) for idx, extractor := range extractors { v, err := extractor(w, r) @@ -128,9 +146,6 @@ func MagicHandler(fn any) (http.Handler, error) { } return } - if closer, ok := v.(io.Closer); ok { - defer closer.Close() - } in[idx] = reflect.ValueOf(v) } values := funcValue.Call(in) @@ -169,6 +184,11 @@ func MagicHandler(fn any) (http.Handler, error) { continue } responseWriter.WriteResponse(w) + } else if responseWriter, ok := util.Implements[magic.RespondWriter](&value); ok { + if responseWriter == nil { + continue + } + responseWriter.WriteResponse(w) } else if headers, ok := obj.(http.Header); ok { // not for name, values := range headers { for idx, value := range values { diff --git a/inject_test.go b/inject_test.go index b2f786f..ae67429 100644 --- a/inject_test.go +++ b/inject_test.go @@ -30,10 +30,47 @@ type HelloWorldParams struct { } type MultipleParams struct { - Numbers []int - Strings []string - OptionalNumber *int - OptionalString *string + Numbers []int `query:"numbers"` + Strings []string `query:"strings"` + OptionalNumber *int `query:"optionalNumber"` + OptionalString *string `query:"optionalString"` +} + +type QueryTypes struct { + IntField int `query:"int_field"` + StringField string `query:"string_field"` + BoolField bool `query:"bool_field"` + FloatField32 float32 `query:"float32_field"` + FloatField64 float64 `query:"float64_field"` + IntSlice []int `query:"int_slice"` + StringSlice []string `query:"string_slice"` + BoolSlice []bool `query:"bool_slice"` + OptionalInt *int `query:"optional_int"` +} + +type IntegerPathValues struct { + IntField int `pathvalue:"int_val"` + Int8Field int8 `pathvalue:"int8_val"` + Int16Field int16 `pathvalue:"int16_val"` + Int32Field int32 `pathvalue:"int32_val"` + Int64Field int64 `pathvalue:"int64_val"` + UintField uint `pathvalue:"uint_val"` + Uint8Field uint8 `pathvalue:"uint8_val"` + Uint16Field uint16 `pathvalue:"uint16_val"` + Uint32Field uint32 `pathvalue:"uint32_val"` + Uint64Field uint64 `pathvalue:"uint64_val"` +} + +type ErrorFromRequest struct{} + +func (e ErrorFromRequest) FromRequest(r *http.Request) error { + return fmt.Errorf("custom error from request") +} + +type PtrErrorFromRequest struct{} + +func (e *PtrErrorFromRequest) FromRequest(r *http.Request) error { + return fmt.Errorf("custom error from pointer request") } func TestNotAFunc(t *testing.T) { @@ -60,6 +97,12 @@ func TestMethod(t *testing.T) { MagicHandler(obj.Func) } +func TestJsonResponse(t *testing.T) { + MustMagicHandler(func() magic.Json[map[string]any] { + return magic.Json[map[string]any]{} + }) +} + type testCase struct { Path string Pattern string @@ -278,13 +321,468 @@ var ( // }) // }, // }, + { + Pattern: "/hello-query", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(query magic.Query[HelloWorldParams]) magic.Json[HelloWorldResponse] { + name := "world" + if query.Data.Name != nil { + name = *query.Data.Name + } + return magic.Json[HelloWorldResponse]{ + Data: HelloWorldResponse{ + Greeting: "hello, " + name, + }, + } + }) + }, + }, { Pattern: "/test-query", CreateHandler: func() http.Handler { return MustMagicHandler(func(query magic.Query[MultipleParams]) { - log.Println(query.Data) + // Add assertions for MultipleParams + if len(query.Data.Numbers) != 3 || query.Data.Numbers[0] != 1 || query.Data.Numbers[1] != 2 || query.Data.Numbers[2] != 3 { + panic(fmt.Sprintf("expected numbers [1 2 3], got %v", query.Data.Numbers)) + } + if len(query.Data.Strings) != 2 || query.Data.Strings[0] != "a" || query.Data.Strings[1] != "b" { + panic(fmt.Sprintf("expected strings [a b], got %v", query.Data.Strings)) + } + if query.Data.OptionalNumber == nil || *query.Data.OptionalNumber != 10 { + panic(fmt.Sprintf("expected optional number 10, got %v", query.Data.OptionalNumber)) + } + if query.Data.OptionalString == nil || *query.Data.OptionalString != "optional" { + panic(fmt.Sprintf("expected optional string 'optional', got %v", query.Data.OptionalString)) + } }) }, + MakeRequest: func(base string) *http.Request { + req, err := http.NewRequest("GET", base+"/test-query?numbers=1&numbers=2&numbers=3&strings=a&strings=b&optionalNumber=10&optionalString=optional", nil) + if err != nil { + panic(err) + } + return req + }, + }, + { + Pattern: "/test-query-optional-missing", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(query magic.Query[MultipleParams]) { + // Test missing optional fields + if query.Data.OptionalNumber != nil { + panic(fmt.Sprintf("expected optional number to be nil, got %v", query.Data.OptionalNumber)) + } + if query.Data.OptionalString != nil { + panic(fmt.Sprintf("expected optional string to be nil, got %v", query.Data.OptionalString)) + } + }) + }, + MakeRequest: func(base string) *http.Request { + req, err := http.NewRequest("GET", base+"/test-query-optional-missing?numbers=1&strings=a", nil) + if err != nil { + panic(err) + } + return req + }, + }, + { + Pattern: "/test-query-types", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(query magic.Query[QueryTypes]) { + // Test various data types and slices + if query.Data.IntField != 123 { + panic(fmt.Sprintf("expected int_field 123, got %d", query.Data.IntField)) + } + if query.Data.StringField != "test_string" { + panic(fmt.Sprintf("expected string_field 'test_string', got '%s'", query.Data.StringField)) + } + if !query.Data.BoolField { + panic(fmt.Sprintf("expected bool_field true, got %t", query.Data.BoolField)) + } + if query.Data.FloatField32 != 1.23 { + panic(fmt.Sprintf("expected float32_field 1.23, got %f", query.Data.FloatField32)) + } + if query.Data.FloatField64 != 4.56 { + panic(fmt.Sprintf("expected float64_field 4.56, got %f", query.Data.FloatField64)) + } + if len(query.Data.IntSlice) != 2 || query.Data.IntSlice[0] != 1 || query.Data.IntSlice[1] != 2 { + panic(fmt.Sprintf("expected int_slice [1 2], got %v", query.Data.IntSlice)) + } + if len(query.Data.StringSlice) != 2 || query.Data.StringSlice[0] != "x" || query.Data.StringSlice[1] != "y" { + panic(fmt.Sprintf("expected string_slice [x y], got %v", query.Data.StringSlice)) + } + if len(query.Data.BoolSlice) != 2 || query.Data.BoolSlice[0] != true || query.Data.BoolSlice[1] != false { + panic(fmt.Sprintf("expected bool_slice [true false], got %v", query.Data.BoolSlice)) + } + if query.Data.OptionalInt == nil || *query.Data.OptionalInt != 99 { + panic(fmt.Sprintf("expected optional_int 99, got %v", query.Data.OptionalInt)) + } + }) + }, + MakeRequest: func(base string) *http.Request { + req, err := http.NewRequest("GET", base+"/test-query-types?int_field=123&string_field=test_string&bool_field=true&float32_field=1.23&float64_field=4.56&int_slice=1&int_slice=2&string_slice=x&string_slice=y&bool_slice=true&bool_slice=false&optional_int=99", nil) + if err != nil { + panic(err) + } + return req + }, + }, + { + Pattern: "/test-query-types-missing-optional", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(query magic.Query[QueryTypes]) { + // Test missing optional int + if query.Data.OptionalInt != nil { + panic(fmt.Sprintf("expected optional_int to be nil, got %v", query.Data.OptionalInt)) + } + }) + }, + MakeRequest: func(base string) *http.Request { + req, err := http.NewRequest("GET", base+"/test-query-types-missing-optional?int_field=123&string_field=test_string&bool_field=true&float32_field=1.23&float64_field=4.56&int_slice=1&int_slice=2&string_slice=x&string_slice=y&bool_slice=true&bool_slice=false", nil) + if err != nil { + panic(err) + } + return req + }, + }, + { + Pattern: "/test-query-invalid-int", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(query magic.Query[QueryTypes]) { + panic("handler should not be called for invalid input") + }) + }, + MakeRequest: func(base string) *http.Request { + req, err := http.NewRequest("GET", base+"/test-query-invalid-int?int_field=abc&string_field=test", nil) + if err != nil { + panic(err) + } + return req + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } + return nil + }, + }, + { + Pattern: "/test-query-invalid-bool", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(query magic.Query[QueryTypes]) { + panic("handler should not be called for invalid input") + }) + }, + MakeRequest: func(base string) *http.Request { + req, err := http.NewRequest("GET", base+"/test-query-invalid-bool?int_field=123&string_field=test&bool_field=not_a_bool", nil) + if err != nil { + panic(err) + } + return req + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } + return nil + }, + }, + { + Pattern: "/test-query-invalid-float", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(query magic.Query[QueryTypes]) { + panic("handler should not be called for invalid input") + }) + }, + MakeRequest: func(base string) *http.Request { + req, err := http.NewRequest("GET", base+"/test-query-invalid-float?int_field=123&string_field=test&bool_field=true&float64_field=not_a_float", nil) + if err != nil { + panic(err) + } + return req + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } + return nil + }, + }, + { + Pattern: "/test-form", + CreateHandler: func() http.Handler { + type TestForm struct { + Name string `form:"name"` + Age int `form:"age"` + IsPro bool `form:"is_pro"` + } + return MustMagicHandler(func(form magic.Form[TestForm]) { + if form.Data.Name != "test" { + panic(fmt.Sprintf("expected name 'test', got '%s'", form.Data.Name)) + } + if form.Data.Age != 30 { + panic(fmt.Sprintf("expected age 30, got %d", form.Data.Age)) + } + if !form.Data.IsPro { + panic(fmt.Sprintf("expected is_pro true, got %t", form.Data.IsPro)) + } + }) + }, + MakeRequest: func(base string) *http.Request { + body := strings.NewReader("name=test&age=30&is_pro=true") + req, err := http.NewRequest("POST", base+"/test-form", body) + if err != nil { + panic(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + }, + { + Pattern: "/test-form-advanced", + CreateHandler: func() http.Handler { + type TestFormAdvanced struct { + Names []string `form:"names"` + Ages [2]int `form:"ages"` + ID string `form:"id,match=^[a-f0-9]{8}$"` + Optional *string `form:"optional"` + Numbers []int `form:"numbers"` + Required string `form:"required"` + FixedSize [3]bool `form:"fixed_size"` + } + return MustMagicHandler(func(form magic.Form[TestFormAdvanced]) { + // Test successful extraction + if len(form.Data.Names) != 2 || form.Data.Names[0] != "alice" || form.Data.Names[1] != "bob" { + panic(fmt.Sprintf("expected names [alice bob], got %v", form.Data.Names)) + } + if form.Data.Ages[0] != 25 || form.Data.Ages[1] != 30 { + panic(fmt.Sprintf("expected ages [25 30], got %v", form.Data.Ages)) + } + if form.Data.ID != "abcdef01" { + panic(fmt.Sprintf("expected id 'abcdef01', got '%s'", form.Data.ID)) + } + if form.Data.Optional == nil || *form.Data.Optional != "present" { + panic(fmt.Sprintf("expected optional 'present', got %v", form.Data.Optional)) + } + if len(form.Data.Numbers) != 3 || form.Data.Numbers[0] != 1 || form.Data.Numbers[1] != 2 || form.Data.Numbers[2] != 3 { + panic(fmt.Sprintf("expected numbers [1 2 3], got %v", form.Data.Numbers)) + } + if form.Data.Required != "required_value" { + panic(fmt.Sprintf("expected required 'required_value', got '%s'", form.Data.Required)) + } + if form.Data.FixedSize[0] != true || form.Data.FixedSize[1] != false || form.Data.FixedSize[2] != true { + panic(fmt.Sprintf("expected fixed_size [true false true], got %v", form.Data.FixedSize)) + } + }) + }, + MakeRequest: func(base string) *http.Request { + body := strings.NewReader("names=alice&names=bob&ages=25&ages=30&id=abcdef01&optional=present&numbers=1&numbers=2&numbers=3&required=required_value&fixed_size=true&fixed_size=false&fixed_size=true") + req, err := http.NewRequest("POST", base+"/test-form-advanced", body) + if err != nil { + panic(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + }, + { + Pattern: "/test-form-advanced-errors", + CreateHandler: func() http.Handler { + type TestFormAdvancedErrors struct { + Required string `form:"required"` + } + return MustMagicHandler(func(form magic.Form[TestFormAdvancedErrors]) { + // This handler should not be reached in case of errors, + // the magic handler should return an error before calling this. + panic("handler should not be called for error test cases") + }) + }, + MakeRequest: func(base string) *http.Request { + // This request is intentionally malformed to trigger errors + body := strings.NewReader("names=alice&ages=25&ages=30&id=invalid-id&numbers=1&required=required_value&fixed_size=true&fixed_size=false") // Missing one fixed_size value, invalid ID + req, err := http.NewRequest("POST", base+"/test-form-advanced-errors", body) + if err != nil { + panic(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + CheckResponse: func(resp *http.Response) error { + // Expecting an error response due to invalid form data + if resp.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } + // Further checks on the error message body could be added here + return nil + }, + }, + { + Pattern: "/test-json-non-struct", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(body magic.Json[string]) string { + return "Received: " + body.Data + }) + }, + MakeRequest: func(base string) *http.Request { + body := strings.NewReader(`"a simple string"`) + req, err := http.NewRequest("POST", base+"/test-json-non-struct", body) + if err != nil { + panic(err) + } + req.Header.Set("Content-Type", "application/json") + return req + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode) + } + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + bodyString := string(bodyBytes) + // The response is the string "Received: " + the JSON string, + // which includes the quotes from the original JSON body. + expected := `"Received: \"a simple string\""` + if !strings.Contains(bodyString, expected) { + return fmt.Errorf("expected response body to contain '%s', got '%s'", expected, bodyString) + } + return nil + }, + }, + { + Pattern: "/test-from-request-error", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(e ErrorFromRequest) { + panic("handler should not be called for error test cases") + }) + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusInternalServerError { + return fmt.Errorf("expected status %d, got %d", http.StatusInternalServerError, resp.StatusCode) + } + return nil + }, + }, + { + Pattern: "/test-ptr-from-request-error", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(e *PtrErrorFromRequest) { + panic("handler should not be called for error test cases") + }) + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusInternalServerError { + return fmt.Errorf("expected status %d, got %d", http.StatusInternalServerError, resp.StatusCode) + } + return nil + }, + }, + { + Pattern: "/test-form-invalid-int", + CreateHandler: func() http.Handler { + type TestForm struct { + Age int `form:"age"` + } + return MustMagicHandler(func(form magic.Form[TestForm]) { + panic("handler should not be called for invalid input") + }) + }, + MakeRequest: func(base string) *http.Request { + body := strings.NewReader("age=abc") + req, err := http.NewRequest("POST", base+"/test-form-invalid-int", body) + if err != nil { + panic(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } + return nil + }, + }, + { + Pattern: "/test-form-array-too-few", + CreateHandler: func() http.Handler { + type TestForm struct { + Ages [2]int `form:"ages"` + } + return MustMagicHandler(func(form magic.Form[TestForm]) { + panic("handler should not be called for invalid input") + }) + }, + MakeRequest: func(base string) *http.Request { + body := strings.NewReader("ages=25") + req, err := http.NewRequest("POST", base+"/test-form-array-too-few", body) + if err != nil { + panic(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } + return nil + }, + }, + { + Pattern: "/test-query-slice-invalid-entry", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(query magic.Query[QueryTypes]) { + panic("handler should not be called for invalid input") + }) + }, + MakeRequest: func(base string) *http.Request { + req, err := http.NewRequest("GET", base+"/test-query-slice-invalid-entry?int_slice=1&int_slice=abc&int_slice=3", nil) + if err != nil { + panic(err) + } + return req + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } + return nil + }, + }, + { + Pattern: "/test-double-body-read", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(body io.Reader, jsonBody magic.Json[HelloWorldRequest]) { + // This handler should panic because we are trying to read the body twice. + panic("handler should not be called") + }) + }, + ExpectPanicOnCreateHandler: true, + }, + { + Pattern: "/websocket-fail", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(ws *websocket.Conn) { + // This handler should not be called if websocket upgrade fails. + panic("handler should not be called") + }) + }, + MakeRequest: func(base string) *http.Request { + // A regular GET request without the upgrade headers will fail the websocket upgrade. + req, err := http.NewRequest("GET", base+"/websocket-fail", nil) + if err != nil { + panic(err) + } + return req + }, + CheckResponse: func(resp *http.Response) error { + if resp.StatusCode != http.StatusBadRequest { + return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode) + } + return nil + }, }, } ) diff --git a/magic/autodecode.go b/magic/autodecode.go new file mode 100644 index 0000000..049075e --- /dev/null +++ b/magic/autodecode.go @@ -0,0 +1,37 @@ +package magic + +import ( + "net/http" + "strings" +) + +type ErrUknownContentType string + +func (err ErrUknownContentType) Error() string { + return "unknown content-type: " + string(err) +} + +type AutoDecode[T any] struct { + Data T +} + +func (d *AutoDecode[T]) FromRequest(r *http.Request) error { + contentType, _, _ := strings.Cut(r.Header.Get("Content-Type"), ";") + switch contentType { + case "application/json": + decoder := &Json[T]{} + if err := decoder.FromRequest(r); err != nil { + return err + } + d.Data = decoder.Data + case "multipart/form-data", "application/x-www-form-urlencoded": + decoder := &Form[T]{} + if err := decoder.FromRequest(r); err != nil { + return err + } + d.Data = decoder.Data + } + return ErrUknownContentType(contentType) +} + +func (d AutoDecode[T]) TakeRequestBody() {} diff --git a/magic/path_query_values.go b/magic/common.go similarity index 82% rename from magic/path_query_values.go rename to magic/common.go index 49e10c7..fbe4b06 100644 --- a/magic/path_query_values.go +++ b/magic/common.go @@ -4,6 +4,8 @@ import ( "reflect" "strconv" "strings" + + "github.com/go-playground/validator/v10" ) var ( @@ -54,21 +56,21 @@ var ( if v, err := strconv.ParseInt(s, 10, 16); err != nil { return nil, err } else { - return int8(v), nil + return int16(v), nil } }, reflect.Int32: func(s string) (any, error) { if v, err := strconv.ParseInt(s, 10, 32); err != nil { return nil, err } else { - return int8(v), nil + return int32(v), nil } }, reflect.Int64: func(s string) (any, error) { if v, err := strconv.ParseInt(s, 10, 64); err != nil { return nil, err } else { - return int8(v), nil + return v, nil } }, @@ -108,4 +110,24 @@ var ( } }, } + + defaultValidator = NewValidator("") ) + +func NewValidator(nameFromTag string) *validator.Validate { + validate := validator.New( + validator.WithRequiredStructEnabled(), + ) + + if nameFromTag != "" { + validate.RegisterTagNameFunc(func(fld reflect.StructField) string { + name := strings.SplitN(fld.Tag.Get(nameFromTag), ",", 2)[0] + if name == "-" { + return "" + } + return name + }) + } + + return validate +} diff --git a/magic/extractors.go b/magic/extractors.go index a07dd18..c2f9e5e 100644 --- a/magic/extractors.go +++ b/magic/extractors.go @@ -11,6 +11,8 @@ import ( var ( extractors = make(map[reflect.Type]func(*http.Request) (any, error)) + extractorsTakeBody = make(map[reflect.Type]bool) + extractorsTakesResponseWriter = make(map[reflect.Type]func(http.ResponseWriter, *http.Request) (any, error)) ) @@ -18,6 +20,12 @@ type FromRequest interface { FromRequest(*http.Request) error } +// Marker interface +type TakeRequestBody interface { + TakeRequestBody() +} + +// Marker interface type TakeResponseWriter interface { TakeResponseWriter(http.ResponseWriter) } @@ -43,6 +51,42 @@ func RegisterExtractorGeneric[T any](extractor func(*http.Request) (any, error)) RegisterExtractor(reflect.TypeOf(pointerToT).Elem(), extractor) } +func RegisterExtractorThatTakesBody(o any, extractor func(*http.Request) (any, error)) { + if t, ok := o.(reflect.Type); ok { + slog.With("type", t.String()).Debug("extractor type registered with reflect.Type") + extractors[t] = extractor + extractorsTakeBody[t] = true + } else if v, ok := o.(reflect.Value); ok { + slog.With("type", v.Type().String(), "value", v.Interface()).Debug("extractor type registered with reflect.Value") + + extractors[v.Type()] = extractor + extractorsTakeBody[v.Type()] = true + } else { + t := reflect.TypeOf(o) + slog.With("type", t.String(), "value", o).Debug("extractor type registered with object") + + extractors[t] = extractor + extractorsTakeBody[t] = true + } +} + +func RegisterExtractorThatTakesBodyGeneric[T any](extractor func(*http.Request) (any, error)) { + var pointerToT *T + RegisterExtractorThatTakesBody(reflect.TypeOf(pointerToT).Elem(), extractor) +} + +func IsTakeBody(t reflect.Type) bool { + if isTakeBody := extractorsTakeBody[t]; isTakeBody { + return isTakeBody + } + + if _, ok := util.Implements[TakeRequestBody](t); ok { + return ok + } + + return false +} + func RegisterExtractorThatTakesResponseWriter(o any, extractor func(http.ResponseWriter, *http.Request) (any, error)) { if t, ok := o.(reflect.Type); ok { slog.With("type", t.String()).Debug("extractor type registered with reflect.Type") @@ -70,62 +114,57 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any _, ptrIsTakeResponseWriter := util.Implements[TakeResponseWriter](t.Elem()) isTakeResponseWriter = isTakeResponseWriter || ptrIsTakeResponseWriter } - if _, ok := util.Implements[FromRequest](t); ok { - if t.Kind() == reflect.Pointer { - // T.Implement(FromRequest) and T is Pointer - // create new actual T and call T.FromRequest on it - // if error happens, no error is returned. just return nil - return func(w http.ResponseWriter, r *http.Request) (any, error) { - // var t T - // return t, t.FromRequest(r) - z := reflect.New(t.Elem()) - if isTakeResponseWriter { - z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)}) + fromRequestType := reflect.TypeOf((*FromRequest)(nil)).Elem() + if t.Implements(fromRequestType) { + return func(w http.ResponseWriter, r *http.Request) (any, error) { + val := reflect.New(t).Elem() + if t.Kind() == reflect.Pointer { + val = reflect.New(t.Elem()) + } + + if isTakeResponseWriter { + if taker, ok := val.Addr().Interface().(TakeResponseWriter); ok { + taker.TakeResponseWriter(w) } + } - results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)}) - if err, ok := results[0].Interface().(error); ok && err != nil { - if errResponse, ok := err.(ErrorResponse); ok { - return reflect.Zero(t).Interface(), errResponse + if fromR, ok := val.Interface().(FromRequest); ok { + if err := fromR.FromRequest(r); err != nil { + return nil, err + } + } else if val.CanAddr() { + if fromR, ok := val.Addr().Interface().(FromRequest); ok { + if err := fromR.FromRequest(r); err != nil { + return nil, err } - return reflect.Zero(t).Interface(), nil } - return z.Interface(), nil - }, isTakeResponseWriter - } - // T.Implement(FromRequest) and T is not Pointer - // create zero T and call T.FromRequest directly on it - return func(w http.ResponseWriter, r *http.Request) (any, error) { - // var t T - // return t, t.FromRequest(r) - z := reflect.Zero(t) - - if isTakeResponseWriter { - z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)}) } + if t.Kind() == reflect.Pointer { - results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)}) - if err, ok := results[0].Interface().(error); ok { - return z.Interface(), err + return val.Interface(), nil } - return z.Interface(), nil + return val.Addr().Interface(), nil }, isTakeResponseWriter - } else if v, ok := util.PointerImplements[FromRequest](t); ok { + } + + if reflect.PointerTo(t).Implements(fromRequestType) { return func(w http.ResponseWriter, r *http.Request) (any, error) { - // var t T - // return t, t.FromRequest(r) - z := reflect.New(v.Type().Elem()) + val := reflect.New(t) if isTakeResponseWriter { - z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)}) + if taker, ok := val.Interface().(TakeResponseWriter); ok { + taker.TakeResponseWriter(w) + } } - results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)}) - if err := results[0].Interface(); err == nil { - return z.Elem().Interface(), nil + if fromR, ok := val.Interface().(FromRequest); ok { + if err := fromR.FromRequest(r); err != nil { + return nil, err + } } - return z.Elem().Interface(), results[0].Interface().(error) + + return val.Elem().Interface(), nil }, isTakeResponseWriter } @@ -142,7 +181,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any if err != nil { return nil, err } - return &v, nil + // Create a pointer to the value + vp := reflect.New(reflect.TypeOf(v)) + vp.Elem().Set(reflect.ValueOf(v)) + return vp.Interface(), nil }, false } } @@ -160,7 +202,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any if err != nil { return nil, err } - return &v, nil + // Create a pointer to the value + vp := reflect.New(reflect.TypeOf(v)) + vp.Elem().Set(reflect.ValueOf(v)) + return vp.Interface(), nil }, true } } diff --git a/magic/form.go b/magic/form.go new file mode 100644 index 0000000..8f190e4 --- /dev/null +++ b/magic/form.go @@ -0,0 +1,206 @@ +package magic + +import ( + "fmt" + "net/http" + "reflect" + "regexp" + "strconv" + "strings" +) + +var ( + cachedStructFormFields = make(map[reflect.Type][]structFormFields) + cachedUnsupportedFormFields = make(map[reflect.Type]structFormFields) +) + +type Form[T any] struct { + Data T +} + +type structFormFields struct { + PathKey string + FieldKey string + + Type reflect.Type + MinValues, MaxValues int + + Match *regexp.Regexp +} + +type UnsupportedFormType struct { + inner reflect.Type +} + +func (err UnsupportedFormType) Error() string { + return "unsupported type for Form: " + err.inner.String() +} + +type FormValueNotFitIn struct { + Count int + + field structFormFields +} + +func (err FormValueNotFitIn) Error() string { + return fmt.Sprintf("Form %v is expected to has %v~%v values, got %v", err.field.PathKey, err.field.MinValues, err.field.MaxValues, err.Count) +} + +type InvalidFormValue struct { + Match *regexp.Regexp + Value string +} + +func (err InvalidFormValue) Error() string { + return fmt.Sprintf("value not matched with regexp %v: %v", err.Match, err.Value) +} + +func findReflectFormFields(v reflect.Value) ([]structFormFields, error) { + if cached, ok := cachedStructFormFields[v.Type()]; ok { + return cached, nil + } + if cached, ok := cachedUnsupportedFormFields[v.Type()]; ok { + return nil, UnsupportedPathValueType{inner: cached.Type} + } + + var fields []structFormFields + if v.Kind() == reflect.Struct { + t := v.Type() + for idx := 0; idx < t.NumField(); idx++ { + field := t.Field(idx) + + pathKey := field.Name + var match *regexp.Regexp + + if tags, ok := field.Tag.Lookup("form"); ok { + parts := strings.Split(strings.TrimSpace(tags), ",") + if len(parts) == 0 { + // do nothing + } else { + if parts[0] != "" { + pathKey = parts[0] + } + + for _, part := range parts[1:] { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "match=") { + part := strings.TrimPrefix(part, "match=") + if unquoted, err := strconv.Unquote(part); err != nil { + return nil, err + } else { + part = unquoted + } + exp, err := regexp.Compile(part) + if err != nil { + return nil, err + } + match = exp + } + } + } + } + + fieldType := field.Type + maxValues, minValues := 1, 1 + + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + minValues = 0 + } else if fieldType.Kind() == reflect.Array { + maxValues, minValues = fieldType.Len(), fieldType.Len() + fieldType = fieldType.Elem() + } else if fieldType.Kind() == reflect.Slice { + fieldType = fieldType.Elem() + maxValues, minValues = -1, 0 + } + + if pathValueConvertTable[fieldType.Kind()] == nil { + if _, ok := fieldType.MethodByName("FromString"); !ok { + cachedUnsupportedPathValueFields[t] = field + return nil, UnsupportedPathValueType{inner: field.Type} + } + } + fields = append(fields, structFormFields{ + PathKey: pathKey, + FieldKey: field.Name, + + Type: fieldType, + MaxValues: maxValues, + MinValues: minValues, + + Match: match, + }) + + cachedStructFormFields[t] = fields + } + } else { + return nil, UnsupportedFormType{inner: v.Type()} + } + + return fields, nil +} + +func (form *Form[T]) FromRequest(r *http.Request) error { + v := reflect.ValueOf(form).Elem().FieldByName("Data") + fields, err := findReflectFormFields(v) + if err != nil { + return err + } + + if err := r.ParseMultipartForm(32 << 20); err != nil { + return err + } + q := r.PostForm + + for _, field := range fields { + values := q[field.PathKey] + if len(values) < field.MinValues || (field.MaxValues > 0 && len(values) > field.MaxValues) { + return FormValueNotFitIn{Count: len(values), field: field} + } + + structField := v.FieldByName(field.FieldKey) + + if structField.Kind() == reflect.Slice && structField.Len() < len(values) { + if structField.Cap() < len(values) { + structField.SetCap(len(values)) + } + structField.SetLen(len(values)) + } + + for idx, value := range values { + if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice { + if idx != len(values)-1 { + // only the last takes effect + continue + } + } + // check match regexp + if field.Match != nil && !field.Match.MatchString(value) { + return InvalidFormValue{Match: field.Match, Value: value} + } + + // convert to elem value + convert := pathValueConvertTable[field.Type.Kind()] + result, err := convert(value) + if err != nil { + return err + } + reflectValue := reflect.ValueOf(result) + + if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice { + if field.MinValues == 0 { + newReflectValue := reflect.New(field.Type) + newReflectValue.Elem().Set(reflectValue) + reflectValue = newReflectValue + } + structField.Set(reflectValue) + } else { + structField.Index(idx).Set(reflectValue) + } + } + } + + return nil +} + +func (form Form[T]) TakeRequestBody() {} diff --git a/magic/json.go b/magic/json.go index e3f3ccf..42ef91c 100644 --- a/magic/json.go +++ b/magic/json.go @@ -2,8 +2,15 @@ package magic import ( "encoding/json" - "io" "net/http" + "reflect" + "strings" + + "github.com/go-playground/validator/v10" +) + +var ( + jsonValidator = NewValidator("json") ) type JsonDecodeError struct { @@ -18,10 +25,25 @@ func (err JsonDecodeError) WriteResponse(rw http.ResponseWriter) { rw.WriteHeader(http.StatusBadRequest) } -type Json[T any] struct { - Data T +type ValidateError []string + +func (err ValidateError) Error() string { + return "invalid or missing fields: [" + strings.Join(err, ", ") + "]" } +func (v ValidateError) WriteResponse(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(v.Error())) +} + +type Json[T any] struct { + Data T `json:"data"` +} + +// TakeRequestBody is a marker method to indicate that this extractor takes the request body. +func (j Json[T]) TakeRequestBody() {} + func NewJson[T any](data T) Json[T] { return Json[T]{ Data: data, @@ -29,35 +51,41 @@ func NewJson[T any](data T) Json[T] { } func (data *Json[T]) FromRequest(request *http.Request) error { + if request.Body == nil { + panic("body is already taken") + } + bodyReader := request.Body defer bodyReader.Close() - request.Body = http.NoBody + request.Body = nil if err := json.NewDecoder(bodyReader).Decode(&data.Data); err != nil { return JsonDecodeError{inner: err} } + value := reflect.ValueOf(data.Data) + if value.Kind() == reflect.Struct || (value.Kind() == reflect.Ptr && value.Elem().Kind() == reflect.Struct) { + if err := jsonValidator.StructCtx(request.Context(), data.Data); err == nil { + return nil + } else if validateErr, isValidateErr := err.(validator.ValidationErrors); isValidateErr { + fields := make([]string, len(validateErr)) + for idx, err := range validateErr { + fields[idx] = err.Field() + } + return ValidateError(fields) + } else { + return err + } + } + return nil } -type counter struct { - counter int64 -} - -func (c *counter) Write(b []byte) (int, error) { - n := len(b) - c.counter += int64(n) - - return n, nil -} - -func (container Json[T]) Write(w io.Writer) (int64, error) { - counter := &counter{} - err := json.NewEncoder(io.MultiWriter(counter, w)).Encode(container.Data) - return counter.counter, err -} - -func (json Json[T]) WriteResponse(w http.ResponseWriter) { - json.Write(w) +func (data *Json[T]) WriteResponse(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + if data == nil { + return + } + json.NewEncoder(w).Encode(data) } diff --git a/magic/map.go b/magic/map.go index 8393682..756a435 100644 --- a/magic/map.go +++ b/magic/map.go @@ -8,5 +8,6 @@ import ( type Map map[string]any func (m Map) RespondWriter(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(m) } diff --git a/magic/query.go b/magic/query.go index 7eb5669..a6676e7 100644 --- a/magic/query.go +++ b/magic/query.go @@ -159,7 +159,7 @@ func (query *Query[T]) FromRequest(r *http.Request) error { if structField.Kind() == reflect.Slice && structField.Len() < len(values) { if structField.Cap() < len(values) { - structField.SetCap(len(values)) + structField.Grow(len(values) - structField.Cap()) } structField.SetLen(len(values)) } diff --git a/magic/request.go b/magic/request.go index c7ff700..8f7761f 100644 --- a/magic/request.go +++ b/magic/request.go @@ -2,7 +2,10 @@ package magic import ( "context" + "io" "net/http" + + "git.jeffthecoder.xyz/public/lazyhandler/middleware/cleanup" ) func init() { @@ -12,4 +15,49 @@ func init() { RegisterExtractorGeneric[context.Context](func(r *http.Request) (any, error) { return r.Context(), nil }) + // RegisterExtractorThatTakesBodyGeneric for io.ReadCloser extracts the request body. + // + // IMPORTANT: This extractor consumes the request body (r.Body). After this extractor + // is used, r.Body will be set to nil to prevent it from being read multiple times. + // Any subsequent attempt to read the body will fail. + // + // A cleanup function is automatically registered to close the body once the request + // is finished. This requires the cleanup middleware to be present in the handler chain. + RegisterExtractorThatTakesBodyGeneric[io.ReadCloser](func(r *http.Request) (any, error) { + if r.Body == nil { + panic("body is already taken") + } + + body := r.Body + r.Body = nil + + cleanup.Register(r.Context(), cleanup.CleanupFunc(func() { + body.Close() + })) + + return body, nil + }) + + // RegisterExtractorThatTakesBodyGeneric for io.Reader extracts the request body. + // + // IMPORTANT: This extractor consumes the request body (r.Body). After this extractor + // is used, r.Body will be set to nil to prevent it from being read multiple times. + // Any subsequent attempt to read the body will fail. + // + // A cleanup function is automatically registered to close the body once the request + // is finished. This requires the cleanup middleware to be present in the handler chain. + RegisterExtractorThatTakesBodyGeneric[io.Reader](func(r *http.Request) (any, error) { + if r.Body == nil { + panic("body is already taken") + } + + body := r.Body + r.Body = nil + + cleanup.Register(r.Context(), cleanup.CleanupFunc(func() { + body.Close() + })) + + return body, nil + }) } diff --git a/magic/response.go b/magic/response.go index 5cff84f..523639c 100644 --- a/magic/response.go +++ b/magic/response.go @@ -22,6 +22,26 @@ type ErrorResponse interface { WriteResponse(http.ResponseWriter) } +func WrapSimpleErrorWithStatus(message string, status int) ErrorResponse { + return simpleErrorResponseWithoutBody{ + message: message, + status: status, + } +} + +type simpleErrorResponseWithoutBody struct { + message string + status int +} + +func (err simpleErrorResponseWithoutBody) Error() string { + return err.message +} + +func (err simpleErrorResponseWithoutBody) WriteResponse(w http.ResponseWriter) { + w.WriteHeader(err.status) +} + func init() { RegisterExtractorThatTakesResponseWriterGeneric[http.ResponseWriter](func(w http.ResponseWriter, r *http.Request) (any, error) { return w, nil diff --git a/magic/state.go b/magic/state.go index 443217a..73d55e5 100644 --- a/magic/state.go +++ b/magic/state.go @@ -2,12 +2,8 @@ package magic import "net/http" -type State[T any] struct { - Data T -} - func RegisterState[T any](data T) { - RegisterExtractor(State[T]{}, func(r *http.Request) (any, error) { - return State[T]{Data: data}, nil + RegisterExtractor(data, func(r *http.Request) (any, error) { + return data, nil }) } diff --git a/middleware/cleanup/cleanup.go b/middleware/cleanup/cleanup.go new file mode 100644 index 0000000..188bdf8 --- /dev/null +++ b/middleware/cleanup/cleanup.go @@ -0,0 +1,61 @@ +package cleanup + +import ( + "context" + "log/slog" + "net/http" + "reflect" + + "git.jeffthecoder.xyz/public/lazyhandler/middleware" +) + +type ctxKey int + +const ( + cleanupCtxKey ctxKey = iota +) + +type CleanupContext struct { + funcs []Cleanup +} + +type Cleanup interface { + Name() string + Cleanup() +} + +type CleanupFunc func() + +func (fn CleanupFunc) Name() string { + return reflect.TypeOf(fn).String() +} + +func (fn CleanupFunc) Cleanup() { + defer func() { + if v := recover(); v != nil { + slog.With("v", v).Warn("cleanup panicked") + } + }() + + fn() +} + +func Collect() middleware.Middleware { + return middleware.WrapFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := &CleanupContext{} + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cleanupCtxKey, ctx))) + for _, cleanupFunc := range ctx.funcs { + defer cleanupFunc.Cleanup() + } + }) + }) +} + +// Register adds a Cleanup function to the CleanupContext in the provided context. +// If the CleanupContext is not found in the context, the Cleanup function is not registered. +func Register(ctx context.Context, c Cleanup) { + if ctx, ok := ctx.Value(cleanupCtxKey).(*CleanupContext); ok { + ctx.funcs = append(ctx.funcs, c) + } +} diff --git a/middleware/cleanup/cleanup_test.go b/middleware/cleanup/cleanup_test.go new file mode 100644 index 0000000..c1f938d --- /dev/null +++ b/middleware/cleanup/cleanup_test.go @@ -0,0 +1,54 @@ +package cleanup + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestCleanupMiddleware(t *testing.T) { + var executionOrder []string + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Register(r.Context(), CleanupFunc(func() { + executionOrder = append(executionOrder, "first") + })) + Register(r.Context(), CleanupFunc(func() { + executionOrder = append(executionOrder, "second") + })) + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/", nil) + rr := httptest.NewRecorder() + + middleware := Collect().WrapHandler(handler) + middleware.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("handler returned wrong status code: got %v want %v", + rr.Code, http.StatusOK) + } + + expectedOrder := "second,first" + actualOrder := strings.Join(executionOrder, ",") + if actualOrder != expectedOrder { + t.Errorf("cleanup functions executed in wrong order: got %s want %s", + actualOrder, expectedOrder) + } +} + +func TestCleanupMissingContext(t *testing.T) { + // This test ensures that Register does not panic when the context is missing. + // The function should fail silently. + defer func() { + if r := recover(); r != nil { + t.Errorf("The code panicked when it should not have") + } + }() + + req := httptest.NewRequest("GET", "/", nil) + // No middleware, so no context + Register(req.Context(), CleanupFunc(func() {})) +} diff --git a/middleware/httplog/log.go b/middleware/httplog/log.go index fa7aa74..a57b9cc 100644 --- a/middleware/httplog/log.go +++ b/middleware/httplog/log.go @@ -75,6 +75,10 @@ func (log Log) WrapHandler(next http.Handler) http.Handler { }) } +// Logger retrieves the slog.Logger from the request context. +// If the logger is not found in the context (e.g., the httplog middleware is not used), +// it creates and returns a new logger with basic request information. +// Using the httplog middleware is recommended to ensure the configured logger is available. func Logger(r *http.Request) *slog.Logger { if logger, ok := r.Context().Value(loggerKey).(*slog.Logger); ok { return logger.With("time", time.Now()) diff --git a/middleware/httplog/log_test.go b/middleware/httplog/log_test.go new file mode 100644 index 0000000..18c7fb7 --- /dev/null +++ b/middleware/httplog/log_test.go @@ -0,0 +1,85 @@ +package httplog + +import ( + "bufio" + "bytes" + "log/slog" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +type mockHijacker struct { + *httptest.ResponseRecorder + hijacked bool +} + +func (m *mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + m.hijacked = true + // Return dummy values, not used in this test + return nil, nil, nil +} + +func TestResponseRecorder_Hijack(t *testing.T) { + recorder := &responseRecorder{ + ResponseWriter: &mockHijacker{ResponseRecorder: httptest.NewRecorder()}, + } + + _, _, err := recorder.Hijack() + if err != nil { + t.Fatalf("Hijack failed: %v", err) + } + + if recorder.ResponseWriter != nil { + t.Error("ResponseWriter should be nil after Hijack") + } +} + +func TestLogger(t *testing.T) { + var buf bytes.Buffer + slog.SetDefault(slog.New(slog.NewTextHandler(&buf, nil))) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger := Logger(r) + logger.Info("test message") + w.WriteHeader(http.StatusOK) + }) + + t.Run("with middleware", func(t *testing.T) { + buf.Reset() + logMiddleware := Log{LogStart: true, LogFinish: true} + wrappedHandler := logMiddleware.WrapHandler(handler) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + + if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) { + t.Error("expected start log message, but not found") + } + if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) { + t.Error("expected handler log message, but not found") + } + if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) { + t.Error("expected finish log message, but not found") + } + }) + + t.Run("without middleware", func(t *testing.T) { + buf.Reset() + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) { + t.Error("expected handler log message, but not found") + } + if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) { + t.Error("unexpected start log message found") + } + if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) { + t.Error("unexpected finish log message found") + } + }) +} diff --git a/middleware/recover/recover.go b/middleware/recover/recover.go index 9103654..40f6d6f 100644 --- a/middleware/recover/recover.go +++ b/middleware/recover/recover.go @@ -42,3 +42,20 @@ func DebugPanicHandler(w http.ResponseWriter, r *http.Request, err any) { func Debug() middleware.Middleware { return Recover(DebugPanicHandler) } + +// ProductionPanicHandler is a PanicResponseFunc that provides a generic error message +// in production environments and logs the detailed panic error. +func ProductionPanicHandler(w http.ResponseWriter, r *http.Request, err any) { + httplog.Logger(r).With("panic", err).Error("request panicked") + + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(map[string]any{ + "error": "Internal Server Error", + }) +} + +// Production returns a middleware that recovers from panics and handles them +// with ProductionPanicHandler. +func Production() middleware.Middleware { + return Recover(ProductionPanicHandler) +} diff --git a/middleware/recover/recover_test.go b/middleware/recover/recover_test.go new file mode 100644 index 0000000..6141fb5 --- /dev/null +++ b/middleware/recover/recover_test.go @@ -0,0 +1,56 @@ +package recover + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func panicHandler(w http.ResponseWriter, r *http.Request) { + panic("test panic") +} + +func TestDebugRecover(t *testing.T) { + handler := http.HandlerFunc(panicHandler) + wrappedHandler := Debug().WrapHandler(handler) + + req := httptest.NewRequest("GET", "/panic", nil) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code) + } + + var resp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } + + if resp["error"] != "test panic" { + t.Errorf("expected error 'test panic', got '%v'", resp["error"]) + } +} + +func TestProductionRecover(t *testing.T) { + handler := http.HandlerFunc(panicHandler) + wrappedHandler := Production().WrapHandler(handler) + + req := httptest.NewRequest("GET", "/panic", nil) + rr := httptest.NewRecorder() + wrappedHandler.ServeHTTP(rr, req) + + if rr.Code != http.StatusInternalServerError { + t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code) + } + + var resp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response body: %v", err) + } + + if resp["error"] != "Internal Server Error" { + t.Errorf("expected error 'Internal Server Error', got '%v'", resp["error"]) + } +}