sync from project

This commit is contained in:
2025-06-18 10:12:19 +08:00
parent 61ffeeb3b8
commit fb579e8689
20 changed files with 1332 additions and 103 deletions

13
go.mod
View File

@ -3,8 +3,19 @@ module git.jeffthecoder.xyz/public/lazyhandler
go 1.24.0
require (
github.com/go-playground/validator/v10 v10.26.0
github.com/go-session/session/v3 v3.2.1
github.com/gorilla/websocket v1.5.3
)
require github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 // indirect
require (
github.com/bytedance/gopkg v0.1.2 // indirect
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
)

38
go.sum
View File

@ -1,6 +1,17 @@
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8=
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/bytedance/gopkg v0.1.2 h1:8o2feYuxknDpN+O7kPwvSXfMEKfYvJYiA2K7aonoMEQ=
github.com/bytedance/gopkg v0.1.2/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/go-session/session/v3 v3.2.1 h1:APQf5JFW84+bhbqRjEZO8J+IppSgT1jMQTFI/XVyIFY=
github.com/go-session/session/v3 v3.2.1/go.mod h1:RftEBbyuzqkNCAxIrCLJe+rfBqB/4G11qxq9KYKrx4M=
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0=
@ -9,14 +20,23 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0=
github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -12,10 +12,13 @@ import (
"git.jeffthecoder.xyz/public/lazyhandler/util"
)
type ErrArgumentIsNotExtractable int
type ErrArgumentIsNotExtractable struct {
Index int
Type reflect.Type
}
func (err ErrArgumentIsNotExtractable) Error() string {
return fmt.Sprintf("argument %v is not extractable", int(err))
return fmt.Sprintf("argument %v is not extractable: %s", err.Index, err.Type.String())
}
type ErrReturnValueNotConvertableIntoResponsePart int
@ -27,6 +30,7 @@ func (err ErrReturnValueNotConvertableIntoResponsePart) Error() string {
var (
ErrNotAFunc = errors.New("not a function")
ErrDuplicateResponseWriterExtractor = errors.New("duplicate response writer extractor")
ErrDuplicateRequestBodyExtractor = errors.New("duplicate request body extractor")
ErrResponseWriterCannotBeExtracted = errors.New("http.ResponseWriter extractor must not exists if function has return value")
)
@ -50,6 +54,7 @@ func MagicHandler(fn any) (http.Handler, error) {
}
responseWriterExtracted := -1
requestBodyExtracted := -1
extractors := make([]func(http.ResponseWriter, *http.Request) (any, error), t.NumIn())
for idx := 0; idx < t.NumIn(); idx++ {
@ -57,7 +62,10 @@ func MagicHandler(fn any) (http.Handler, error) {
extractor, isTakeResponseWriter := magic.GetExtractor(in)
if extractor == nil {
return nil, ErrArgumentIsNotExtractable(idx)
return nil, ErrArgumentIsNotExtractable{
Index: idx,
Type: in,
}
}
if isTakeResponseWriter {
if responseWriterExtracted >= 0 {
@ -65,6 +73,14 @@ func MagicHandler(fn any) (http.Handler, error) {
}
responseWriterExtracted = idx
}
if magic.IsTakeBody(in) {
if requestBodyExtracted >= 0 {
return nil, ErrDuplicateRequestBodyExtractor
}
requestBodyExtracted = idx
}
extractors[idx] = extractor
}
@ -98,6 +114,14 @@ func MagicHandler(fn any) (http.Handler, error) {
if _, ok := util.Implements[magic.RespondWriter](out); ok {
continue
}
if _, ok := util.Implements[magic.RespondWriter](reflect.PointerTo(out)); ok {
continue
}
if out.Kind() == reflect.Pointer {
if _, ok := util.Implements[magic.RespondWriter](out.Elem()); ok {
continue
}
}
// last is error
if _, ok := util.Implements[error](out); ok && idx == t.NumOut()-1 {
@ -108,12 +132,6 @@ func MagicHandler(fn any) (http.Handler, error) {
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var closers []io.Closer
defer func() {
for _, closer := range closers {
closer.Close()
}
}()
in := make([]reflect.Value, len(extractors))
for idx, extractor := range extractors {
v, err := extractor(w, r)
@ -128,9 +146,6 @@ func MagicHandler(fn any) (http.Handler, error) {
}
return
}
if closer, ok := v.(io.Closer); ok {
defer closer.Close()
}
in[idx] = reflect.ValueOf(v)
}
values := funcValue.Call(in)
@ -169,6 +184,11 @@ func MagicHandler(fn any) (http.Handler, error) {
continue
}
responseWriter.WriteResponse(w)
} else if responseWriter, ok := util.Implements[magic.RespondWriter](&value); ok {
if responseWriter == nil {
continue
}
responseWriter.WriteResponse(w)
} else if headers, ok := obj.(http.Header); ok { // not
for name, values := range headers {
for idx, value := range values {

View File

@ -30,10 +30,47 @@ type HelloWorldParams struct {
}
type MultipleParams struct {
Numbers []int
Strings []string
OptionalNumber *int
OptionalString *string
Numbers []int `query:"numbers"`
Strings []string `query:"strings"`
OptionalNumber *int `query:"optionalNumber"`
OptionalString *string `query:"optionalString"`
}
type QueryTypes struct {
IntField int `query:"int_field"`
StringField string `query:"string_field"`
BoolField bool `query:"bool_field"`
FloatField32 float32 `query:"float32_field"`
FloatField64 float64 `query:"float64_field"`
IntSlice []int `query:"int_slice"`
StringSlice []string `query:"string_slice"`
BoolSlice []bool `query:"bool_slice"`
OptionalInt *int `query:"optional_int"`
}
type IntegerPathValues struct {
IntField int `pathvalue:"int_val"`
Int8Field int8 `pathvalue:"int8_val"`
Int16Field int16 `pathvalue:"int16_val"`
Int32Field int32 `pathvalue:"int32_val"`
Int64Field int64 `pathvalue:"int64_val"`
UintField uint `pathvalue:"uint_val"`
Uint8Field uint8 `pathvalue:"uint8_val"`
Uint16Field uint16 `pathvalue:"uint16_val"`
Uint32Field uint32 `pathvalue:"uint32_val"`
Uint64Field uint64 `pathvalue:"uint64_val"`
}
type ErrorFromRequest struct{}
func (e ErrorFromRequest) FromRequest(r *http.Request) error {
return fmt.Errorf("custom error from request")
}
type PtrErrorFromRequest struct{}
func (e *PtrErrorFromRequest) FromRequest(r *http.Request) error {
return fmt.Errorf("custom error from pointer request")
}
func TestNotAFunc(t *testing.T) {
@ -60,6 +97,12 @@ func TestMethod(t *testing.T) {
MagicHandler(obj.Func)
}
func TestJsonResponse(t *testing.T) {
MustMagicHandler(func() magic.Json[map[string]any] {
return magic.Json[map[string]any]{}
})
}
type testCase struct {
Path string
Pattern string
@ -278,13 +321,468 @@ var (
// })
// },
// },
{
Pattern: "/hello-query",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(query magic.Query[HelloWorldParams]) magic.Json[HelloWorldResponse] {
name := "world"
if query.Data.Name != nil {
name = *query.Data.Name
}
return magic.Json[HelloWorldResponse]{
Data: HelloWorldResponse{
Greeting: "hello, " + name,
},
}
})
},
},
{
Pattern: "/test-query",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(query magic.Query[MultipleParams]) {
log.Println(query.Data)
// Add assertions for MultipleParams
if len(query.Data.Numbers) != 3 || query.Data.Numbers[0] != 1 || query.Data.Numbers[1] != 2 || query.Data.Numbers[2] != 3 {
panic(fmt.Sprintf("expected numbers [1 2 3], got %v", query.Data.Numbers))
}
if len(query.Data.Strings) != 2 || query.Data.Strings[0] != "a" || query.Data.Strings[1] != "b" {
panic(fmt.Sprintf("expected strings [a b], got %v", query.Data.Strings))
}
if query.Data.OptionalNumber == nil || *query.Data.OptionalNumber != 10 {
panic(fmt.Sprintf("expected optional number 10, got %v", query.Data.OptionalNumber))
}
if query.Data.OptionalString == nil || *query.Data.OptionalString != "optional" {
panic(fmt.Sprintf("expected optional string 'optional', got %v", query.Data.OptionalString))
}
})
},
MakeRequest: func(base string) *http.Request {
req, err := http.NewRequest("GET", base+"/test-query?numbers=1&numbers=2&numbers=3&strings=a&strings=b&optionalNumber=10&optionalString=optional", nil)
if err != nil {
panic(err)
}
return req
},
},
{
Pattern: "/test-query-optional-missing",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(query magic.Query[MultipleParams]) {
// Test missing optional fields
if query.Data.OptionalNumber != nil {
panic(fmt.Sprintf("expected optional number to be nil, got %v", query.Data.OptionalNumber))
}
if query.Data.OptionalString != nil {
panic(fmt.Sprintf("expected optional string to be nil, got %v", query.Data.OptionalString))
}
})
},
MakeRequest: func(base string) *http.Request {
req, err := http.NewRequest("GET", base+"/test-query-optional-missing?numbers=1&strings=a", nil)
if err != nil {
panic(err)
}
return req
},
},
{
Pattern: "/test-query-types",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
// Test various data types and slices
if query.Data.IntField != 123 {
panic(fmt.Sprintf("expected int_field 123, got %d", query.Data.IntField))
}
if query.Data.StringField != "test_string" {
panic(fmt.Sprintf("expected string_field 'test_string', got '%s'", query.Data.StringField))
}
if !query.Data.BoolField {
panic(fmt.Sprintf("expected bool_field true, got %t", query.Data.BoolField))
}
if query.Data.FloatField32 != 1.23 {
panic(fmt.Sprintf("expected float32_field 1.23, got %f", query.Data.FloatField32))
}
if query.Data.FloatField64 != 4.56 {
panic(fmt.Sprintf("expected float64_field 4.56, got %f", query.Data.FloatField64))
}
if len(query.Data.IntSlice) != 2 || query.Data.IntSlice[0] != 1 || query.Data.IntSlice[1] != 2 {
panic(fmt.Sprintf("expected int_slice [1 2], got %v", query.Data.IntSlice))
}
if len(query.Data.StringSlice) != 2 || query.Data.StringSlice[0] != "x" || query.Data.StringSlice[1] != "y" {
panic(fmt.Sprintf("expected string_slice [x y], got %v", query.Data.StringSlice))
}
if len(query.Data.BoolSlice) != 2 || query.Data.BoolSlice[0] != true || query.Data.BoolSlice[1] != false {
panic(fmt.Sprintf("expected bool_slice [true false], got %v", query.Data.BoolSlice))
}
if query.Data.OptionalInt == nil || *query.Data.OptionalInt != 99 {
panic(fmt.Sprintf("expected optional_int 99, got %v", query.Data.OptionalInt))
}
})
},
MakeRequest: func(base string) *http.Request {
req, err := http.NewRequest("GET", base+"/test-query-types?int_field=123&string_field=test_string&bool_field=true&float32_field=1.23&float64_field=4.56&int_slice=1&int_slice=2&string_slice=x&string_slice=y&bool_slice=true&bool_slice=false&optional_int=99", nil)
if err != nil {
panic(err)
}
return req
},
},
{
Pattern: "/test-query-types-missing-optional",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
// Test missing optional int
if query.Data.OptionalInt != nil {
panic(fmt.Sprintf("expected optional_int to be nil, got %v", query.Data.OptionalInt))
}
})
},
MakeRequest: func(base string) *http.Request {
req, err := http.NewRequest("GET", base+"/test-query-types-missing-optional?int_field=123&string_field=test_string&bool_field=true&float32_field=1.23&float64_field=4.56&int_slice=1&int_slice=2&string_slice=x&string_slice=y&bool_slice=true&bool_slice=false", nil)
if err != nil {
panic(err)
}
return req
},
},
{
Pattern: "/test-query-invalid-int",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
panic("handler should not be called for invalid input")
})
},
MakeRequest: func(base string) *http.Request {
req, err := http.NewRequest("GET", base+"/test-query-invalid-int?int_field=abc&string_field=test", nil)
if err != nil {
panic(err)
}
return req
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
return nil
},
},
{
Pattern: "/test-query-invalid-bool",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
panic("handler should not be called for invalid input")
})
},
MakeRequest: func(base string) *http.Request {
req, err := http.NewRequest("GET", base+"/test-query-invalid-bool?int_field=123&string_field=test&bool_field=not_a_bool", nil)
if err != nil {
panic(err)
}
return req
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
return nil
},
},
{
Pattern: "/test-query-invalid-float",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
panic("handler should not be called for invalid input")
})
},
MakeRequest: func(base string) *http.Request {
req, err := http.NewRequest("GET", base+"/test-query-invalid-float?int_field=123&string_field=test&bool_field=true&float64_field=not_a_float", nil)
if err != nil {
panic(err)
}
return req
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
return nil
},
},
{
Pattern: "/test-form",
CreateHandler: func() http.Handler {
type TestForm struct {
Name string `form:"name"`
Age int `form:"age"`
IsPro bool `form:"is_pro"`
}
return MustMagicHandler(func(form magic.Form[TestForm]) {
if form.Data.Name != "test" {
panic(fmt.Sprintf("expected name 'test', got '%s'", form.Data.Name))
}
if form.Data.Age != 30 {
panic(fmt.Sprintf("expected age 30, got %d", form.Data.Age))
}
if !form.Data.IsPro {
panic(fmt.Sprintf("expected is_pro true, got %t", form.Data.IsPro))
}
})
},
MakeRequest: func(base string) *http.Request {
body := strings.NewReader("name=test&age=30&is_pro=true")
req, err := http.NewRequest("POST", base+"/test-form", body)
if err != nil {
panic(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
},
{
Pattern: "/test-form-advanced",
CreateHandler: func() http.Handler {
type TestFormAdvanced struct {
Names []string `form:"names"`
Ages [2]int `form:"ages"`
ID string `form:"id,match=^[a-f0-9]{8}$"`
Optional *string `form:"optional"`
Numbers []int `form:"numbers"`
Required string `form:"required"`
FixedSize [3]bool `form:"fixed_size"`
}
return MustMagicHandler(func(form magic.Form[TestFormAdvanced]) {
// Test successful extraction
if len(form.Data.Names) != 2 || form.Data.Names[0] != "alice" || form.Data.Names[1] != "bob" {
panic(fmt.Sprintf("expected names [alice bob], got %v", form.Data.Names))
}
if form.Data.Ages[0] != 25 || form.Data.Ages[1] != 30 {
panic(fmt.Sprintf("expected ages [25 30], got %v", form.Data.Ages))
}
if form.Data.ID != "abcdef01" {
panic(fmt.Sprintf("expected id 'abcdef01', got '%s'", form.Data.ID))
}
if form.Data.Optional == nil || *form.Data.Optional != "present" {
panic(fmt.Sprintf("expected optional 'present', got %v", form.Data.Optional))
}
if len(form.Data.Numbers) != 3 || form.Data.Numbers[0] != 1 || form.Data.Numbers[1] != 2 || form.Data.Numbers[2] != 3 {
panic(fmt.Sprintf("expected numbers [1 2 3], got %v", form.Data.Numbers))
}
if form.Data.Required != "required_value" {
panic(fmt.Sprintf("expected required 'required_value', got '%s'", form.Data.Required))
}
if form.Data.FixedSize[0] != true || form.Data.FixedSize[1] != false || form.Data.FixedSize[2] != true {
panic(fmt.Sprintf("expected fixed_size [true false true], got %v", form.Data.FixedSize))
}
})
},
MakeRequest: func(base string) *http.Request {
body := strings.NewReader("names=alice&names=bob&ages=25&ages=30&id=abcdef01&optional=present&numbers=1&numbers=2&numbers=3&required=required_value&fixed_size=true&fixed_size=false&fixed_size=true")
req, err := http.NewRequest("POST", base+"/test-form-advanced", body)
if err != nil {
panic(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
},
{
Pattern: "/test-form-advanced-errors",
CreateHandler: func() http.Handler {
type TestFormAdvancedErrors struct {
Required string `form:"required"`
}
return MustMagicHandler(func(form magic.Form[TestFormAdvancedErrors]) {
// This handler should not be reached in case of errors,
// the magic handler should return an error before calling this.
panic("handler should not be called for error test cases")
})
},
MakeRequest: func(base string) *http.Request {
// This request is intentionally malformed to trigger errors
body := strings.NewReader("names=alice&ages=25&ages=30&id=invalid-id&numbers=1&required=required_value&fixed_size=true&fixed_size=false") // Missing one fixed_size value, invalid ID
req, err := http.NewRequest("POST", base+"/test-form-advanced-errors", body)
if err != nil {
panic(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
CheckResponse: func(resp *http.Response) error {
// Expecting an error response due to invalid form data
if resp.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
// Further checks on the error message body could be added here
return nil
},
},
{
Pattern: "/test-json-non-struct",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(body magic.Json[string]) string {
return "Received: " + body.Data
})
},
MakeRequest: func(base string) *http.Request {
body := strings.NewReader(`"a simple string"`)
req, err := http.NewRequest("POST", base+"/test-json-non-struct", body)
if err != nil {
panic(err)
}
req.Header.Set("Content-Type", "application/json")
return req
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode)
}
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
bodyString := string(bodyBytes)
// The response is the string "Received: " + the JSON string,
// which includes the quotes from the original JSON body.
expected := `"Received: \"a simple string\""`
if !strings.Contains(bodyString, expected) {
return fmt.Errorf("expected response body to contain '%s', got '%s'", expected, bodyString)
}
return nil
},
},
{
Pattern: "/test-from-request-error",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(e ErrorFromRequest) {
panic("handler should not be called for error test cases")
})
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusInternalServerError {
return fmt.Errorf("expected status %d, got %d", http.StatusInternalServerError, resp.StatusCode)
}
return nil
},
},
{
Pattern: "/test-ptr-from-request-error",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(e *PtrErrorFromRequest) {
panic("handler should not be called for error test cases")
})
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusInternalServerError {
return fmt.Errorf("expected status %d, got %d", http.StatusInternalServerError, resp.StatusCode)
}
return nil
},
},
{
Pattern: "/test-form-invalid-int",
CreateHandler: func() http.Handler {
type TestForm struct {
Age int `form:"age"`
}
return MustMagicHandler(func(form magic.Form[TestForm]) {
panic("handler should not be called for invalid input")
})
},
MakeRequest: func(base string) *http.Request {
body := strings.NewReader("age=abc")
req, err := http.NewRequest("POST", base+"/test-form-invalid-int", body)
if err != nil {
panic(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
return nil
},
},
{
Pattern: "/test-form-array-too-few",
CreateHandler: func() http.Handler {
type TestForm struct {
Ages [2]int `form:"ages"`
}
return MustMagicHandler(func(form magic.Form[TestForm]) {
panic("handler should not be called for invalid input")
})
},
MakeRequest: func(base string) *http.Request {
body := strings.NewReader("ages=25")
req, err := http.NewRequest("POST", base+"/test-form-array-too-few", body)
if err != nil {
panic(err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return req
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
return nil
},
},
{
Pattern: "/test-query-slice-invalid-entry",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(query magic.Query[QueryTypes]) {
panic("handler should not be called for invalid input")
})
},
MakeRequest: func(base string) *http.Request {
req, err := http.NewRequest("GET", base+"/test-query-slice-invalid-entry?int_slice=1&int_slice=abc&int_slice=3", nil)
if err != nil {
panic(err)
}
return req
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
return nil
},
},
{
Pattern: "/test-double-body-read",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(body io.Reader, jsonBody magic.Json[HelloWorldRequest]) {
// This handler should panic because we are trying to read the body twice.
panic("handler should not be called")
})
},
ExpectPanicOnCreateHandler: true,
},
{
Pattern: "/websocket-fail",
CreateHandler: func() http.Handler {
return MustMagicHandler(func(ws *websocket.Conn) {
// This handler should not be called if websocket upgrade fails.
panic("handler should not be called")
})
},
MakeRequest: func(base string) *http.Request {
// A regular GET request without the upgrade headers will fail the websocket upgrade.
req, err := http.NewRequest("GET", base+"/websocket-fail", nil)
if err != nil {
panic(err)
}
return req
},
CheckResponse: func(resp *http.Response) error {
if resp.StatusCode != http.StatusBadRequest {
return fmt.Errorf("expected status %d, got %d", http.StatusBadRequest, resp.StatusCode)
}
return nil
},
},
}
)

