initial commit
This commit is contained in:
commit
61ffeeb3b8
10
go.mod
Normal file
10
go.mod
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
module git.jeffthecoder.xyz/public/lazyhandler
|
||||||
|
|
||||||
|
go 1.24.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/go-session/session/v3 v3.2.1
|
||||||
|
github.com/gorilla/websocket v1.5.3
|
||||||
|
)
|
||||||
|
|
||||||
|
require github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 // indirect
|
22
go.sum
Normal file
22
go.sum
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6 h1:FCLDGi1EmB7JzjVVYNZiqc/zAJj2BQ5M0lfkVOxbfs8=
|
||||||
|
github.com/bytedance/gopkg v0.0.0-20221122125632-68358b8ecec6/go.mod h1:5FoAH5xUHHCMDvQPy1rnj8moqLkLHFaDVBjHhcFwEi0=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/go-session/session/v3 v3.2.1 h1:APQf5JFW84+bhbqRjEZO8J+IppSgT1jMQTFI/XVyIFY=
|
||||||
|
github.com/go-session/session/v3 v3.2.1/go.mod h1:RftEBbyuzqkNCAxIrCLJe+rfBqB/4G11qxq9KYKrx4M=
|
||||||
|
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0=
|
||||||
|
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
|
||||||
|
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0=
|
||||||
|
github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo=
|
||||||
|
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
|
||||||
|
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
199
inject.go
Normal file
199
inject.go
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
package lazyhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/magic"
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ErrArgumentIsNotExtractable int
|
||||||
|
|
||||||
|
func (err ErrArgumentIsNotExtractable) Error() string {
|
||||||
|
return fmt.Sprintf("argument %v is not extractable", int(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrReturnValueNotConvertableIntoResponsePart int
|
||||||
|
|
||||||
|
func (err ErrReturnValueNotConvertableIntoResponsePart) Error() string {
|
||||||
|
return fmt.Sprintf("return value %v is not convertable into response part", int(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotAFunc = errors.New("not a function")
|
||||||
|
ErrDuplicateResponseWriterExtractor = errors.New("duplicate response writer extractor")
|
||||||
|
ErrResponseWriterCannotBeExtracted = errors.New("http.ResponseWriter extractor must not exists if function has return value")
|
||||||
|
)
|
||||||
|
|
||||||
|
func canConvert[T any](o any) bool {
|
||||||
|
t := reflect.TypeOf((*T)(nil)).Elem()
|
||||||
|
if reflectValue, ok := o.(reflect.Value); ok {
|
||||||
|
return reflectValue.CanConvert(t)
|
||||||
|
}
|
||||||
|
if reflectType, ok := o.(reflect.Type); ok {
|
||||||
|
return reflectType.ConvertibleTo(t)
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(o).CanConvert(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func MagicHandler(fn any) (http.Handler, error) {
|
||||||
|
funcValue := reflect.ValueOf(fn)
|
||||||
|
t := funcValue.Type()
|
||||||
|
|
||||||
|
if t.Kind() != reflect.Func {
|
||||||
|
return nil, ErrNotAFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
responseWriterExtracted := -1
|
||||||
|
|
||||||
|
extractors := make([]func(http.ResponseWriter, *http.Request) (any, error), t.NumIn())
|
||||||
|
for idx := 0; idx < t.NumIn(); idx++ {
|
||||||
|
in := t.In(idx)
|
||||||
|
|
||||||
|
extractor, isTakeResponseWriter := magic.GetExtractor(in)
|
||||||
|
if extractor == nil {
|
||||||
|
return nil, ErrArgumentIsNotExtractable(idx)
|
||||||
|
}
|
||||||
|
if isTakeResponseWriter {
|
||||||
|
if responseWriterExtracted >= 0 {
|
||||||
|
return nil, ErrDuplicateResponseWriterExtractor
|
||||||
|
}
|
||||||
|
responseWriterExtracted = idx
|
||||||
|
}
|
||||||
|
extractors[idx] = extractor
|
||||||
|
}
|
||||||
|
|
||||||
|
// http.ResponseWriter extractor must not exists if function has return value
|
||||||
|
if responseWriterExtracted >= 0 && t.NumOut() > 0 {
|
||||||
|
return nil, ErrResponseWriterCannotBeExtracted
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx := 0; idx < t.NumOut(); idx++ {
|
||||||
|
out := t.Out(idx)
|
||||||
|
|
||||||
|
// int(status) || string(body)
|
||||||
|
if out.Kind() == reflect.Int || out.Kind() == reflect.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// []byte(body)
|
||||||
|
if out.Kind() == reflect.Slice && out.Elem().Kind() == reflect.Uint8 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// [T] map[string]T(header)
|
||||||
|
if out.Kind() == reflect.Map && out.Key().Kind() == reflect.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := util.Implements[io.Reader](out); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := util.Implements[magic.RespondWriter](out); ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// last is error
|
||||||
|
if _, ok := util.Implements[error](out); ok && idx == t.NumOut()-1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrReturnValueNotConvertableIntoResponsePart(idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var closers []io.Closer
|
||||||
|
defer func() {
|
||||||
|
for _, closer := range closers {
|
||||||
|
closer.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
in := make([]reflect.Value, len(extractors))
|
||||||
|
for idx, extractor := range extractors {
|
||||||
|
v, err := extractor(w, r)
|
||||||
|
if err != nil {
|
||||||
|
if errResponse, ok := err.(magic.ErrorResponse); ok {
|
||||||
|
errResponse.WriteResponse(w)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(magic.Map{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if closer, ok := v.(io.Closer); ok {
|
||||||
|
defer closer.Close()
|
||||||
|
}
|
||||||
|
in[idx] = reflect.ValueOf(v)
|
||||||
|
}
|
||||||
|
values := funcValue.Call(in)
|
||||||
|
|
||||||
|
if numValues := len(values); numValues > 0 {
|
||||||
|
lastValue := values[numValues-1].Interface()
|
||||||
|
if err, isError := lastValue.(error); isError {
|
||||||
|
if err != nil {
|
||||||
|
if errResponse, ok := lastValue.(magic.ErrorResponse); ok {
|
||||||
|
errResponse.WriteResponse(w)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
json.NewEncoder(w).Encode(magic.Map{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
values = values[:numValues-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, value := range values {
|
||||||
|
obj := value.Interface()
|
||||||
|
|
||||||
|
if value.Kind() == reflect.Int {
|
||||||
|
w.WriteHeader(int(value.Int()))
|
||||||
|
} else if value.Kind() == reflect.String {
|
||||||
|
w.Write([]byte(value.String()))
|
||||||
|
} else if canConvert[[]byte](value) {
|
||||||
|
w.Write(value.Bytes())
|
||||||
|
} else if reader, ok := util.Implements[io.Reader](value); ok {
|
||||||
|
io.Copy(w, reader)
|
||||||
|
} else if responseWriter, ok := util.Implements[magic.RespondWriter](value); ok {
|
||||||
|
if responseWriter == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
responseWriter.WriteResponse(w)
|
||||||
|
} else if headers, ok := obj.(http.Header); ok { // not
|
||||||
|
for name, values := range headers {
|
||||||
|
for idx, value := range values {
|
||||||
|
if idx == 0 {
|
||||||
|
w.Header().Set(name, value)
|
||||||
|
} else {
|
||||||
|
w.Header().Add(name, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if value.Kind() == reflect.Map {
|
||||||
|
for _, key := range value.MapKeys() {
|
||||||
|
value := value.MapIndex(key).Interface()
|
||||||
|
w.Header().Set(key.String(), fmt.Sprint(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustMagicHandler(fn any) http.Handler {
|
||||||
|
if handler, err := MagicHandler(fn); err != nil {
|
||||||
|
panic("can not be converted to http.Handler: " + fmt.Sprintf("%T", fn) + ": " + err.Error())
|
||||||
|
} else {
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
356
inject_test.go
Normal file
356
inject_test.go
Normal file
@ -0,0 +1,356 @@
|
|||||||
|
package lazyhandler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/magic"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StatusPathValues struct {
|
||||||
|
Status int `pathvalue:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HelloWorldRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
type HelloWorldResponse struct {
|
||||||
|
Greeting string `json:"greeting"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HelloWorldParams struct {
|
||||||
|
Name *string `pathvalue:"name,match=" query:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MultipleParams struct {
|
||||||
|
Numbers []int
|
||||||
|
Strings []string
|
||||||
|
OptionalNumber *int
|
||||||
|
OptionalString *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotAFunc(t *testing.T) {
|
||||||
|
v, err := MagicHandler(1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expect not ok, get %v", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmptyFunc(t *testing.T) {
|
||||||
|
MagicHandler(func() {
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type SomeType int
|
||||||
|
|
||||||
|
func (obj SomeType) Func() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMethod(t *testing.T) {
|
||||||
|
obj := SomeType(0)
|
||||||
|
MagicHandler(obj.Func)
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
Path string
|
||||||
|
Pattern string
|
||||||
|
|
||||||
|
CreateHandler func() http.Handler
|
||||||
|
ExpectPanicOnCreateHandler bool
|
||||||
|
|
||||||
|
NoCheck bool
|
||||||
|
MakeRequest func(string) *http.Request
|
||||||
|
CheckResponse func(*http.Response) error
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
cases = []testCase{
|
||||||
|
{
|
||||||
|
Path: "/",
|
||||||
|
Pattern: "/{$}",
|
||||||
|
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func() (int, string, error) {
|
||||||
|
return 200, "hi", nil
|
||||||
|
})
|
||||||
|
},
|
||||||
|
MakeRequest: func(base string) *http.Request {
|
||||||
|
r, err := http.NewRequest("GET", base+"/", nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/error-checked-before-other-return-values",
|
||||||
|
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func() (int, string, error) {
|
||||||
|
// it checks errors first
|
||||||
|
return 200, "hi", fmt.Errorf("it is handled before other return values")
|
||||||
|
})
|
||||||
|
},
|
||||||
|
MakeRequest: func(base string) *http.Request {
|
||||||
|
r, err := http.NewRequest("GET", base+"/return-error", nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/return-nil",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func() (magic.RespondWriter, error) {
|
||||||
|
// no response body, 200
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/no-return",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func() {
|
||||||
|
// return nothing at all
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/no-error",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func() int {
|
||||||
|
// well, error is not needed
|
||||||
|
return http.StatusNoContent
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/respond-json",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func() magic.Json[map[string]any] {
|
||||||
|
return magic.Json[map[string]any]{
|
||||||
|
Data: map[string]any{
|
||||||
|
"message": "hello, world",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/extract-json",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(body magic.Json[HelloWorldRequest]) magic.Json[HelloWorldResponse] {
|
||||||
|
return magic.Json[HelloWorldResponse]{
|
||||||
|
Data: HelloWorldResponse{
|
||||||
|
Greeting: "hello, " + body.Data.Name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/extract-optional",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
// the parameter accepts a pointer to extractor, which indicates that it is optional
|
||||||
|
return MustMagicHandler(func(body *magic.Json[HelloWorldRequest]) magic.Json[HelloWorldResponse] {
|
||||||
|
name := "world"
|
||||||
|
if body != nil {
|
||||||
|
name = body.Data.Name
|
||||||
|
}
|
||||||
|
return magic.Json[HelloWorldResponse]{
|
||||||
|
Data: HelloWorldResponse{
|
||||||
|
Greeting: "hello, " + name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/inject-request-respond-status-from-http-request/418", // yeah...i'm a teapot, lol
|
||||||
|
Pattern: "/inject-request-respond-status-from-http-request/{status}",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(r *http.Request) (int, error) {
|
||||||
|
// inject *http.Request, respond with status only
|
||||||
|
return strconv.Atoi(r.PathValue("status"))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/inject-path-value-respond-status/418", // yeah...i'm a teapot, lol
|
||||||
|
Pattern: "/inject-path-value-respond-status/{status}",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(pathValues magic.PathValue[StatusPathValues]) int {
|
||||||
|
return pathValues.Data.Status
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/inject-optional-path-value/rose", // name extracted
|
||||||
|
Pattern: "/inject-optional-path-value/{name}",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(pathValues magic.PathValue[HelloWorldParams]) magic.Json[HelloWorldResponse] {
|
||||||
|
name := "world"
|
||||||
|
if pathValues.Data.Name != nil {
|
||||||
|
name = *pathValues.Data.Name
|
||||||
|
}
|
||||||
|
return magic.Json[HelloWorldResponse]{
|
||||||
|
Data: HelloWorldResponse{
|
||||||
|
Greeting: "hello, " + name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/inject-optional-path-value", // name fallback to world
|
||||||
|
Pattern: "/inject-optional-path-value",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(pathValues magic.PathValue[HelloWorldParams]) magic.Json[HelloWorldResponse] {
|
||||||
|
name := "world"
|
||||||
|
if pathValues.Data.Name != nil {
|
||||||
|
name = *pathValues.Data.Name
|
||||||
|
}
|
||||||
|
return magic.Json[HelloWorldResponse]{
|
||||||
|
Data: HelloWorldResponse{
|
||||||
|
Greeting: "hello, " + name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/direct-response-writer-extractor",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(w http.ResponseWriter) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/hahaha-it-act-just-like-http-handlerfunc",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/websocket",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(ws *websocket.Conn) {
|
||||||
|
|
||||||
|
})
|
||||||
|
},
|
||||||
|
NoCheck: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Pattern: "/multiple-respond-extractor-should-panic-on-create",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(conn *websocket.Conn) magic.Json[HelloWorldResponse] {
|
||||||
|
return magic.Json[HelloWorldResponse]{}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
ExpectPanicOnCreateHandler: true,
|
||||||
|
},
|
||||||
|
// {
|
||||||
|
// Pattern: "/hello-query",
|
||||||
|
// CreateHandler: func() http.Handler {
|
||||||
|
// return MustMagicHandler(func(query magic.Query[HelloWorldParams]) magic.Json[HelloWorldResponse] {
|
||||||
|
// name := "world"
|
||||||
|
// if query.Data.Name != nil {
|
||||||
|
// name = *query.Data.Name
|
||||||
|
// }
|
||||||
|
// return magic.Json[HelloWorldResponse]{
|
||||||
|
// Data: HelloWorldResponse{
|
||||||
|
// Greeting: "hello, " + name,
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
{
|
||||||
|
Pattern: "/test-query",
|
||||||
|
CreateHandler: func() http.Handler {
|
||||||
|
return MustMagicHandler(func(query magic.Query[MultipleParams]) {
|
||||||
|
log.Println(query.Data)
|
||||||
|
})
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestActualFunctions(t *testing.T) {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
|
t.Run("register-routes", func(t *testing.T) {
|
||||||
|
for _, c := range cases {
|
||||||
|
path := c.Path
|
||||||
|
if path == "" {
|
||||||
|
path = c.Pattern
|
||||||
|
}
|
||||||
|
t.Run(strings.TrimPrefix(path, "/"), func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if o := recover(); (o != nil) != c.ExpectPanicOnCreateHandler {
|
||||||
|
if c.ExpectPanicOnCreateHandler {
|
||||||
|
t.Fatal("expect panic but panic did not occur")
|
||||||
|
} else {
|
||||||
|
t.Fatalf("panic not expected: %v", o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if handler := c.CreateHandler(); handler != nil {
|
||||||
|
mux.Handle(c.Pattern, handler)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("check-actual-result", func(t *testing.T) {
|
||||||
|
server := httptest.NewTLSServer(mux)
|
||||||
|
client := server.Client()
|
||||||
|
for _, c := range cases {
|
||||||
|
if c.ExpectPanicOnCreateHandler || c.NoCheck {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
path := c.Path
|
||||||
|
if path == "" {
|
||||||
|
path = c.Pattern
|
||||||
|
}
|
||||||
|
t.Run(strings.TrimPrefix(path, "/"), func(t *testing.T) {
|
||||||
|
var req *http.Request
|
||||||
|
|
||||||
|
if c.MakeRequest != nil {
|
||||||
|
req = c.MakeRequest(server.URL)
|
||||||
|
} else {
|
||||||
|
request, err := http.NewRequest("GET", server.URL+path, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
req = request
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
log.Println(resp)
|
||||||
|
if bytes, err := io.ReadAll(resp.Body); err == nil {
|
||||||
|
log.Println(string(bytes))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
169
magic/extractors.go
Normal file
169
magic/extractors.go
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
package magic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
extractors = make(map[reflect.Type]func(*http.Request) (any, error))
|
||||||
|
|
||||||
|
extractorsTakesResponseWriter = make(map[reflect.Type]func(http.ResponseWriter, *http.Request) (any, error))
|
||||||
|
)
|
||||||
|
|
||||||
|
type FromRequest interface {
|
||||||
|
FromRequest(*http.Request) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TakeResponseWriter interface {
|
||||||
|
TakeResponseWriter(http.ResponseWriter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterExtractor(o any, extractor func(*http.Request) (any, error)) {
|
||||||
|
if t, ok := o.(reflect.Type); ok {
|
||||||
|
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
|
||||||
|
extractors[t] = extractor
|
||||||
|
} else if v, ok := o.(reflect.Value); ok {
|
||||||
|
slog.With("type", v.Type().String(), "value", v.Interface()).Debug("extractor type registered with reflect.Value")
|
||||||
|
|
||||||
|
extractors[v.Type()] = extractor
|
||||||
|
} else {
|
||||||
|
t := reflect.TypeOf(o)
|
||||||
|
slog.With("type", t.String(), "value", o).Debug("extractor type registered with object")
|
||||||
|
|
||||||
|
extractors[t] = extractor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterExtractorGeneric[T any](extractor func(*http.Request) (any, error)) {
|
||||||
|
var pointerToT *T
|
||||||
|
RegisterExtractor(reflect.TypeOf(pointerToT).Elem(), extractor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterExtractorThatTakesResponseWriter(o any, extractor func(http.ResponseWriter, *http.Request) (any, error)) {
|
||||||
|
if t, ok := o.(reflect.Type); ok {
|
||||||
|
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
|
||||||
|
extractorsTakesResponseWriter[t] = extractor
|
||||||
|
} else if v, ok := o.(reflect.Value); ok {
|
||||||
|
slog.With("type", v.Type().String(), "value", v.Interface()).Debug("extractor type registered with reflect.Value")
|
||||||
|
|
||||||
|
extractorsTakesResponseWriter[v.Type()] = extractor
|
||||||
|
} else {
|
||||||
|
t := reflect.TypeOf(o)
|
||||||
|
slog.With("type", t.String(), "value", o).Debug("extractor type registered with object")
|
||||||
|
|
||||||
|
extractorsTakesResponseWriter[t] = extractor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterExtractorThatTakesResponseWriterGeneric[T any](extractor func(http.ResponseWriter, *http.Request) (any, error)) {
|
||||||
|
var pointerToT *T
|
||||||
|
RegisterExtractorThatTakesResponseWriter(reflect.TypeOf(pointerToT).Elem(), extractor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any, error), bool) {
|
||||||
|
_, isTakeResponseWriter := util.Implements[TakeResponseWriter](t)
|
||||||
|
if t.Kind() == reflect.Pointer {
|
||||||
|
_, ptrIsTakeResponseWriter := util.Implements[TakeResponseWriter](t.Elem())
|
||||||
|
isTakeResponseWriter = isTakeResponseWriter || ptrIsTakeResponseWriter
|
||||||
|
}
|
||||||
|
if _, ok := util.Implements[FromRequest](t); ok {
|
||||||
|
if t.Kind() == reflect.Pointer {
|
||||||
|
// T.Implement(FromRequest) and T is Pointer
|
||||||
|
// create new actual T and call T.FromRequest on it
|
||||||
|
// if error happens, no error is returned. just return nil
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||||
|
// var t T
|
||||||
|
// return t, t.FromRequest(r)
|
||||||
|
z := reflect.New(t.Elem())
|
||||||
|
|
||||||
|
if isTakeResponseWriter {
|
||||||
|
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
|
||||||
|
}
|
||||||
|
|
||||||
|
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||||
|
if err, ok := results[0].Interface().(error); ok && err != nil {
|
||||||
|
if errResponse, ok := err.(ErrorResponse); ok {
|
||||||
|
return reflect.Zero(t).Interface(), errResponse
|
||||||
|
}
|
||||||
|
return reflect.Zero(t).Interface(), nil
|
||||||
|
}
|
||||||
|
return z.Interface(), nil
|
||||||
|
}, isTakeResponseWriter
|
||||||
|
}
|
||||||
|
// T.Implement(FromRequest) and T is not Pointer
|
||||||
|
// create zero T and call T.FromRequest directly on it
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||||
|
// var t T
|
||||||
|
// return t, t.FromRequest(r)
|
||||||
|
z := reflect.Zero(t)
|
||||||
|
|
||||||
|
if isTakeResponseWriter {
|
||||||
|
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
|
||||||
|
}
|
||||||
|
|
||||||
|
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||||
|
if err, ok := results[0].Interface().(error); ok {
|
||||||
|
return z.Interface(), err
|
||||||
|
}
|
||||||
|
return z.Interface(), nil
|
||||||
|
}, isTakeResponseWriter
|
||||||
|
} else if v, ok := util.PointerImplements[FromRequest](t); ok {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||||
|
// var t T
|
||||||
|
// return t, t.FromRequest(r)
|
||||||
|
z := reflect.New(v.Type().Elem())
|
||||||
|
|
||||||
|
if isTakeResponseWriter {
|
||||||
|
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
|
||||||
|
}
|
||||||
|
|
||||||
|
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||||
|
if err := results[0].Interface(); err == nil {
|
||||||
|
return z.Elem().Interface(), nil
|
||||||
|
}
|
||||||
|
return z.Elem().Interface(), results[0].Interface().(error)
|
||||||
|
}, isTakeResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
if extractor, ok := extractors[t]; ok {
|
||||||
|
return func(_ http.ResponseWriter, r *http.Request) (any, error) {
|
||||||
|
return extractor(r)
|
||||||
|
}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.Kind() == reflect.Pointer {
|
||||||
|
if extractor, ok := extractors[t.Elem()]; ok {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||||
|
v, err := extractor(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &v, nil
|
||||||
|
}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if extractor, ok := extractorsTakesResponseWriter[t]; ok {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||||
|
return extractor(w, r)
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.Kind() == reflect.Pointer {
|
||||||
|
if extractor, ok := extractorsTakesResponseWriter[t.Elem()]; ok {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||||
|
v, err := extractor(w, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &v, nil
|
||||||
|
}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, isTakeResponseWriter
|
||||||
|
}
|
63
magic/json.go
Normal file
63
magic/json.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package magic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JsonDecodeError struct {
|
||||||
|
inner error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err JsonDecodeError) Error() string {
|
||||||
|
return "failed to decode json: " + err.inner.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err JsonDecodeError) WriteResponse(rw http.ResponseWriter) {
|
||||||
|
rw.WriteHeader(http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Json[T any] struct {
|
||||||
|
Data T
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewJson[T any](data T) Json[T] {
|
||||||
|
return Json[T]{
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (data *Json[T]) FromRequest(request *http.Request) error {
|
||||||
|
bodyReader := request.Body
|
||||||
|
defer bodyReader.Close()
|
||||||
|
|
||||||
|
request.Body = http.NoBody
|
||||||
|
|
||||||
|
if err := json.NewDecoder(bodyReader).Decode(&data.Data); err != nil {
|
||||||
|
return JsonDecodeError{inner: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type counter struct {
|
||||||
|
counter int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *counter) Write(b []byte) (int, error) {
|
||||||
|
n := len(b)
|
||||||
|
c.counter += int64(n)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (container Json[T]) Write(w io.Writer) (int64, error) {
|
||||||
|
counter := &counter{}
|
||||||
|
err := json.NewEncoder(io.MultiWriter(counter, w)).Encode(container.Data)
|
||||||
|
return counter.counter, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (json Json[T]) WriteResponse(w http.ResponseWriter) {
|
||||||
|
json.Write(w)
|
||||||
|
}
|
12
magic/map.go
Normal file
12
magic/map.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package magic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Map map[string]any
|
||||||
|
|
||||||
|
func (m Map) RespondWriter(w http.ResponseWriter) {
|
||||||
|
json.NewEncoder(w).Encode(m)
|
||||||
|
}
|
178
magic/path.go
Normal file
178
magic/path.go
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
package magic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cachedStructPathValueFields = make(map[reflect.Type][]structPathValueFields)
|
||||||
|
cachedUnsupportedPathValueFields = make(map[reflect.Type]reflect.StructField)
|
||||||
|
)
|
||||||
|
|
||||||
|
type UnsupportedPathValueType struct {
|
||||||
|
inner reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err UnsupportedPathValueType) Error() string {
|
||||||
|
return "unsupported type for path value: " + err.inner.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
type structPathValueFields struct {
|
||||||
|
PathKey string
|
||||||
|
FieldKey string
|
||||||
|
|
||||||
|
Type reflect.Type
|
||||||
|
Optional bool
|
||||||
|
Match *regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
type PathValue[T any] struct {
|
||||||
|
Data T
|
||||||
|
}
|
||||||
|
|
||||||
|
type PathValueNotFound struct {
|
||||||
|
Key string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err PathValueNotFound) Error() string {
|
||||||
|
return "path value not found: " + err.Key
|
||||||
|
}
|
||||||
|
|
||||||
|
type InvalidPathValueType struct {
|
||||||
|
Kind reflect.Kind
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err InvalidPathValueType) Error() string {
|
||||||
|
return fmt.Sprintf("invalid value for kind %v: %v", err.Kind, err.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
type InvalidPathValue struct {
|
||||||
|
Match *regexp.Regexp
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err InvalidPathValue) Error() string {
|
||||||
|
return fmt.Sprintf("value not matched with regexp %v: %v", err.Match, err.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func findReflectPathValueFields(v reflect.Value) ([]structPathValueFields, error) {
|
||||||
|
if cached, ok := cachedStructPathValueFields[v.Type()]; ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
if cached, ok := cachedUnsupportedPathValueFields[v.Type()]; ok {
|
||||||
|
return nil, UnsupportedPathValueType{inner: cached.Type}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fields []structPathValueFields
|
||||||
|
if v.Kind() == reflect.Struct {
|
||||||
|
t := v.Type()
|
||||||
|
for idx := 0; idx < t.NumField(); idx++ {
|
||||||
|
field := t.Field(idx)
|
||||||
|
|
||||||
|
pathKey := field.Name
|
||||||
|
var match *regexp.Regexp
|
||||||
|
|
||||||
|
if tags, ok := field.Tag.Lookup("pathvalue"); ok {
|
||||||
|
parts := strings.Split(strings.TrimSpace(tags), ",")
|
||||||
|
if len(parts) == 0 {
|
||||||
|
// do nothing
|
||||||
|
} else {
|
||||||
|
if parts[0] != "" {
|
||||||
|
pathKey = parts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range parts[1:] {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "match=") {
|
||||||
|
part := strings.TrimPrefix(part, "match=")
|
||||||
|
if unquoted, err := strconv.Unquote(part); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
part = unquoted
|
||||||
|
}
|
||||||
|
exp, err := regexp.Compile(part)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
match = exp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldType := field.Type
|
||||||
|
optional := false
|
||||||
|
|
||||||
|
if fieldType.Kind() == reflect.Ptr {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
optional = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if pathValueConvertTable[fieldType.Kind()] == nil {
|
||||||
|
if _, ok := fieldType.MethodByName("FromString"); !ok {
|
||||||
|
cachedUnsupportedPathValueFields[t] = field
|
||||||
|
return nil, UnsupportedPathValueType{inner: field.Type}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fields = append(fields, structPathValueFields{
|
||||||
|
PathKey: pathKey,
|
||||||
|
FieldKey: field.Name,
|
||||||
|
|
||||||
|
Type: fieldType,
|
||||||
|
Optional: optional,
|
||||||
|
|
||||||
|
Match: match,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
cachedStructPathValueFields[t] = fields
|
||||||
|
} else {
|
||||||
|
return nil, UnsupportedPathValueType{inner: v.Type()}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pathValue *PathValue[T]) FromRequest(r *http.Request) error {
|
||||||
|
v := reflect.ValueOf(pathValue).Elem().FieldByName("Data")
|
||||||
|
fields, err := findReflectPathValueFields(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range fields {
|
||||||
|
str := r.PathValue(field.PathKey)
|
||||||
|
if str == "" {
|
||||||
|
if !field.Optional {
|
||||||
|
return PathValueNotFound{Key: field.PathKey}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if field.Match != nil && !field.Match.MatchString(str) {
|
||||||
|
return InvalidPathValue{Match: field.Match, Value: str}
|
||||||
|
}
|
||||||
|
convert := pathValueConvertTable[field.Type.Kind()]
|
||||||
|
result, err := convert(str)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectValue := reflect.ValueOf(result)
|
||||||
|
|
||||||
|
if field.Optional {
|
||||||
|
newReflectedValue := reflect.New(reflectValue.Type())
|
||||||
|
newReflectedValue.Elem().Set(reflect.ValueOf(result))
|
||||||
|
reflectValue = newReflectedValue
|
||||||
|
}
|
||||||
|
v.FieldByName(field.FieldKey).Set(reflectValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
111
magic/path_query_values.go
Normal file
111
magic/path_query_values.go
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
package magic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
pathValueConvertTable = map[reflect.Kind]func(string) (any, error){
|
||||||
|
reflect.String: func(s string) (any, error) {
|
||||||
|
return s, nil
|
||||||
|
},
|
||||||
|
reflect.Bool: func(s string) (any, error) {
|
||||||
|
switch strings.ToLower(s) {
|
||||||
|
case "0", "false", "no", "n":
|
||||||
|
return false, nil
|
||||||
|
case "1", "true", "yes", "y":
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return nil, InvalidPathValueType{Kind: reflect.Bool, Value: s}
|
||||||
|
},
|
||||||
|
|
||||||
|
reflect.Float64: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseFloat(s, 64); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reflect.Float32: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseFloat(s, 32); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return float32(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
reflect.Int: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseInt(s, 10, 64); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return int(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reflect.Int8: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseInt(s, 10, 8); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return int8(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reflect.Int16: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseInt(s, 10, 16); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return int8(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reflect.Int32: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseInt(s, 10, 32); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return int8(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reflect.Int64: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseInt(s, 10, 64); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return int8(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
reflect.Uint: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseUint(s, 10, 64); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return uint(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reflect.Uint8: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseUint(s, 10, 8); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return uint8(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reflect.Uint16: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseUint(s, 10, 16); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return uint16(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reflect.Uint32: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseUint(s, 10, 32); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return uint32(v), nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
reflect.Uint64: func(s string) (any, error) {
|
||||||
|
if v, err := strconv.ParseUint(s, 10, 64); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
201
magic/query.go
Normal file
201
magic/query.go
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
package magic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cachedStructQueryFields = make(map[reflect.Type][]structQueryFields)
|
||||||
|
cachedUnsupportedQueryFields = make(map[reflect.Type]structQueryFields)
|
||||||
|
)
|
||||||
|
|
||||||
|
type Query[T any] struct {
|
||||||
|
Data T
|
||||||
|
}
|
||||||
|
|
||||||
|
type structQueryFields struct {
|
||||||
|
PathKey string
|
||||||
|
FieldKey string
|
||||||
|
|
||||||
|
Type reflect.Type
|
||||||
|
MinQuery, MaxQuery int
|
||||||
|
|
||||||
|
Match *regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnsupportedQueryType struct {
|
||||||
|
inner reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err UnsupportedQueryType) Error() string {
|
||||||
|
return "unsupported type for query: " + err.inner.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryValueNotFitIn struct {
|
||||||
|
Count int
|
||||||
|
|
||||||
|
field structQueryFields
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err QueryValueNotFitIn) Error() string {
|
||||||
|
return fmt.Sprintf("query %v is expected to has %v~%v values, got %v", err.field.PathKey, err.field.MinQuery, err.field.MaxQuery, err.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
type InvalidQueryValue struct {
|
||||||
|
Match *regexp.Regexp
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err InvalidQueryValue) Error() string {
|
||||||
|
return fmt.Sprintf("value not matched with regexp %v: %v", err.Match, err.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func findReflectQueryFields(v reflect.Value) ([]structQueryFields, error) {
|
||||||
|
if cached, ok := cachedStructQueryFields[v.Type()]; ok {
|
||||||
|
return cached, nil
|
||||||
|
}
|
||||||
|
if cached, ok := cachedUnsupportedQueryFields[v.Type()]; ok {
|
||||||
|
return nil, UnsupportedPathValueType{inner: cached.Type}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fields []structQueryFields
|
||||||
|
if v.Kind() == reflect.Struct {
|
||||||
|
t := v.Type()
|
||||||
|
for idx := 0; idx < t.NumField(); idx++ {
|
||||||
|
field := t.Field(idx)
|
||||||
|
|
||||||
|
pathKey := field.Name
|
||||||
|
var match *regexp.Regexp
|
||||||
|
|
||||||
|
if tags, ok := field.Tag.Lookup("query"); ok {
|
||||||
|
parts := strings.Split(strings.TrimSpace(tags), ",")
|
||||||
|
if len(parts) == 0 {
|
||||||
|
// do nothing
|
||||||
|
} else {
|
||||||
|
if parts[0] != "" {
|
||||||
|
pathKey = parts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, part := range parts[1:] {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "match=") {
|
||||||
|
part := strings.TrimPrefix(part, "match=")
|
||||||
|
if unquoted, err := strconv.Unquote(part); err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
part = unquoted
|
||||||
|
}
|
||||||
|
exp, err := regexp.Compile(part)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
match = exp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldType := field.Type
|
||||||
|
maxQuery, minQuery := 1, 1
|
||||||
|
|
||||||
|
if fieldType.Kind() == reflect.Ptr {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
minQuery = 0
|
||||||
|
} else if fieldType.Kind() == reflect.Array {
|
||||||
|
maxQuery, minQuery = fieldType.Len(), fieldType.Len()
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
} else if fieldType.Kind() == reflect.Slice {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
maxQuery, minQuery = -1, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if pathValueConvertTable[fieldType.Kind()] == nil {
|
||||||
|
if _, ok := fieldType.MethodByName("FromString"); !ok {
|
||||||
|
cachedUnsupportedPathValueFields[t] = field
|
||||||
|
return nil, UnsupportedPathValueType{inner: field.Type}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fields = append(fields, structQueryFields{
|
||||||
|
PathKey: pathKey,
|
||||||
|
FieldKey: field.Name,
|
||||||
|
|
||||||
|
Type: fieldType,
|
||||||
|
MaxQuery: maxQuery,
|
||||||
|
MinQuery: minQuery,
|
||||||
|
|
||||||
|
Match: match,
|
||||||
|
})
|
||||||
|
|
||||||
|
cachedStructQueryFields[t] = fields
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, UnsupportedQueryType{inner: v.Type()}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (query *Query[T]) FromRequest(r *http.Request) error {
|
||||||
|
v := reflect.ValueOf(query).Elem().FieldByName("Data")
|
||||||
|
fields, err := findReflectQueryFields(v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
q := r.URL.Query()
|
||||||
|
|
||||||
|
for _, field := range fields {
|
||||||
|
values := q[field.PathKey]
|
||||||
|
if len(values) < field.MinQuery || (field.MaxQuery > 0 && len(values) > field.MaxQuery) {
|
||||||
|
return QueryValueNotFitIn{Count: len(values), field: field}
|
||||||
|
}
|
||||||
|
|
||||||
|
structField := v.FieldByName(field.FieldKey)
|
||||||
|
|
||||||
|
if structField.Kind() == reflect.Slice && structField.Len() < len(values) {
|
||||||
|
if structField.Cap() < len(values) {
|
||||||
|
structField.SetCap(len(values))
|
||||||
|
}
|
||||||
|
structField.SetLen(len(values))
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, value := range values {
|
||||||
|
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
|
||||||
|
if idx != len(values)-1 {
|
||||||
|
// only the last takes effect
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// check match regexp
|
||||||
|
if field.Match != nil && !field.Match.MatchString(value) {
|
||||||
|
return InvalidPathValue{Match: field.Match, Value: value}
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert to elem value
|
||||||
|
convert := pathValueConvertTable[field.Type.Kind()]
|
||||||
|
result, err := convert(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reflectValue := reflect.ValueOf(result)
|
||||||
|
|
||||||
|
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
|
||||||
|
if field.MinQuery == 0 {
|
||||||
|
newReflectValue := reflect.New(field.Type)
|
||||||
|
newReflectValue.Elem().Set(reflectValue)
|
||||||
|
reflectValue = newReflectValue
|
||||||
|
}
|
||||||
|
structField.Set(reflectValue)
|
||||||
|
} else {
|
||||||
|
structField.Index(idx).Set(reflectValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
15
magic/request.go
Normal file
15
magic/request.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package magic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterExtractorGeneric[*http.Request](func(r *http.Request) (any, error) {
|
||||||
|
return r, nil
|
||||||
|
})
|
||||||
|
RegisterExtractorGeneric[context.Context](func(r *http.Request) (any, error) {
|
||||||
|
return r.Context(), nil
|
||||||
|
})
|
||||||
|
}
|
41
magic/response.go
Normal file
41
magic/response.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package magic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RespondWriter interface {
|
||||||
|
WriteResponse(http.ResponseWriter)
|
||||||
|
}
|
||||||
|
|
||||||
|
type RespondWriterFunc func(http.ResponseWriter)
|
||||||
|
|
||||||
|
func (fn RespondWriterFunc) WriteResponse(w http.ResponseWriter) {
|
||||||
|
fn(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorResponse interface {
|
||||||
|
error
|
||||||
|
WriteResponse(http.ResponseWriter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterExtractorThatTakesResponseWriterGeneric[http.ResponseWriter](func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||||
|
return w, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
upgrader := &websocket.Upgrader{
|
||||||
|
Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
|
||||||
|
w.WriteHeader(status)
|
||||||
|
json.NewEncoder(w).Encode(Map{
|
||||||
|
"error": reason.Error(),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
}
|
||||||
|
RegisterExtractorThatTakesResponseWriterGeneric[*websocket.Conn](func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||||
|
return upgrader.Upgrade(w, r, nil)
|
||||||
|
})
|
||||||
|
}
|
13
magic/state.go
Normal file
13
magic/state.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package magic
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type State[T any] struct {
|
||||||
|
Data T
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterState[T any](data T) {
|
||||||
|
RegisterExtractor(State[T]{}, func(r *http.Request) (any, error) {
|
||||||
|
return State[T]{Data: data}, nil
|
||||||
|
})
|
||||||
|
}
|
95
middleware/httplog/log.go
Normal file
95
middleware/httplog/log.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package httplog
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/magic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ctxKey int
|
||||||
|
|
||||||
|
const (
|
||||||
|
loggerKey ctxKey = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotHijackable = errors.New("not a hijacker")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Log struct {
|
||||||
|
LogStart bool
|
||||||
|
LogStartLevel slog.Level
|
||||||
|
|
||||||
|
LogFinish bool
|
||||||
|
LogFinishLevel slog.Level
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseRecorder struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
StatusCode int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (recorder *responseRecorder) WriteHeader(statusCode int) {
|
||||||
|
recorder.StatusCode = statusCode
|
||||||
|
|
||||||
|
recorder.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (recorder *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacker, ok := recorder.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
recorder.ResponseWriter = nil
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, ErrNotHijackable
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ http.Hijacker = &responseRecorder{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (log Log) WrapHandler(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
args := []any{
|
||||||
|
"remote_addr", r.RemoteAddr,
|
||||||
|
"host", r.Host,
|
||||||
|
"path", r.URL.Path,
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
if log.LogStart {
|
||||||
|
slog.With(append(args, "time", startTime)...).Log(r.Context(), log.LogStartLevel, "request")
|
||||||
|
}
|
||||||
|
recorder := &responseRecorder{ResponseWriter: w, StatusCode: 200}
|
||||||
|
next.ServeHTTP(recorder, r.WithContext(context.WithValue(r.Context(), loggerKey, slog.With(args...))))
|
||||||
|
if log.LogFinish && recorder.ResponseWriter != nil {
|
||||||
|
finishTime := time.Now()
|
||||||
|
slog.With(append(args, "time", finishTime, "duration", finishTime.Sub(startTime), "status", recorder.StatusCode)...).Log(r.Context(), log.LogFinishLevel, "response")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Logger(r *http.Request) *slog.Logger {
|
||||||
|
if logger, ok := r.Context().Value(loggerKey).(*slog.Logger); ok {
|
||||||
|
return logger.With("time", time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
return slog.With(
|
||||||
|
"remote_addr", r.RemoteAddr,
|
||||||
|
"host", r.Host,
|
||||||
|
"path", r.URL.Path,
|
||||||
|
"time", time.Now(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterExtractor() {
|
||||||
|
magic.RegisterExtractorGeneric[*slog.Logger](func(r *http.Request) (any, error) {
|
||||||
|
return Logger(r), nil
|
||||||
|
})
|
||||||
|
}
|
13
middleware/middleware.go
Normal file
13
middleware/middleware.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type Middleware interface {
|
||||||
|
WrapHandler(next http.Handler) http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
type WrapFunc func(next http.Handler) http.Handler
|
||||||
|
|
||||||
|
func (wrap WrapFunc) WrapHandler(next http.Handler) http.Handler {
|
||||||
|
return wrap(next)
|
||||||
|
}
|
44
middleware/recover/recover.go
Normal file
44
middleware/recover/recover.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package recover
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/middleware"
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/middleware/httplog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PanicResponseFunc func(http.ResponseWriter, *http.Request, any)
|
||||||
|
|
||||||
|
func recoverer(w http.ResponseWriter, r *http.Request, fn PanicResponseFunc) {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
httplog.Logger(r).With("panic", err).Error("request panicked")
|
||||||
|
|
||||||
|
if fn != nil {
|
||||||
|
fn(w, r, err)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Recover(responseFunc PanicResponseFunc) middleware.Middleware {
|
||||||
|
return middleware.WrapFunc(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer recoverer(w, r, responseFunc)
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func DebugPanicHandler(w http.ResponseWriter, r *http.Request, err any) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"error": err,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Debug() middleware.Middleware {
|
||||||
|
return Recover(DebugPanicHandler)
|
||||||
|
}
|
55
middleware/session/session.go
Normal file
55
middleware/session/session.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/magic"
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/middleware"
|
||||||
|
|
||||||
|
"github.com/go-session/session/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ctxKey int
|
||||||
|
|
||||||
|
const (
|
||||||
|
sessionKey ctxKey = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrSessionStoreNotInitialized = errors.New("no session store in request context")
|
||||||
|
)
|
||||||
|
|
||||||
|
func Session(sessionStore session.ManagerStore) middleware.Middleware {
|
||||||
|
manager := session.NewManager(
|
||||||
|
session.SetStore(sessionStore),
|
||||||
|
)
|
||||||
|
|
||||||
|
return middleware.WrapFunc(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
store, err := manager.Start(r.Context(), w, r)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer store.Save()
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), sessionKey, store)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSession(r *http.Request) (session.Store, error) {
|
||||||
|
if store, ok := r.Context().Value(sessionKey).(session.Store); ok {
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
return nil, ErrSessionStoreNotInitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterExtractor() {
|
||||||
|
magic.RegisterExtractorGeneric[session.Store](func(r *http.Request) (any, error) {
|
||||||
|
store, err := GetSession(r)
|
||||||
|
return store, err
|
||||||
|
})
|
||||||
|
}
|
18
middleware/slash/slash.go
Normal file
18
middleware/slash/slash.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
package slash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.jeffthecoder.xyz/public/lazyhandler/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StripSlash() middleware.Middleware {
|
||||||
|
return middleware.WrapFunc(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path[len(r.URL.Path)-1] == '/' && len(r.URL.Path) > 1 {
|
||||||
|
r.URL.Path = r.URL.Path[:len(r.URL.Path)-1]
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
16
middleware/use.go
Normal file
16
middleware/use.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Use(handler http.Handler, middlewares ...Middleware) http.Handler {
|
||||||
|
slices.Reverse(middlewares)
|
||||||
|
|
||||||
|
for _, middleware := range middlewares {
|
||||||
|
handler = middleware.WrapHandler(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler
|
||||||
|
}
|
45
util/implements.go
Normal file
45
util/implements.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Implements get reflect.Zero(T) and (o).Implements(T)
|
||||||
|
func Implements[T any](o any) (T, bool) {
|
||||||
|
iface := reflect.TypeOf((*T)(nil)).Elem()
|
||||||
|
|
||||||
|
reflectType := reflect.TypeOf(o)
|
||||||
|
|
||||||
|
if alreadyReflectType, ok := o.(reflect.Type); ok {
|
||||||
|
reflectType = alreadyReflectType
|
||||||
|
o = reflect.Zero(reflectType).Interface()
|
||||||
|
} else if alreadyReflectValue, ok := o.(reflect.Value); ok {
|
||||||
|
reflectType = alreadyReflectValue.Type()
|
||||||
|
o = alreadyReflectValue.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflectType.Implements(iface) {
|
||||||
|
if reflectType.Kind() == reflect.Interface {
|
||||||
|
return *reflect.New(iface).Interface().(*T), true
|
||||||
|
}
|
||||||
|
return o.(T), true
|
||||||
|
} else {
|
||||||
|
return *reflect.New(iface).Interface().(*T), false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PointerImplements get reflect.New(T) and (&o).Implements(T)
|
||||||
|
func PointerImplements[I any](o any) (reflect.Value, bool) {
|
||||||
|
iface := reflect.TypeOf((*I)(nil)).Elem()
|
||||||
|
|
||||||
|
reflectType := reflect.TypeOf(&o)
|
||||||
|
if alreadyReflectType, ok := o.(reflect.Type); ok {
|
||||||
|
reflectType = reflect.PointerTo(alreadyReflectType)
|
||||||
|
} else if alreadyReflectValue, ok := o.(reflect.Value); ok {
|
||||||
|
reflectType = reflect.PointerTo(alreadyReflectValue.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectValue := reflect.New(reflectType.Elem())
|
||||||
|
|
||||||
|
return reflectValue, reflectType.Implements(iface)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user