commit 61ffeeb3b835a9bb6142b844a0a4099d27fff352 Author: guochao Date: Sat Feb 22 23:00:56 2025 +0800 initial commit diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3326100 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module git.jeffthecoder.xyz/public/lazyhandler + +go 1.24.0 + +require ( + github.com/go-session/session/v3 v3.2.1 + github.com/gorilla/websocket v1.5.3 +) + +require github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6e9d68c --- /dev/null +++ b/go.sum @@ -0,0 +1,22 @@ +github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8= +github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-session/session/v3 v3.2.1 h1:APQf5JFW84+bhbqRjEZO8J+IppSgT1jMQTFI/XVyIFY= +github.com/go-session/session/v3 v3.2.1/go.mod h1:RftEBbyuzqkNCAxIrCLJe+rfBqB/4G11qxq9KYKrx4M= +github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0= +github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0= +github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/inject.go b/inject.go new file mode 100644 index 0000000..0799aa4 --- /dev/null +++ b/inject.go @@ -0,0 +1,199 @@ +package lazyhandler + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "reflect" + + "git.jeffthecoder.xyz/public/lazyhandler/magic" + "git.jeffthecoder.xyz/public/lazyhandler/util" +) + +type ErrArgumentIsNotExtractable int + +func (err ErrArgumentIsNotExtractable) Error() string { + return fmt.Sprintf("argument %v is not extractable", int(err)) +} + +type ErrReturnValueNotConvertableIntoResponsePart int + +func (err ErrReturnValueNotConvertableIntoResponsePart) Error() string { + return fmt.Sprintf("return value %v is not convertable into response part", int(err)) +} + +var ( + ErrNotAFunc = errors.New("not a function") + ErrDuplicateResponseWriterExtractor = errors.New("duplicate response writer extractor") + ErrResponseWriterCannotBeExtracted = errors.New("http.ResponseWriter extractor must not exists if function has return value") +) + +func canConvert[T any](o any) bool { + t := reflect.TypeOf((*T)(nil)).Elem() + if reflectValue, ok := o.(reflect.Value); ok { + return reflectValue.CanConvert(t) + } + if reflectType, ok := o.(reflect.Type); ok { + return reflectType.ConvertibleTo(t) + } + return reflect.ValueOf(o).CanConvert(t) +} + +func MagicHandler(fn any) (http.Handler, error) { + funcValue := reflect.ValueOf(fn) + t := funcValue.Type() + + if t.Kind() != reflect.Func { + return nil, ErrNotAFunc + } + + responseWriterExtracted := -1 + + extractors := make([]func(http.ResponseWriter, *http.Request) (any, error), t.NumIn()) + for idx := 0; idx < t.NumIn(); idx++ { + in := t.In(idx) + + extractor, isTakeResponseWriter := magic.GetExtractor(in) + if extractor == nil { + return nil, ErrArgumentIsNotExtractable(idx) + } + if isTakeResponseWriter { + if responseWriterExtracted >= 0 { + return nil, ErrDuplicateResponseWriterExtractor + } + responseWriterExtracted = idx + } + extractors[idx] = extractor + } + + // http.ResponseWriter extractor must not exists if function has return value + if responseWriterExtracted >= 0 && t.NumOut() > 0 { + return nil, ErrResponseWriterCannotBeExtracted + } + + for idx := 0; idx < t.NumOut(); idx++ { + out := t.Out(idx) + + // int(status) || string(body) + if out.Kind() == reflect.Int || out.Kind() == reflect.String { + continue + } + + // []byte(body) + if out.Kind() == reflect.Slice && out.Elem().Kind() == reflect.Uint8 { + continue + } + + // [T] map[string]T(header) + if out.Kind() == reflect.Map && out.Key().Kind() == reflect.String { + continue + } + + if _, ok := util.Implements[io.Reader](out); ok { + continue + } + + if _, ok := util.Implements[magic.RespondWriter](out); ok { + continue + } + + // last is error + if _, ok := util.Implements[error](out); ok && idx == t.NumOut()-1 { + continue + } + + return nil, ErrReturnValueNotConvertableIntoResponsePart(idx) + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var closers []io.Closer + defer func() { + for _, closer := range closers { + closer.Close() + } + }() + in := make([]reflect.Value, len(extractors)) + for idx, extractor := range extractors { + v, err := extractor(w, r) + if err != nil { + if errResponse, ok := err.(magic.ErrorResponse); ok { + errResponse.WriteResponse(w) + } else { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(magic.Map{ + "error": err.Error(), + }) + } + return + } + if closer, ok := v.(io.Closer); ok { + defer closer.Close() + } + in[idx] = reflect.ValueOf(v) + } + values := funcValue.Call(in) + + if numValues := len(values); numValues > 0 { + lastValue := values[numValues-1].Interface() + if err, isError := lastValue.(error); isError { + if err != nil { + if errResponse, ok := lastValue.(magic.ErrorResponse); ok { + errResponse.WriteResponse(w) + } else { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(magic.Map{ + "error": err.Error(), + }) + } + return + } + values = values[:numValues-1] + } + } + + for _, value := range values { + obj := value.Interface() + + if value.Kind() == reflect.Int { + w.WriteHeader(int(value.Int())) + } else if value.Kind() == reflect.String { + w.Write([]byte(value.String())) + } else if canConvert[[]byte](value) { + w.Write(value.Bytes()) + } else if reader, ok := util.Implements[io.Reader](value); ok { + io.Copy(w, reader) + } else if responseWriter, ok := util.Implements[magic.RespondWriter](value); ok { + if responseWriter == nil { + continue + } + responseWriter.WriteResponse(w) + } else if headers, ok := obj.(http.Header); ok { // not + for name, values := range headers { + for idx, value := range values { + if idx == 0 { + w.Header().Set(name, value) + } else { + w.Header().Add(name, value) + } + } + } + } else if value.Kind() == reflect.Map { + for _, key := range value.MapKeys() { + value := value.MapIndex(key).Interface() + w.Header().Set(key.String(), fmt.Sprint(value)) + } + } + } + }), nil +} + +func MustMagicHandler(fn any) http.Handler { + if handler, err := MagicHandler(fn); err != nil { + panic("can not be converted to http.Handler: " + fmt.Sprintf("%T", fn) + ": " + err.Error()) + } else { + return handler + } + +} diff --git a/inject_test.go b/inject_test.go new file mode 100644 index 0000000..b2f786f --- /dev/null +++ b/inject_test.go @@ -0,0 +1,356 @@ +package lazyhandler + +import ( + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "git.jeffthecoder.xyz/public/lazyhandler/magic" + "github.com/gorilla/websocket" +) + +type StatusPathValues struct { + Status int `pathvalue:"status"` +} + +type HelloWorldRequest struct { + Name string `json:"name"` +} +type HelloWorldResponse struct { + Greeting string `json:"greeting"` +} + +type HelloWorldParams struct { + Name *string `pathvalue:"name,match=" query:"name"` +} + +type MultipleParams struct { + Numbers []int + Strings []string + OptionalNumber *int + OptionalString *string +} + +func TestNotAFunc(t *testing.T) { + v, err := MagicHandler(1) + if err == nil { + t.Fatalf("expect not ok, get %v", v) + } +} + +func TestEmptyFunc(t *testing.T) { + MagicHandler(func() { + + }) +} + +type SomeType int + +func (obj SomeType) Func() { + +} + +func TestMethod(t *testing.T) { + obj := SomeType(0) + MagicHandler(obj.Func) +} + +type testCase struct { + Path string + Pattern string + + CreateHandler func() http.Handler + ExpectPanicOnCreateHandler bool + + NoCheck bool + MakeRequest func(string) *http.Request + CheckResponse func(*http.Response) error +} + +var ( + cases = []testCase{ + { + Path: "/", + Pattern: "/{$}", + + CreateHandler: func() http.Handler { + return MustMagicHandler(func() (int, string, error) { + return 200, "hi", nil + }) + }, + MakeRequest: func(base string) *http.Request { + r, err := http.NewRequest("GET", base+"/", nil) + if err != nil { + panic(err) + } + return r + }, + }, + { + Pattern: "/error-checked-before-other-return-values", + + CreateHandler: func() http.Handler { + return MustMagicHandler(func() (int, string, error) { + // it checks errors first + return 200, "hi", fmt.Errorf("it is handled before other return values") + }) + }, + MakeRequest: func(base string) *http.Request { + r, err := http.NewRequest("GET", base+"/return-error", nil) + if err != nil { + panic(err) + } + return r + }, + }, + { + Pattern: "/return-nil", + CreateHandler: func() http.Handler { + return MustMagicHandler(func() (magic.RespondWriter, error) { + // no response body, 200 + return nil, nil + }) + }, + }, + { + Pattern: "/no-return", + CreateHandler: func() http.Handler { + return MustMagicHandler(func() { + // return nothing at all + }) + }, + }, + { + Pattern: "/no-error", + CreateHandler: func() http.Handler { + return MustMagicHandler(func() int { + // well, error is not needed + return http.StatusNoContent + }) + }, + }, + { + Pattern: "/respond-json", + CreateHandler: func() http.Handler { + return MustMagicHandler(func() magic.Json[map[string]any] { + return magic.Json[map[string]any]{ + Data: map[string]any{ + "message": "hello, world", + }, + } + }) + }, + }, + { + Pattern: "/extract-json", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(body magic.Json[HelloWorldRequest]) magic.Json[HelloWorldResponse] { + return magic.Json[HelloWorldResponse]{ + Data: HelloWorldResponse{ + Greeting: "hello, " + body.Data.Name, + }, + } + }) + }, + }, + { + Pattern: "/extract-optional", + CreateHandler: func() http.Handler { + // the parameter accepts a pointer to extractor, which indicates that it is optional + return MustMagicHandler(func(body *magic.Json[HelloWorldRequest]) magic.Json[HelloWorldResponse] { + name := "world" + if body != nil { + name = body.Data.Name + } + return magic.Json[HelloWorldResponse]{ + Data: HelloWorldResponse{ + Greeting: "hello, " + name, + }, + } + }) + }, + }, + { + Path: "/inject-request-respond-status-from-http-request/418", // yeah...i'm a teapot, lol + Pattern: "/inject-request-respond-status-from-http-request/{status}", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(r *http.Request) (int, error) { + // inject *http.Request, respond with status only + return strconv.Atoi(r.PathValue("status")) + }) + }, + }, + { + Path: "/inject-path-value-respond-status/418", // yeah...i'm a teapot, lol + Pattern: "/inject-path-value-respond-status/{status}", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(pathValues magic.PathValue[StatusPathValues]) int { + return pathValues.Data.Status + }) + }, + }, + { + Path: "/inject-optional-path-value/rose", // name extracted + Pattern: "/inject-optional-path-value/{name}", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(pathValues magic.PathValue[HelloWorldParams]) magic.Json[HelloWorldResponse] { + name := "world" + if pathValues.Data.Name != nil { + name = *pathValues.Data.Name + } + return magic.Json[HelloWorldResponse]{ + Data: HelloWorldResponse{ + Greeting: "hello, " + name, + }, + } + }) + }, + }, + { + Path: "/inject-optional-path-value", // name fallback to world + Pattern: "/inject-optional-path-value", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(pathValues magic.PathValue[HelloWorldParams]) magic.Json[HelloWorldResponse] { + name := "world" + if pathValues.Data.Name != nil { + name = *pathValues.Data.Name + } + return magic.Json[HelloWorldResponse]{ + Data: HelloWorldResponse{ + Greeting: "hello, " + name, + }, + } + }) + }, + }, + { + Pattern: "/direct-response-writer-extractor", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(w http.ResponseWriter) { + w.WriteHeader(http.StatusNoContent) + }) + }, + }, + { + Pattern: "/hahaha-it-act-just-like-http-handlerfunc", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + }) + }, + }, + { + Pattern: "/websocket", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(ws *websocket.Conn) { + + }) + }, + NoCheck: true, + }, + { + Pattern: "/multiple-respond-extractor-should-panic-on-create", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(conn *websocket.Conn) magic.Json[HelloWorldResponse] { + return magic.Json[HelloWorldResponse]{} + }) + }, + ExpectPanicOnCreateHandler: true, + }, + // { + // Pattern: "/hello-query", + // CreateHandler: func() http.Handler { + // return MustMagicHandler(func(query magic.Query[HelloWorldParams]) magic.Json[HelloWorldResponse] { + // name := "world" + // if query.Data.Name != nil { + // name = *query.Data.Name + // } + // return magic.Json[HelloWorldResponse]{ + // Data: HelloWorldResponse{ + // Greeting: "hello, " + name, + // }, + // } + // }) + // }, + // }, + { + Pattern: "/test-query", + CreateHandler: func() http.Handler { + return MustMagicHandler(func(query magic.Query[MultipleParams]) { + log.Println(query.Data) + }) + }, + }, + } +) + +func TestActualFunctions(t *testing.T) { + mux := http.NewServeMux() + + t.Run("register-routes", func(t *testing.T) { + for _, c := range cases { + path := c.Path + if path == "" { + path = c.Pattern + } + t.Run(strings.TrimPrefix(path, "/"), func(t *testing.T) { + defer func() { + if o := recover(); (o != nil) != c.ExpectPanicOnCreateHandler { + if c.ExpectPanicOnCreateHandler { + t.Fatal("expect panic but panic did not occur") + } else { + t.Fatalf("panic not expected: %v", o) + } + } + }() + if handler := c.CreateHandler(); handler != nil { + mux.Handle(c.Pattern, handler) + } + }) + } + }) + + t.Run("check-actual-result", func(t *testing.T) { + server := httptest.NewTLSServer(mux) + client := server.Client() + for _, c := range cases { + if c.ExpectPanicOnCreateHandler || c.NoCheck { + continue + } + + path := c.Path + if path == "" { + path = c.Pattern + } + t.Run(strings.TrimPrefix(path, "/"), func(t *testing.T) { + var req *http.Request + + if c.MakeRequest != nil { + req = c.MakeRequest(server.URL) + } else { + request, err := http.NewRequest("GET", server.URL+path, nil) + if err != nil { + t.Fatal(err) + } + req = request + } + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + log.Println(resp) + if bytes, err := io.ReadAll(resp.Body); err == nil { + log.Println(string(bytes)) + } + }) + + } + + }) +} diff --git a/magic/extractors.go b/magic/extractors.go new file mode 100644 index 0000000..a07dd18 --- /dev/null +++ b/magic/extractors.go @@ -0,0 +1,169 @@ +package magic + +import ( + "log/slog" + "net/http" + "reflect" + + "git.jeffthecoder.xyz/public/lazyhandler/util" +) + +var ( + extractors = make(map[reflect.Type]func(*http.Request) (any, error)) + + extractorsTakesResponseWriter = make(map[reflect.Type]func(http.ResponseWriter, *http.Request) (any, error)) +) + +type FromRequest interface { + FromRequest(*http.Request) error +} + +type TakeResponseWriter interface { + TakeResponseWriter(http.ResponseWriter) +} + +func RegisterExtractor(o any, extractor func(*http.Request) (any, error)) { + if t, ok := o.(reflect.Type); ok { + slog.With("type", t.String()).Debug("extractor type registered with reflect.Type") + extractors[t] = extractor + } else if v, ok := o.(reflect.Value); ok { + slog.With("type", v.Type().String(), "value", v.Interface()).Debug("extractor type registered with reflect.Value") + + extractors[v.Type()] = extractor + } else { + t := reflect.TypeOf(o) + slog.With("type", t.String(), "value", o).Debug("extractor type registered with object") + + extractors[t] = extractor + } +} + +func RegisterExtractorGeneric[T any](extractor func(*http.Request) (any, error)) { + var pointerToT *T + RegisterExtractor(reflect.TypeOf(pointerToT).Elem(), extractor) +} + +func RegisterExtractorThatTakesResponseWriter(o any, extractor func(http.ResponseWriter, *http.Request) (any, error)) { + if t, ok := o.(reflect.Type); ok { + slog.With("type", t.String()).Debug("extractor type registered with reflect.Type") + extractorsTakesResponseWriter[t] = extractor + } else if v, ok := o.(reflect.Value); ok { + slog.With("type", v.Type().String(), "value", v.Interface()).Debug("extractor type registered with reflect.Value") + + extractorsTakesResponseWriter[v.Type()] = extractor + } else { + t := reflect.TypeOf(o) + slog.With("type", t.String(), "value", o).Debug("extractor type registered with object") + + extractorsTakesResponseWriter[t] = extractor + } +} + +func RegisterExtractorThatTakesResponseWriterGeneric[T any](extractor func(http.ResponseWriter, *http.Request) (any, error)) { + var pointerToT *T + RegisterExtractorThatTakesResponseWriter(reflect.TypeOf(pointerToT).Elem(), extractor) +} + +func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any, error), bool) { + _, isTakeResponseWriter := util.Implements[TakeResponseWriter](t) + if t.Kind() == reflect.Pointer { + _, ptrIsTakeResponseWriter := util.Implements[TakeResponseWriter](t.Elem()) + isTakeResponseWriter = isTakeResponseWriter || ptrIsTakeResponseWriter + } + if _, ok := util.Implements[FromRequest](t); ok { + if t.Kind() == reflect.Pointer { + // T.Implement(FromRequest) and T is Pointer + // create new actual T and call T.FromRequest on it + // if error happens, no error is returned. just return nil + return func(w http.ResponseWriter, r *http.Request) (any, error) { + // var t T + // return t, t.FromRequest(r) + z := reflect.New(t.Elem()) + + if isTakeResponseWriter { + z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)}) + } + + results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)}) + if err, ok := results[0].Interface().(error); ok && err != nil { + if errResponse, ok := err.(ErrorResponse); ok { + return reflect.Zero(t).Interface(), errResponse + } + return reflect.Zero(t).Interface(), nil + } + return z.Interface(), nil + }, isTakeResponseWriter + } + // T.Implement(FromRequest) and T is not Pointer + // create zero T and call T.FromRequest directly on it + return func(w http.ResponseWriter, r *http.Request) (any, error) { + // var t T + // return t, t.FromRequest(r) + z := reflect.Zero(t) + + if isTakeResponseWriter { + z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)}) + } + + results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)}) + if err, ok := results[0].Interface().(error); ok { + return z.Interface(), err + } + return z.Interface(), nil + }, isTakeResponseWriter + } else if v, ok := util.PointerImplements[FromRequest](t); ok { + return func(w http.ResponseWriter, r *http.Request) (any, error) { + // var t T + // return t, t.FromRequest(r) + z := reflect.New(v.Type().Elem()) + + if isTakeResponseWriter { + z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)}) + } + + results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)}) + if err := results[0].Interface(); err == nil { + return z.Elem().Interface(), nil + } + return z.Elem().Interface(), results[0].Interface().(error) + }, isTakeResponseWriter + } + + if extractor, ok := extractors[t]; ok { + return func(_ http.ResponseWriter, r *http.Request) (any, error) { + return extractor(r) + }, false + } + + if t.Kind() == reflect.Pointer { + if extractor, ok := extractors[t.Elem()]; ok { + return func(w http.ResponseWriter, r *http.Request) (any, error) { + v, err := extractor(r) + if err != nil { + return nil, err + } + return &v, nil + }, false + } + } + + if extractor, ok := extractorsTakesResponseWriter[t]; ok { + return func(w http.ResponseWriter, r *http.Request) (any, error) { + return extractor(w, r) + }, true + } + + if t.Kind() == reflect.Pointer { + if extractor, ok := extractorsTakesResponseWriter[t.Elem()]; ok { + return func(w http.ResponseWriter, r *http.Request) (any, error) { + v, err := extractor(w, r) + if err != nil { + return nil, err + } + return &v, nil + }, true + } + } + + return nil, isTakeResponseWriter +} diff --git a/magic/json.go b/magic/json.go new file mode 100644 index 0000000..e3f3ccf --- /dev/null +++ b/magic/json.go @@ -0,0 +1,63 @@ +package magic + +import ( + "encoding/json" + "io" + "net/http" +) + +type JsonDecodeError struct { + inner error +} + +func (err JsonDecodeError) Error() string { + return "failed to decode json: " + err.inner.Error() +} + +func (err JsonDecodeError) WriteResponse(rw http.ResponseWriter) { + rw.WriteHeader(http.StatusBadRequest) +} + +type Json[T any] struct { + Data T +} + +func NewJson[T any](data T) Json[T] { + return Json[T]{ + Data: data, + } +} + +func (data *Json[T]) FromRequest(request *http.Request) error { + bodyReader := request.Body + defer bodyReader.Close() + + request.Body = http.NoBody + + if err := json.NewDecoder(bodyReader).Decode(&data.Data); err != nil { + return JsonDecodeError{inner: err} + } + + return nil +} + +type counter struct { + counter int64 +} + +func (c *counter) Write(b []byte) (int, error) { + n := len(b) + c.counter += int64(n) + + return n, nil +} + +func (container Json[T]) Write(w io.Writer) (int64, error) { + counter := &counter{} + err := json.NewEncoder(io.MultiWriter(counter, w)).Encode(container.Data) + return counter.counter, err +} + +func (json Json[T]) WriteResponse(w http.ResponseWriter) { + json.Write(w) +} diff --git a/magic/map.go b/magic/map.go new file mode 100644 index 0000000..8393682 --- /dev/null +++ b/magic/map.go @@ -0,0 +1,12 @@ +package magic + +import ( + "encoding/json" + "net/http" +) + +type Map map[string]any + +func (m Map) RespondWriter(w http.ResponseWriter) { + json.NewEncoder(w).Encode(m) +} diff --git a/magic/path.go b/magic/path.go new file mode 100644 index 0000000..7068a0b --- /dev/null +++ b/magic/path.go @@ -0,0 +1,178 @@ +package magic + +import ( + "fmt" + "net/http" + "reflect" + "regexp" + "strconv" + "strings" +) + +var ( + cachedStructPathValueFields = make(map[reflect.Type][]structPathValueFields) + cachedUnsupportedPathValueFields = make(map[reflect.Type]reflect.StructField) +) + +type UnsupportedPathValueType struct { + inner reflect.Type +} + +func (err UnsupportedPathValueType) Error() string { + return "unsupported type for path value: " + err.inner.String() +} + +type structPathValueFields struct { + PathKey string + FieldKey string + + Type reflect.Type + Optional bool + Match *regexp.Regexp +} + +type PathValue[T any] struct { + Data T +} + +type PathValueNotFound struct { + Key string +} + +func (err PathValueNotFound) Error() string { + return "path value not found: " + err.Key +} + +type InvalidPathValueType struct { + Kind reflect.Kind + Value string +} + +func (err InvalidPathValueType) Error() string { + return fmt.Sprintf("invalid value for kind %v: %v", err.Kind, err.Value) +} + +type InvalidPathValue struct { + Match *regexp.Regexp + Value string +} + +func (err InvalidPathValue) Error() string { + return fmt.Sprintf("value not matched with regexp %v: %v", err.Match, err.Value) +} + +func findReflectPathValueFields(v reflect.Value) ([]structPathValueFields, error) { + if cached, ok := cachedStructPathValueFields[v.Type()]; ok { + return cached, nil + } + if cached, ok := cachedUnsupportedPathValueFields[v.Type()]; ok { + return nil, UnsupportedPathValueType{inner: cached.Type} + } + + var fields []structPathValueFields + if v.Kind() == reflect.Struct { + t := v.Type() + for idx := 0; idx < t.NumField(); idx++ { + field := t.Field(idx) + + pathKey := field.Name + var match *regexp.Regexp + + if tags, ok := field.Tag.Lookup("pathvalue"); ok { + parts := strings.Split(strings.TrimSpace(tags), ",") + if len(parts) == 0 { + // do nothing + } else { + if parts[0] != "" { + pathKey = parts[0] + } + + for _, part := range parts[1:] { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "match=") { + part := strings.TrimPrefix(part, "match=") + if unquoted, err := strconv.Unquote(part); err != nil { + return nil, err + } else { + part = unquoted + } + exp, err := regexp.Compile(part) + if err != nil { + return nil, err + } + match = exp + } + } + } + + } + + fieldType := field.Type + optional := false + + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + optional = true + } + + if pathValueConvertTable[fieldType.Kind()] == nil { + if _, ok := fieldType.MethodByName("FromString"); !ok { + cachedUnsupportedPathValueFields[t] = field + return nil, UnsupportedPathValueType{inner: field.Type} + } + } + + fields = append(fields, structPathValueFields{ + PathKey: pathKey, + FieldKey: field.Name, + + Type: fieldType, + Optional: optional, + + Match: match, + }) + } + cachedStructPathValueFields[t] = fields + } else { + return nil, UnsupportedPathValueType{inner: v.Type()} + } + + return fields, nil +} + +func (pathValue *PathValue[T]) FromRequest(r *http.Request) error { + v := reflect.ValueOf(pathValue).Elem().FieldByName("Data") + fields, err := findReflectPathValueFields(v) + if err != nil { + return err + } + + for _, field := range fields { + str := r.PathValue(field.PathKey) + if str == "" { + if !field.Optional { + return PathValueNotFound{Key: field.PathKey} + } + continue + } + if field.Match != nil && !field.Match.MatchString(str) { + return InvalidPathValue{Match: field.Match, Value: str} + } + convert := pathValueConvertTable[field.Type.Kind()] + result, err := convert(str) + if err != nil { + return err + } + + reflectValue := reflect.ValueOf(result) + + if field.Optional { + newReflectedValue := reflect.New(reflectValue.Type()) + newReflectedValue.Elem().Set(reflect.ValueOf(result)) + reflectValue = newReflectedValue + } + v.FieldByName(field.FieldKey).Set(reflectValue) + } + + return nil +} diff --git a/magic/path_query_values.go b/magic/path_query_values.go new file mode 100644 index 0000000..49e10c7 --- /dev/null +++ b/magic/path_query_values.go @@ -0,0 +1,111 @@ +package magic + +import ( + "reflect" + "strconv" + "strings" +) + +var ( + pathValueConvertTable = map[reflect.Kind]func(string) (any, error){ + reflect.String: func(s string) (any, error) { + return s, nil + }, + reflect.Bool: func(s string) (any, error) { + switch strings.ToLower(s) { + case "0", "false", "no", "n": + return false, nil + case "1", "true", "yes", "y": + return true, nil + } + return nil, InvalidPathValueType{Kind: reflect.Bool, Value: s} + }, + + reflect.Float64: func(s string) (any, error) { + if v, err := strconv.ParseFloat(s, 64); err != nil { + return nil, err + } else { + return v, nil + } + }, + reflect.Float32: func(s string) (any, error) { + if v, err := strconv.ParseFloat(s, 32); err != nil { + return nil, err + } else { + return float32(v), nil + } + }, + + reflect.Int: func(s string) (any, error) { + if v, err := strconv.ParseInt(s, 10, 64); err != nil { + return nil, err + } else { + return int(v), nil + } + }, + reflect.Int8: func(s string) (any, error) { + if v, err := strconv.ParseInt(s, 10, 8); err != nil { + return nil, err + } else { + return int8(v), nil + } + }, + reflect.Int16: func(s string) (any, error) { + if v, err := strconv.ParseInt(s, 10, 16); err != nil { + return nil, err + } else { + return int8(v), nil + } + }, + reflect.Int32: func(s string) (any, error) { + if v, err := strconv.ParseInt(s, 10, 32); err != nil { + return nil, err + } else { + return int8(v), nil + } + }, + reflect.Int64: func(s string) (any, error) { + if v, err := strconv.ParseInt(s, 10, 64); err != nil { + return nil, err + } else { + return int8(v), nil + } + }, + + reflect.Uint: func(s string) (any, error) { + if v, err := strconv.ParseUint(s, 10, 64); err != nil { + return nil, err + } else { + return uint(v), nil + } + }, + reflect.Uint8: func(s string) (any, error) { + if v, err := strconv.ParseUint(s, 10, 8); err != nil { + return nil, err + } else { + return uint8(v), nil + } + }, + reflect.Uint16: func(s string) (any, error) { + if v, err := strconv.ParseUint(s, 10, 16); err != nil { + return nil, err + } else { + return uint16(v), nil + } + }, + reflect.Uint32: func(s string) (any, error) { + if v, err := strconv.ParseUint(s, 10, 32); err != nil { + return nil, err + } else { + return uint32(v), nil + } + }, + reflect.Uint64: func(s string) (any, error) { + if v, err := strconv.ParseUint(s, 10, 64); err != nil { + return nil, err + } else { + return v, nil + } + }, + } +) diff --git a/magic/query.go b/magic/query.go new file mode 100644 index 0000000..7eb5669 --- /dev/null +++ b/magic/query.go @@ -0,0 +1,201 @@ +package magic + +import ( + "fmt" + "net/http" + "reflect" + "regexp" + "strconv" + "strings" +) + +var ( + cachedStructQueryFields = make(map[reflect.Type][]structQueryFields) + cachedUnsupportedQueryFields = make(map[reflect.Type]structQueryFields) +) + +type Query[T any] struct { + Data T +} + +type structQueryFields struct { + PathKey string + FieldKey string + + Type reflect.Type + MinQuery, MaxQuery int + + Match *regexp.Regexp +} + +type UnsupportedQueryType struct { + inner reflect.Type +} + +func (err UnsupportedQueryType) Error() string { + return "unsupported type for query: " + err.inner.String() +} + +type QueryValueNotFitIn struct { + Count int + + field structQueryFields +} + +func (err QueryValueNotFitIn) Error() string { + return fmt.Sprintf("query %v is expected to has %v~%v values, got %v", err.field.PathKey, err.field.MinQuery, err.field.MaxQuery, err.Count) +} + +type InvalidQueryValue struct { + Match *regexp.Regexp + Value string +} + +func (err InvalidQueryValue) Error() string { + return fmt.Sprintf("value not matched with regexp %v: %v", err.Match, err.Value) +} + +func findReflectQueryFields(v reflect.Value) ([]structQueryFields, error) { + if cached, ok := cachedStructQueryFields[v.Type()]; ok { + return cached, nil + } + if cached, ok := cachedUnsupportedQueryFields[v.Type()]; ok { + return nil, UnsupportedPathValueType{inner: cached.Type} + } + + var fields []structQueryFields + if v.Kind() == reflect.Struct { + t := v.Type() + for idx := 0; idx < t.NumField(); idx++ { + field := t.Field(idx) + + pathKey := field.Name + var match *regexp.Regexp + + if tags, ok := field.Tag.Lookup("query"); ok { + parts := strings.Split(strings.TrimSpace(tags), ",") + if len(parts) == 0 { + // do nothing + } else { + if parts[0] != "" { + pathKey = parts[0] + } + + for _, part := range parts[1:] { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "match=") { + part := strings.TrimPrefix(part, "match=") + if unquoted, err := strconv.Unquote(part); err != nil { + return nil, err + } else { + part = unquoted + } + exp, err := regexp.Compile(part) + if err != nil { + return nil, err + } + match = exp + } + } + } + } + + fieldType := field.Type + maxQuery, minQuery := 1, 1 + + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + minQuery = 0 + } else if fieldType.Kind() == reflect.Array { + maxQuery, minQuery = fieldType.Len(), fieldType.Len() + fieldType = fieldType.Elem() + } else if fieldType.Kind() == reflect.Slice { + fieldType = fieldType.Elem() + maxQuery, minQuery = -1, 0 + } + + if pathValueConvertTable[fieldType.Kind()] == nil { + if _, ok := fieldType.MethodByName("FromString"); !ok { + cachedUnsupportedPathValueFields[t] = field + return nil, UnsupportedPathValueType{inner: field.Type} + } + } + fields = append(fields, structQueryFields{ + PathKey: pathKey, + FieldKey: field.Name, + + Type: fieldType, + MaxQuery: maxQuery, + MinQuery: minQuery, + + Match: match, + }) + + cachedStructQueryFields[t] = fields + } + } else { + return nil, UnsupportedQueryType{inner: v.Type()} + } + + return fields, nil +} + +func (query *Query[T]) FromRequest(r *http.Request) error { + v := reflect.ValueOf(query).Elem().FieldByName("Data") + fields, err := findReflectQueryFields(v) + if err != nil { + return err + } + + q := r.URL.Query() + + for _, field := range fields { + values := q[field.PathKey] + if len(values) < field.MinQuery || (field.MaxQuery > 0 && len(values) > field.MaxQuery) { + return QueryValueNotFitIn{Count: len(values), field: field} + } + + structField := v.FieldByName(field.FieldKey) + + if structField.Kind() == reflect.Slice && structField.Len() < len(values) { + if structField.Cap() < len(values) { + structField.SetCap(len(values)) + } + structField.SetLen(len(values)) + } + + for idx, value := range values { + if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice { + if idx != len(values)-1 { + // only the last takes effect + continue + } + } + // check match regexp + if field.Match != nil && !field.Match.MatchString(value) { + return InvalidPathValue{Match: field.Match, Value: value} + } + + // convert to elem value + convert := pathValueConvertTable[field.Type.Kind()] + result, err := convert(value) + if err != nil { + return err + } + reflectValue := reflect.ValueOf(result) + + if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice { + if field.MinQuery == 0 { + newReflectValue := reflect.New(field.Type) + newReflectValue.Elem().Set(reflectValue) + reflectValue = newReflectValue + } + structField.Set(reflectValue) + } else { + structField.Index(idx).Set(reflectValue) + } + } + } + + return nil +} diff --git a/magic/request.go b/magic/request.go new file mode 100644 index 0000000..c7ff700 --- /dev/null +++ b/magic/request.go @@ -0,0 +1,15 @@ +package magic + +import ( + "context" + "net/http" +) + +func init() { + RegisterExtractorGeneric[*http.Request](func(r *http.Request) (any, error) { + return r, nil + }) + RegisterExtractorGeneric[context.Context](func(r *http.Request) (any, error) { + return r.Context(), nil + }) +} diff --git a/magic/response.go b/magic/response.go new file mode 100644 index 0000000..5cff84f --- /dev/null +++ b/magic/response.go @@ -0,0 +1,41 @@ +package magic + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/websocket" +) + +type RespondWriter interface { + WriteResponse(http.ResponseWriter) +} + +type RespondWriterFunc func(http.ResponseWriter) + +func (fn RespondWriterFunc) WriteResponse(w http.ResponseWriter) { + fn(w) +} + +type ErrorResponse interface { + error + WriteResponse(http.ResponseWriter) +} + +func init() { + RegisterExtractorThatTakesResponseWriterGeneric[http.ResponseWriter](func(w http.ResponseWriter, r *http.Request) (any, error) { + return w, nil + }) + + upgrader := &websocket.Upgrader{ + Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { + w.WriteHeader(status) + json.NewEncoder(w).Encode(Map{ + "error": reason.Error(), + }) + }, + } + RegisterExtractorThatTakesResponseWriterGeneric[*websocket.Conn](func(w http.ResponseWriter, r *http.Request) (any, error) { + return upgrader.Upgrade(w, r, nil) + }) +} diff --git a/magic/state.go b/magic/state.go new file mode 100644 index 0000000..443217a --- /dev/null +++ b/magic/state.go @@ -0,0 +1,13 @@ +package magic + +import "net/http" + +type State[T any] struct { + Data T +} + +func RegisterState[T any](data T) { + RegisterExtractor(State[T]{}, func(r *http.Request) (any, error) { + return State[T]{Data: data}, nil + }) +} diff --git a/middleware/httplog/log.go b/middleware/httplog/log.go new file mode 100644 index 0000000..fa7aa74 --- /dev/null +++ b/middleware/httplog/log.go @@ -0,0 +1,95 @@ +package httplog + +import ( + "bufio" + "context" + "errors" + "log/slog" + "net" + "net/http" + "time" + + "git.jeffthecoder.xyz/public/lazyhandler/magic" +) + +type ctxKey int + +const ( + loggerKey ctxKey = iota +) + +var ( + ErrNotHijackable = errors.New("not a hijacker") +) + +type Log struct { + LogStart bool + LogStartLevel slog.Level + + LogFinish bool + LogFinishLevel slog.Level +} + +type responseRecorder struct { + http.ResponseWriter + StatusCode int +} + +func (recorder *responseRecorder) WriteHeader(statusCode int) { + recorder.StatusCode = statusCode + + recorder.ResponseWriter.WriteHeader(statusCode) +} + +func (recorder *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := recorder.ResponseWriter.(http.Hijacker); ok { + recorder.ResponseWriter = nil + return hijacker.Hijack() + } + + return nil, nil, ErrNotHijackable +} + +var ( + _ http.Hijacker = &responseRecorder{} +) + +func (log Log) WrapHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + args := []any{ + "remote_addr", r.RemoteAddr, + "host", r.Host, + "path", r.URL.Path, + } + + startTime := time.Now() + if log.LogStart { + slog.With(append(args, "time", startTime)...).Log(r.Context(), log.LogStartLevel, "request") + } + recorder := &responseRecorder{ResponseWriter: w, StatusCode: 200} + next.ServeHTTP(recorder, r.WithContext(context.WithValue(r.Context(), loggerKey, slog.With(args...)))) + if log.LogFinish && recorder.ResponseWriter != nil { + finishTime := time.Now() + slog.With(append(args, "time", finishTime, "duration", finishTime.Sub(startTime), "status", recorder.StatusCode)...).Log(r.Context(), log.LogFinishLevel, "response") + } + }) +} + +func Logger(r *http.Request) *slog.Logger { + if logger, ok := r.Context().Value(loggerKey).(*slog.Logger); ok { + return logger.With("time", time.Now()) + } + + return slog.With( + "remote_addr", r.RemoteAddr, + "host", r.Host, + "path", r.URL.Path, + "time", time.Now(), + ) +} + +func RegisterExtractor() { + magic.RegisterExtractorGeneric[*slog.Logger](func(r *http.Request) (any, error) { + return Logger(r), nil + }) +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..63ecbcd --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,13 @@ +package middleware + +import "net/http" + +type Middleware interface { + WrapHandler(next http.Handler) http.Handler +} + +type WrapFunc func(next http.Handler) http.Handler + +func (wrap WrapFunc) WrapHandler(next http.Handler) http.Handler { + return wrap(next) +} diff --git a/middleware/recover/recover.go b/middleware/recover/recover.go new file mode 100644 index 0000000..9103654 --- /dev/null +++ b/middleware/recover/recover.go @@ -0,0 +1,44 @@ +package recover + +import ( + "encoding/json" + "net/http" + + "git.jeffthecoder.xyz/public/lazyhandler/middleware" + "git.jeffthecoder.xyz/public/lazyhandler/middleware/httplog" +) + +type PanicResponseFunc func(http.ResponseWriter, *http.Request, any) + +func recoverer(w http.ResponseWriter, r *http.Request, fn PanicResponseFunc) { + if err := recover(); err != nil { + httplog.Logger(r).With("panic", err).Error("request panicked") + + if fn != nil { + fn(w, r, err) + } else { + w.WriteHeader(http.StatusInternalServerError) + } + } +} + +func Recover(responseFunc PanicResponseFunc) middleware.Middleware { + return middleware.WrapFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer recoverer(w, r, responseFunc) + + next.ServeHTTP(w, r) + }) + }) +} + +func DebugPanicHandler(w http.ResponseWriter, r *http.Request, err any) { + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(map[string]any{ + "error": err, + }) +} + +func Debug() middleware.Middleware { + return Recover(DebugPanicHandler) +} diff --git a/middleware/session/session.go b/middleware/session/session.go new file mode 100644 index 0000000..eebf320 --- /dev/null +++ b/middleware/session/session.go @@ -0,0 +1,55 @@ +package session + +import ( + "context" + "errors" + "net/http" + + "git.jeffthecoder.xyz/public/lazyhandler/magic" + "git.jeffthecoder.xyz/public/lazyhandler/middleware" + + "github.com/go-session/session/v3" +) + +type ctxKey int + +const ( + sessionKey ctxKey = iota +) + +var ( + ErrSessionStoreNotInitialized = errors.New("no session store in request context") +) + +func Session(sessionStore session.ManagerStore) middleware.Middleware { + manager := session.NewManager( + session.SetStore(sessionStore), + ) + + return middleware.WrapFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + store, err := manager.Start(r.Context(), w, r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + defer store.Save() + + next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), sessionKey, store))) + }) + }) +} + +func GetSession(r *http.Request) (session.Store, error) { + if store, ok := r.Context().Value(sessionKey).(session.Store); ok { + return store, nil + } + return nil, ErrSessionStoreNotInitialized +} + +func RegisterExtractor() { + magic.RegisterExtractorGeneric[session.Store](func(r *http.Request) (any, error) { + store, err := GetSession(r) + return store, err + }) +} diff --git a/middleware/slash/slash.go b/middleware/slash/slash.go new file mode 100644 index 0000000..0c55b8d --- /dev/null +++ b/middleware/slash/slash.go @@ -0,0 +1,18 @@ +package slash + +import ( + "net/http" + + "git.jeffthecoder.xyz/public/lazyhandler/middleware" +) + +func StripSlash() middleware.Middleware { + return middleware.WrapFunc(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path[len(r.URL.Path)-1] == '/' && len(r.URL.Path) > 1 { + r.URL.Path = r.URL.Path[:len(r.URL.Path)-1] + } + next.ServeHTTP(w, r) + }) + }) +} diff --git a/middleware/use.go b/middleware/use.go new file mode 100644 index 0000000..6abea8e --- /dev/null +++ b/middleware/use.go @@ -0,0 +1,16 @@ +package middleware + +import ( + "net/http" + "slices" +) + +func Use(handler http.Handler, middlewares ...Middleware) http.Handler { + slices.Reverse(middlewares) + + for _, middleware := range middlewares { + handler = middleware.WrapHandler(handler) + } + + return handler +} diff --git a/util/implements.go b/util/implements.go new file mode 100644 index 0000000..811f76f --- /dev/null +++ b/util/implements.go @@ -0,0 +1,45 @@ +package util + +import ( + "reflect" +) + +// Implements get reflect.Zero(T) and (o).Implements(T) +func Implements[T any](o any) (T, bool) { + iface := reflect.TypeOf((*T)(nil)).Elem() + + reflectType := reflect.TypeOf(o) + + if alreadyReflectType, ok := o.(reflect.Type); ok { + reflectType = alreadyReflectType + o = reflect.Zero(reflectType).Interface() + } else if alreadyReflectValue, ok := o.(reflect.Value); ok { + reflectType = alreadyReflectValue.Type() + o = alreadyReflectValue.Interface() + } + + if reflectType.Implements(iface) { + if reflectType.Kind() == reflect.Interface { + return *reflect.New(iface).Interface().(*T), true + } + return o.(T), true + } else { + return *reflect.New(iface).Interface().(*T), false + } +} + +// PointerImplements get reflect.New(T) and (&o).Implements(T) +func PointerImplements[I any](o any) (reflect.Value, bool) { + iface := reflect.TypeOf((*I)(nil)).Elem() + + reflectType := reflect.TypeOf(&o) + if alreadyReflectType, ok := o.(reflect.Type); ok { + reflectType = reflect.PointerTo(alreadyReflectType) + } else if alreadyReflectValue, ok := o.(reflect.Value); ok { + reflectType = reflect.PointerTo(alreadyReflectValue.Type()) + } + + reflectValue := reflect.New(reflectType.Elem()) + + return reflectValue, reflectType.Implements(iface) +}