37
magic/autodecode.go Normal file
View File

@ -0,0 +1,37 @@
package magic
import (
"net/http"
"strings"
)
type ErrUknownContentType string
func (err ErrUknownContentType) Error() string {
return "unknown content-type: " + string(err)
}
type AutoDecode[T any] struct {
Data T
}
func (d *AutoDecode[T]) FromRequest(r *http.Request) error {
contentType, _, _ := strings.Cut(r.Header.Get("Content-Type"), ";")
switch contentType {
case "application/json":
decoder := &Json[T]{}
if err := decoder.FromRequest(r); err != nil {
return err
}
d.Data = decoder.Data
case "multipart/form-data", "application/x-www-form-urlencoded":
decoder := &Form[T]{}
if err := decoder.FromRequest(r); err != nil {
return err
}
d.Data = decoder.Data
}
return ErrUknownContentType(contentType)
}
func (d AutoDecode[T]) TakeRequestBody() {}

View File

@ -4,6 +4,8 @@ import (
"reflect"
"strconv"
"strings"
"github.com/go-playground/validator/v10"
)
var (
@ -54,21 +56,21 @@ var (
if v, err := strconv.ParseInt(s, 10, 16); err != nil {
return nil, err
} else {
return int8(v), nil
return int16(v), nil
}
},
reflect.Int32: func(s string) (any, error) {
if v, err := strconv.ParseInt(s, 10, 32); err != nil {
return nil, err
} else {
return int8(v), nil
return int32(v), nil
}
},
reflect.Int64: func(s string) (any, error) {
if v, err := strconv.ParseInt(s, 10, 64); err != nil {
return nil, err
} else {
return int8(v), nil
return v, nil
}
},
@ -108,4 +110,24 @@ var (
}
},
}
defaultValidator = NewValidator("")
)
func NewValidator(nameFromTag string) *validator.Validate {
validate := validator.New(
validator.WithRequiredStructEnabled(),
)
if nameFromTag != "" {
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get(nameFromTag), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
}
return validate
}

View File

@ -11,6 +11,8 @@ import (
var (
extractors = make(map[reflect.Type]func(*http.Request) (any, error))
extractorsTakeBody = make(map[reflect.Type]bool)
extractorsTakesResponseWriter = make(map[reflect.Type]func(http.ResponseWriter, *http.Request) (any, error))
)
@ -18,6 +20,12 @@ type FromRequest interface {
FromRequest(*http.Request) error
}
// Marker interface
type TakeRequestBody interface {
TakeRequestBody()
}
// Marker interface
type TakeResponseWriter interface {
TakeResponseWriter(http.ResponseWriter)
}
@ -43,6 +51,42 @@ func RegisterExtractorGeneric[T any](extractor func(*http.Request) (any, error))
RegisterExtractor(reflect.TypeOf(pointerToT).Elem(), extractor)
}
func RegisterExtractorThatTakesBody(o any, extractor func(*http.Request) (any, error)) {
if t, ok := o.(reflect.Type); ok {
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
extractors[t] = extractor
extractorsTakeBody[t] = true
} else if v, ok := o.(reflect.Value); ok {
slog.With("type", v.Type().String(), "value", v.Interface()).Debug("extractor type registered with reflect.Value")
extractors[v.Type()] = extractor
extractorsTakeBody[v.Type()] = true
} else {
t := reflect.TypeOf(o)
slog.With("type", t.String(), "value", o).Debug("extractor type registered with object")
extractors[t] = extractor
extractorsTakeBody[t] = true
}
}
func RegisterExtractorThatTakesBodyGeneric[T any](extractor func(*http.Request) (any, error)) {
var pointerToT *T
RegisterExtractorThatTakesBody(reflect.TypeOf(pointerToT).Elem(), extractor)
}
func IsTakeBody(t reflect.Type) bool {
if isTakeBody := extractorsTakeBody[t]; isTakeBody {
return isTakeBody
}
if _, ok := util.Implements[TakeRequestBody](t); ok {
return ok
}
return false
}
func RegisterExtractorThatTakesResponseWriter(o any, extractor func(http.ResponseWriter, *http.Request) (any, error)) {
if t, ok := o.(reflect.Type); ok {
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
@ -70,62 +114,57 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
_, ptrIsTakeResponseWriter := util.Implements[TakeResponseWriter](t.Elem())
isTakeResponseWriter = isTakeResponseWriter || ptrIsTakeResponseWriter
}
if _, ok := util.Implements[FromRequest](t); ok {
fromRequestType := reflect.TypeOf((*FromRequest)(nil)).Elem()
if t.Implements(fromRequestType) {
return func(w http.ResponseWriter, r *http.Request) (any, error) {
val := reflect.New(t).Elem()
if t.Kind() == reflect.Pointer {
// T.Implement(FromRequest) and T is Pointer
// create new actual T and call T.FromRequest on it
// if error happens, no error is returned. just return nil
return func(w http.ResponseWriter, r *http.Request) (any, error) {
// var t T
// return t, t.FromRequest(r)
z := reflect.New(t.Elem())
val = reflect.New(t.Elem())
}
if isTakeResponseWriter {
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
if taker, ok := val.Addr().Interface().(TakeResponseWriter); ok {
taker.TakeResponseWriter(w)
}
}
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
if err, ok := results[0].Interface().(error); ok && err != nil {
if errResponse, ok := err.(ErrorResponse); ok {
return reflect.Zero(t).Interface(), errResponse
if fromR, ok := val.Interface().(FromRequest); ok {
if err := fromR.FromRequest(r); err != nil {
return nil, err
}
return reflect.Zero(t).Interface(), nil
} else if val.CanAddr() {
if fromR, ok := val.Addr().Interface().(FromRequest); ok {
if err := fromR.FromRequest(r); err != nil {
return nil, err
}
return z.Interface(), nil
}
}
if t.Kind() == reflect.Pointer {
return val.Interface(), nil
}
return val.Addr().Interface(), nil
}, isTakeResponseWriter
}
// T.Implement(FromRequest) and T is not Pointer
// create zero T and call T.FromRequest directly on it
if reflect.PointerTo(t).Implements(fromRequestType) {
return func(w http.ResponseWriter, r *http.Request) (any, error) {
// var t T
// return t, t.FromRequest(r)
z := reflect.Zero(t)
val := reflect.New(t)
if isTakeResponseWriter {
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
if taker, ok := val.Interface().(TakeResponseWriter); ok {
taker.TakeResponseWriter(w)
}
}
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
if err, ok := results[0].Interface().(error); ok {
return z.Interface(), err
if fromR, ok := val.Interface().(FromRequest); ok {
if err := fromR.FromRequest(r); err != nil {
return nil, err
}
return z.Interface(), nil
}, isTakeResponseWriter
} else if v, ok := util.PointerImplements[FromRequest](t); ok {
return func(w http.ResponseWriter, r *http.Request) (any, error) {
// var t T
// return t, t.FromRequest(r)
z := reflect.New(v.Type().Elem())
if isTakeResponseWriter {
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
}
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
if err := results[0].Interface(); err == nil {
return z.Elem().Interface(), nil
}
return z.Elem().Interface(), results[0].Interface().(error)
return val.Elem().Interface(), nil
}, isTakeResponseWriter
}
@ -142,7 +181,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
if err != nil {
return nil, err
}
return &v, nil
// Create a pointer to the value
vp := reflect.New(reflect.TypeOf(v))
vp.Elem().Set(reflect.ValueOf(v))
return vp.Interface(), nil
}, false
}
}
@ -160,7 +202,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
if err != nil {
return nil, err
}
return &v, nil
// Create a pointer to the value
vp := reflect.New(reflect.TypeOf(v))
vp.Elem().Set(reflect.ValueOf(v))
return vp.Interface(), nil
}, true
}
}

206
magic/form.go Normal file
View File

@ -0,0 +1,206 @@
package magic
import (
"fmt"
"net/http"
"reflect"
"regexp"
"strconv"
"strings"
)
var (
cachedStructFormFields = make(map[reflect.Type][]structFormFields)
cachedUnsupportedFormFields = make(map[reflect.Type]structFormFields)
)
type Form[T any] struct {
Data T
}
type structFormFields struct {
PathKey string
FieldKey string
Type reflect.Type
MinValues, MaxValues int
Match *regexp.Regexp
}
type UnsupportedFormType struct {
inner reflect.Type
}
func (err UnsupportedFormType) Error() string {
return "unsupported type for Form: " + err.inner.String()
}
type FormValueNotFitIn struct {
Count int
field structFormFields
}
func (err FormValueNotFitIn) Error() string {
return fmt.Sprintf("Form %v is expected to has %v~%v values, got %v", err.field.PathKey, err.field.MinValues, err.field.MaxValues, err.Count)
}
type InvalidFormValue struct {
Match *regexp.Regexp
Value string
}
func (err InvalidFormValue) Error() string {
return fmt.Sprintf("value not matched with regexp %v: %v", err.Match, err.Value)
}
func findReflectFormFields(v reflect.Value) ([]structFormFields, error) {
if cached, ok := cachedStructFormFields[v.Type()]; ok {
return cached, nil
}
if cached, ok := cachedUnsupportedFormFields[v.Type()]; ok {
return nil, UnsupportedPathValueType{inner: cached.Type}
}
var fields []structFormFields
if v.Kind() == reflect.Struct {
t := v.Type()
for idx := 0; idx < t.NumField(); idx++ {
field := t.Field(idx)
pathKey := field.Name
var match *regexp.Regexp
if tags, ok := field.Tag.Lookup("form"); ok {
parts := strings.Split(strings.TrimSpace(tags), ",")
if len(parts) == 0 {
// do nothing
} else {
if parts[0] != "" {
pathKey = parts[0]
}
for _, part := range parts[1:] {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "match=") {
part := strings.TrimPrefix(part, "match=")
if unquoted, err := strconv.Unquote(part); err != nil {
return nil, err
} else {
part = unquoted
}
exp, err := regexp.Compile(part)
if err != nil {
return nil, err
}
match = exp
}
}
}
}
fieldType := field.Type
maxValues, minValues := 1, 1
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
minValues = 0
} else if fieldType.Kind() == reflect.Array {
maxValues, minValues = fieldType.Len(), fieldType.Len()
fieldType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Slice {
fieldType = fieldType.Elem()
maxValues, minValues = -1, 0
}
if pathValueConvertTable[fieldType.Kind()] == nil {
if _, ok := fieldType.MethodByName("FromString"); !ok {
cachedUnsupportedPathValueFields[t] = field
return nil, UnsupportedPathValueType{inner: field.Type}
}
}
fields = append(fields, structFormFields{
PathKey: pathKey,
FieldKey: field.Name,
Type: fieldType,
MaxValues: maxValues,
MinValues: minValues,
Match: match,
})
cachedStructFormFields[t] = fields
}
} else {
return nil, UnsupportedFormType{inner: v.Type()}
}
return fields, nil
}
func (form *Form[T]) FromRequest(r *http.Request) error {
v := reflect.ValueOf(form).Elem().FieldByName("Data")
fields, err := findReflectFormFields(v)
if err != nil {
return err
}
if err := r.ParseMultipartForm(32 << 20); err != nil {
return err
}
q := r.PostForm
for _, field := range fields {
values := q[field.PathKey]
if len(values) < field.MinValues || (field.MaxValues > 0 && len(values) > field.MaxValues) {
return FormValueNotFitIn{Count: len(values), field: field}
}
structField := v.FieldByName(field.FieldKey)
if structField.Kind() == reflect.Slice && structField.Len() < len(values) {
if structField.Cap() < len(values) {
structField.SetCap(len(values))
}
structField.SetLen(len(values))
}
for idx, value := range values {
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
if idx != len(values)-1 {
// only the last takes effect
continue
}
}
// check match regexp
if field.Match != nil && !field.Match.MatchString(value) {
return InvalidFormValue{Match: field.Match, Value: value}
}
// convert to elem value
convert := pathValueConvertTable[field.Type.Kind()]
result, err := convert(value)
if err != nil {
return err
}
reflectValue := reflect.ValueOf(result)
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
if field.MinValues == 0 {
newReflectValue := reflect.New(field.Type)
newReflectValue.Elem().Set(reflectValue)
reflectValue = newReflectValue
}
structField.Set(reflectValue)
} else {
structField.Index(idx).Set(reflectValue)
}
}
}
return nil
}
func (form Form[T]) TakeRequestBody() {}

View File

@ -2,8 +2,15 @@ package magic
import (
"encoding/json"
"io"
"net/http"
"reflect"
"strings"
"github.com/go-playground/validator/v10"
)
var (
jsonValidator = NewValidator("json")
)
type JsonDecodeError struct {
@ -18,10 +25,25 @@ func (err JsonDecodeError) WriteResponse(rw http.ResponseWriter) {
rw.WriteHeader(http.StatusBadRequest)
}
type Json[T any] struct {
Data T
type ValidateError []string
func (err ValidateError) Error() string {
return "invalid or missing fields: [" + strings.Join(err, ", ") + "]"
}
func (v ValidateError) WriteResponse(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(v.Error()))
}
type Json[T any] struct {
Data T `json:"data"`
}
// TakeRequestBody is a marker method to indicate that this extractor takes the request body.
func (j Json[T]) TakeRequestBody() {}
func NewJson[T any](data T) Json[T] {
return Json[T]{
Data: data,
@ -29,35 +51,41 @@ func NewJson[T any](data T) Json[T] {
}
func (data *Json[T]) FromRequest(request *http.Request) error {
if request.Body == nil {
panic("body is already taken")
}
bodyReader := request.Body
defer bodyReader.Close()
request.Body = http.NoBody
request.Body = nil
if err := json.NewDecoder(bodyReader).Decode(&data.Data); err != nil {
return JsonDecodeError{inner: err}
}
value := reflect.ValueOf(data.Data)
if value.Kind() == reflect.Struct || (value.Kind() == reflect.Ptr && value.Elem().Kind() == reflect.Struct) {
if err := jsonValidator.StructCtx(request.Context(), data.Data); err == nil {
return nil
} else if validateErr, isValidateErr := err.(validator.ValidationErrors); isValidateErr {
fields := make([]string, len(validateErr))
for idx, err := range validateErr {
fields[idx] = err.Field()
}
return ValidateError(fields)
} else {
return err
}
}
return nil
}
type counter struct {
counter int64
}
func (c *counter) Write(b []byte) (int, error) {
n := len(b)
c.counter += int64(n)
return n, nil
}
func (container Json[T]) Write(w io.Writer) (int64, error) {
counter := &counter{}
err := json.NewEncoder(io.MultiWriter(counter, w)).Encode(container.Data)
return counter.counter, err
}
func (json Json[T]) WriteResponse(w http.ResponseWriter) {
json.Write(w)
func (data *Json[T]) WriteResponse(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
if data == nil {
return
}
json.NewEncoder(w).Encode(data)
}

View File

@ -8,5 +8,6 @@ import (
type Map map[string]any
func (m Map) RespondWriter(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(m)
}

View File

@ -159,7 +159,7 @@ func (query *Query[T]) FromRequest(r *http.Request) error {
if structField.Kind() == reflect.Slice && structField.Len() < len(values) {
if structField.Cap() < len(values) {
structField.SetCap(len(values))
structField.Grow(len(values) - structField.Cap())
}
structField.SetLen(len(values))
}

View File

@ -2,7 +2,10 @@ package magic
import (
"context"
"io"
"net/http"
"git.jeffthecoder.xyz/public/lazyhandler/middleware/cleanup"
)
func init() {
@ -12,4 +15,49 @@ func init() {
RegisterExtractorGeneric[context.Context](func(r *http.Request) (any, error) {
return r.Context(), nil
})
// RegisterExtractorThatTakesBodyGeneric for io.ReadCloser extracts the request body.
//
// IMPORTANT: This extractor consumes the request body (r.Body). After this extractor
// is used, r.Body will be set to nil to prevent it from being read multiple times.
// Any subsequent attempt to read the body will fail.
//
// A cleanup function is automatically registered to close the body once the request
// is finished. This requires the cleanup middleware to be present in the handler chain.
RegisterExtractorThatTakesBodyGeneric[io.ReadCloser](func(r *http.Request) (any, error) {
if r.Body == nil {
panic("body is already taken")
}
body := r.Body
r.Body = nil
cleanup.Register(r.Context(), cleanup.CleanupFunc(func() {
body.Close()
}))
return body, nil
})
// RegisterExtractorThatTakesBodyGeneric for io.Reader extracts the request body.
//
// IMPORTANT: This extractor consumes the request body (r.Body). After this extractor
// is used, r.Body will be set to nil to prevent it from being read multiple times.
// Any subsequent attempt to read the body will fail.
//
// A cleanup function is automatically registered to close the body once the request
// is finished. This requires the cleanup middleware to be present in the handler chain.
RegisterExtractorThatTakesBodyGeneric[io.Reader](func(r *http.Request) (any, error) {
if r.Body == nil {
panic("body is already taken")
}
body := r.Body
r.Body = nil
cleanup.Register(r.Context(), cleanup.CleanupFunc(func() {
body.Close()
}))
return body, nil
})
}

View File

@ -22,6 +22,26 @@ type ErrorResponse interface {
WriteResponse(http.ResponseWriter)
}
func WrapSimpleErrorWithStatus(message string, status int) ErrorResponse {
return simpleErrorResponseWithoutBody{
message: message,
status: status,
}
}
type simpleErrorResponseWithoutBody struct {
message string
status int
}
func (err simpleErrorResponseWithoutBody) Error() string {
return err.message
}
func (err simpleErrorResponseWithoutBody) WriteResponse(w http.ResponseWriter) {
w.WriteHeader(err.status)
}
func init() {
RegisterExtractorThatTakesResponseWriterGeneric[http.ResponseWriter](func(w http.ResponseWriter, r *http.Request) (any, error) {
return w, nil

View File

@ -2,12 +2,8 @@ package magic
import "net/http"
type State[T any] struct {
Data T
}
func RegisterState[T any](data T) {
RegisterExtractor(State[T]{}, func(r *http.Request) (any, error) {
return State[T]{Data: data}, nil
RegisterExtractor(data, func(r *http.Request) (any, error) {
return data, nil
})
}

View File

@ -0,0 +1,61 @@
package cleanup
import (
"context"
"log/slog"
"net/http"
"reflect"
"git.jeffthecoder.xyz/public/lazyhandler/middleware"
)
type ctxKey int
const (
cleanupCtxKey ctxKey = iota
)
type CleanupContext struct {
funcs []Cleanup
}
type Cleanup interface {
Name() string
Cleanup()
}
type CleanupFunc func()
func (fn CleanupFunc) Name() string {
return reflect.TypeOf(fn).String()
}
func (fn CleanupFunc) Cleanup() {
defer func() {
if v := recover(); v != nil {
slog.With("v", v).Warn("cleanup panicked")
}
}()
fn()
}
func Collect() middleware.Middleware {
return middleware.WrapFunc(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := &CleanupContext{}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cleanupCtxKey, ctx)))
for _, cleanupFunc := range ctx.funcs {
defer cleanupFunc.Cleanup()
}
})
})
}
// Register adds a Cleanup function to the CleanupContext in the provided context.
// If the CleanupContext is not found in the context, the Cleanup function is not registered.
func Register(ctx context.Context, c Cleanup) {
if ctx, ok := ctx.Value(cleanupCtxKey).(*CleanupContext); ok {
ctx.funcs = append(ctx.funcs, c)
}
}

View File

@ -0,0 +1,54 @@
package cleanup
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCleanupMiddleware(t *testing.T) {
var executionOrder []string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Register(r.Context(), CleanupFunc(func() {
executionOrder = append(executionOrder, "first")
}))
Register(r.Context(), CleanupFunc(func() {
executionOrder = append(executionOrder, "second")
}))
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
middleware := Collect().WrapHandler(handler)
middleware.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
rr.Code, http.StatusOK)
}
expectedOrder := "second,first"
actualOrder := strings.Join(executionOrder, ",")
if actualOrder != expectedOrder {
t.Errorf("cleanup functions executed in wrong order: got %s want %s",
actualOrder, expectedOrder)
}
}
func TestCleanupMissingContext(t *testing.T) {
// This test ensures that Register does not panic when the context is missing.
// The function should fail silently.
defer func() {
if r := recover(); r != nil {
t.Errorf("The code panicked when it should not have")
}
}()
req := httptest.NewRequest("GET", "/", nil)
// No middleware, so no context
Register(req.Context(), CleanupFunc(func() {}))
}

View File

@ -75,6 +75,10 @@ func (log Log) WrapHandler(next http.Handler) http.Handler {
})
}
// Logger retrieves the slog.Logger from the request context.
// If the logger is not found in the context (e.g., the httplog middleware is not used),
// it creates and returns a new logger with basic request information.
// Using the httplog middleware is recommended to ensure the configured logger is available.
func Logger(r *http.Request) *slog.Logger {
if logger, ok := r.Context().Value(loggerKey).(*slog.Logger); ok {
return logger.With("time", time.Now())

View File

@ -0,0 +1,85 @@
package httplog
import (
"bufio"
"bytes"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"testing"
)
type mockHijacker struct {
*httptest.ResponseRecorder
hijacked bool
}
func (m *mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
m.hijacked = true
// Return dummy values, not used in this test
return nil, nil, nil
}
func TestResponseRecorder_Hijack(t *testing.T) {
recorder := &responseRecorder{
ResponseWriter: &mockHijacker{ResponseRecorder: httptest.NewRecorder()},
}
_, _, err := recorder.Hijack()
if err != nil {
t.Fatalf("Hijack failed: %v", err)
}
if recorder.ResponseWriter != nil {
t.Error("ResponseWriter should be nil after Hijack")
}
}
func TestLogger(t *testing.T) {
var buf bytes.Buffer
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, nil)))
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger := Logger(r)
logger.Info("test message")
w.WriteHeader(http.StatusOK)
})
t.Run("with middleware", func(t *testing.T) {
buf.Reset()
logMiddleware := Log{LogStart: true, LogFinish: true}
wrappedHandler := logMiddleware.WrapHandler(handler)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) {
t.Error("expected start log message, but not found")
}
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) {
t.Error("expected handler log message, but not found")
}
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) {
t.Error("expected finish log message, but not found")
}
})
t.Run("without middleware", func(t *testing.T) {
buf.Reset()
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) {
t.Error("expected handler log message, but not found")
}
if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) {
t.Error("unexpected start log message found")
}
if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) {
t.Error("unexpected finish log message found")
}
})
}

View File

@ -42,3 +42,20 @@ func DebugPanicHandler(w http.ResponseWriter, r *http.Request, err any) {
func Debug() middleware.Middleware {
return Recover(DebugPanicHandler)
}
// ProductionPanicHandler is a PanicResponseFunc that provides a generic error message
// in production environments and logs the detailed panic error.
func ProductionPanicHandler(w http.ResponseWriter, r *http.Request, err any) {
httplog.Logger(r).With("panic", err).Error("request panicked")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]any{
"error": "Internal Server Error",
})
}
// Production returns a middleware that recovers from panics and handles them
// with ProductionPanicHandler.
func Production() middleware.Middleware {
return Recover(ProductionPanicHandler)
}

View File

@ -0,0 +1,56 @@
package recover
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func panicHandler(w http.ResponseWriter, r *http.Request) {
panic("test panic")
}
func TestDebugRecover(t *testing.T) {
handler := http.HandlerFunc(panicHandler)
wrappedHandler := Debug().WrapHandler(handler)
req := httptest.NewRequest("GET", "/panic", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
}
var resp map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if resp["error"] != "test panic" {
t.Errorf("expected error 'test panic', got '%v'", resp["error"])
}
}
func TestProductionRecover(t *testing.T) {
handler := http.HandlerFunc(panicHandler)
wrappedHandler := Production().WrapHandler(handler)
req := httptest.NewRequest("GET", "/panic", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
}
var resp map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if resp["error"] != "Internal Server Error" {
t.Errorf("expected error 'Internal Server Error', got '%v'", resp["error"])
}
}