sync from project
This commit is contained in:
37
magic/autodecode.go
Normal file
37
magic/autodecode.go
Normal file
@ -0,0 +1,37 @@
|
||||
package magic
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ErrUknownContentType string
|
||||
|
||||
func (err ErrUknownContentType) Error() string {
|
||||
return "unknown content-type: " + string(err)
|
||||
}
|
||||
|
||||
type AutoDecode[T any] struct {
|
||||
Data T
|
||||
}
|
||||
|
||||
func (d *AutoDecode[T]) FromRequest(r *http.Request) error {
|
||||
contentType, _, _ := strings.Cut(r.Header.Get("Content-Type"), ";")
|
||||
switch contentType {
|
||||
case "application/json":
|
||||
decoder := &Json[T]{}
|
||||
if err := decoder.FromRequest(r); err != nil {
|
||||
return err
|
||||
}
|
||||
d.Data = decoder.Data
|
||||
case "multipart/form-data", "application/x-www-form-urlencoded":
|
||||
decoder := &Form[T]{}
|
||||
if err := decoder.FromRequest(r); err != nil {
|
||||
return err
|
||||
}
|
||||
d.Data = decoder.Data
|
||||
}
|
||||
return ErrUknownContentType(contentType)
|
||||
}
|
||||
|
||||
func (d AutoDecode[T]) TakeRequestBody() {}
|
@ -4,6 +4,8 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -54,21 +56,21 @@ var (
|
||||
if v, err := strconv.ParseInt(s, 10, 16); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return int8(v), nil
|
||||
return int16(v), nil
|
||||
}
|
||||
},
|
||||
reflect.Int32: func(s string) (any, error) {
|
||||
if v, err := strconv.ParseInt(s, 10, 32); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return int8(v), nil
|
||||
return int32(v), nil
|
||||
}
|
||||
},
|
||||
reflect.Int64: func(s string) (any, error) {
|
||||
if v, err := strconv.ParseInt(s, 10, 64); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return int8(v), nil
|
||||
return v, nil
|
||||
}
|
||||
},
|
||||
|
||||
@ -108,4 +110,24 @@ var (
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
defaultValidator = NewValidator("")
|
||||
)
|
||||
|
||||
func NewValidator(nameFromTag string) *validator.Validate {
|
||||
validate := validator.New(
|
||||
validator.WithRequiredStructEnabled(),
|
||||
)
|
||||
|
||||
if nameFromTag != "" {
|
||||
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
|
||||
name := strings.SplitN(fld.Tag.Get(nameFromTag), ",", 2)[0]
|
||||
if name == "-" {
|
||||
return ""
|
||||
}
|
||||
return name
|
||||
})
|
||||
}
|
||||
|
||||
return validate
|
||||
}
|
@ -11,6 +11,8 @@ import (
|
||||
var (
|
||||
extractors = make(map[reflect.Type]func(*http.Request) (any, error))
|
||||
|
||||
extractorsTakeBody = make(map[reflect.Type]bool)
|
||||
|
||||
extractorsTakesResponseWriter = make(map[reflect.Type]func(http.ResponseWriter, *http.Request) (any, error))
|
||||
)
|
||||
|
||||
@ -18,6 +20,12 @@ type FromRequest interface {
|
||||
FromRequest(*http.Request) error
|
||||
}
|
||||
|
||||
// Marker interface
|
||||
type TakeRequestBody interface {
|
||||
TakeRequestBody()
|
||||
}
|
||||
|
||||
// Marker interface
|
||||
type TakeResponseWriter interface {
|
||||
TakeResponseWriter(http.ResponseWriter)
|
||||
}
|
||||
@ -43,6 +51,42 @@ func RegisterExtractorGeneric[T any](extractor func(*http.Request) (any, error))
|
||||
RegisterExtractor(reflect.TypeOf(pointerToT).Elem(), extractor)
|
||||
}
|
||||
|
||||
func RegisterExtractorThatTakesBody(o any, extractor func(*http.Request) (any, error)) {
|
||||
if t, ok := o.(reflect.Type); ok {
|
||||
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
|
||||
extractors[t] = extractor
|
||||
extractorsTakeBody[t] = true
|
||||
} else if v, ok := o.(reflect.Value); ok {
|
||||
slog.With("type", v.Type().String(), "value", v.Interface()).Debug("extractor type registered with reflect.Value")
|
||||
|
||||
extractors[v.Type()] = extractor
|
||||
extractorsTakeBody[v.Type()] = true
|
||||
} else {
|
||||
t := reflect.TypeOf(o)
|
||||
slog.With("type", t.String(), "value", o).Debug("extractor type registered with object")
|
||||
|
||||
extractors[t] = extractor
|
||||
extractorsTakeBody[t] = true
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterExtractorThatTakesBodyGeneric[T any](extractor func(*http.Request) (any, error)) {
|
||||
var pointerToT *T
|
||||
RegisterExtractorThatTakesBody(reflect.TypeOf(pointerToT).Elem(), extractor)
|
||||
}
|
||||
|
||||
func IsTakeBody(t reflect.Type) bool {
|
||||
if isTakeBody := extractorsTakeBody[t]; isTakeBody {
|
||||
return isTakeBody
|
||||
}
|
||||
|
||||
if _, ok := util.Implements[TakeRequestBody](t); ok {
|
||||
return ok
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func RegisterExtractorThatTakesResponseWriter(o any, extractor func(http.ResponseWriter, *http.Request) (any, error)) {
|
||||
if t, ok := o.(reflect.Type); ok {
|
||||
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
|
||||
@ -70,62 +114,57 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
|
||||
_, ptrIsTakeResponseWriter := util.Implements[TakeResponseWriter](t.Elem())
|
||||
isTakeResponseWriter = isTakeResponseWriter || ptrIsTakeResponseWriter
|
||||
}
|
||||
if _, ok := util.Implements[FromRequest](t); ok {
|
||||
if t.Kind() == reflect.Pointer {
|
||||
// T.Implement(FromRequest) and T is Pointer
|
||||
// create new actual T and call T.FromRequest on it
|
||||
// if error happens, no error is returned. just return nil
|
||||
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
// var t T
|
||||
// return t, t.FromRequest(r)
|
||||
z := reflect.New(t.Elem())
|
||||
|
||||
if isTakeResponseWriter {
|
||||
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
|
||||
fromRequestType := reflect.TypeOf((*FromRequest)(nil)).Elem()
|
||||
if t.Implements(fromRequestType) {
|
||||
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
val := reflect.New(t).Elem()
|
||||
if t.Kind() == reflect.Pointer {
|
||||
val = reflect.New(t.Elem())
|
||||
}
|
||||
|
||||
if isTakeResponseWriter {
|
||||
if taker, ok := val.Addr().Interface().(TakeResponseWriter); ok {
|
||||
taker.TakeResponseWriter(w)
|
||||
}
|
||||
}
|
||||
|
||||
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
if err, ok := results[0].Interface().(error); ok && err != nil {
|
||||
if errResponse, ok := err.(ErrorResponse); ok {
|
||||
return reflect.Zero(t).Interface(), errResponse
|
||||
if fromR, ok := val.Interface().(FromRequest); ok {
|
||||
if err := fromR.FromRequest(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if val.CanAddr() {
|
||||
if fromR, ok := val.Addr().Interface().(FromRequest); ok {
|
||||
if err := fromR.FromRequest(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reflect.Zero(t).Interface(), nil
|
||||
}
|
||||
return z.Interface(), nil
|
||||
}, isTakeResponseWriter
|
||||
}
|
||||
// T.Implement(FromRequest) and T is not Pointer
|
||||
// create zero T and call T.FromRequest directly on it
|
||||
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
// var t T
|
||||
// return t, t.FromRequest(r)
|
||||
z := reflect.Zero(t)
|
||||
|
||||
if isTakeResponseWriter {
|
||||
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
|
||||
}
|
||||
if t.Kind() == reflect.Pointer {
|
||||
|
||||
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
if err, ok := results[0].Interface().(error); ok {
|
||||
return z.Interface(), err
|
||||
return val.Interface(), nil
|
||||
}
|
||||
return z.Interface(), nil
|
||||
return val.Addr().Interface(), nil
|
||||
}, isTakeResponseWriter
|
||||
} else if v, ok := util.PointerImplements[FromRequest](t); ok {
|
||||
}
|
||||
|
||||
if reflect.PointerTo(t).Implements(fromRequestType) {
|
||||
return func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
// var t T
|
||||
// return t, t.FromRequest(r)
|
||||
z := reflect.New(v.Type().Elem())
|
||||
val := reflect.New(t)
|
||||
|
||||
if isTakeResponseWriter {
|
||||
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
|
||||
if taker, ok := val.Interface().(TakeResponseWriter); ok {
|
||||
taker.TakeResponseWriter(w)
|
||||
}
|
||||
}
|
||||
|
||||
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
if err := results[0].Interface(); err == nil {
|
||||
return z.Elem().Interface(), nil
|
||||
if fromR, ok := val.Interface().(FromRequest); ok {
|
||||
if err := fromR.FromRequest(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return z.Elem().Interface(), results[0].Interface().(error)
|
||||
|
||||
return val.Elem().Interface(), nil
|
||||
}, isTakeResponseWriter
|
||||
}
|
||||
|
||||
@ -142,7 +181,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
// Create a pointer to the value
|
||||
vp := reflect.New(reflect.TypeOf(v))
|
||||
vp.Elem().Set(reflect.ValueOf(v))
|
||||
return vp.Interface(), nil
|
||||
}, false
|
||||
}
|
||||
}
|
||||
@ -160,7 +202,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
// Create a pointer to the value
|
||||
vp := reflect.New(reflect.TypeOf(v))
|
||||
vp.Elem().Set(reflect.ValueOf(v))
|
||||
return vp.Interface(), nil
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
206
magic/form.go
Normal file
206
magic/form.go
Normal file
@ -0,0 +1,206 @@
|
||||
package magic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
cachedStructFormFields = make(map[reflect.Type][]structFormFields)
|
||||
cachedUnsupportedFormFields = make(map[reflect.Type]structFormFields)
|
||||
)
|
||||
|
||||
type Form[T any] struct {
|
||||
Data T
|
||||
}
|
||||
|
||||
type structFormFields struct {
|
||||
PathKey string
|
||||
FieldKey string
|
||||
|
||||
Type reflect.Type
|
||||
MinValues, MaxValues int
|
||||
|
||||
Match *regexp.Regexp
|
||||
}
|
||||
|
||||
type UnsupportedFormType struct {
|
||||
inner reflect.Type
|
||||
}
|
||||
|
||||
func (err UnsupportedFormType) Error() string {
|
||||
return "unsupported type for Form: " + err.inner.String()
|
||||
}
|
||||
|
||||
type FormValueNotFitIn struct {
|
||||
Count int
|
||||
|
||||
field structFormFields
|
||||
}
|
||||
|
||||
func (err FormValueNotFitIn) Error() string {
|
||||
return fmt.Sprintf("Form %v is expected to has %v~%v values, got %v", err.field.PathKey, err.field.MinValues, err.field.MaxValues, err.Count)
|
||||
}
|
||||
|
||||
type InvalidFormValue struct {
|
||||
Match *regexp.Regexp
|
||||
Value string
|
||||
}
|
||||
|
||||
func (err InvalidFormValue) Error() string {
|
||||
return fmt.Sprintf("value not matched with regexp %v: %v", err.Match, err.Value)
|
||||
}
|
||||
|
||||
func findReflectFormFields(v reflect.Value) ([]structFormFields, error) {
|
||||
if cached, ok := cachedStructFormFields[v.Type()]; ok {
|
||||
return cached, nil
|
||||
}
|
||||
if cached, ok := cachedUnsupportedFormFields[v.Type()]; ok {
|
||||
return nil, UnsupportedPathValueType{inner: cached.Type}
|
||||
}
|
||||
|
||||
var fields []structFormFields
|
||||
if v.Kind() == reflect.Struct {
|
||||
t := v.Type()
|
||||
for idx := 0; idx < t.NumField(); idx++ {
|
||||
field := t.Field(idx)
|
||||
|
||||
pathKey := field.Name
|
||||
var match *regexp.Regexp
|
||||
|
||||
if tags, ok := field.Tag.Lookup("form"); ok {
|
||||
parts := strings.Split(strings.TrimSpace(tags), ",")
|
||||
if len(parts) == 0 {
|
||||
// do nothing
|
||||
} else {
|
||||
if parts[0] != "" {
|
||||
pathKey = parts[0]
|
||||
}
|
||||
|
||||
for _, part := range parts[1:] {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "match=") {
|
||||
part := strings.TrimPrefix(part, "match=")
|
||||
if unquoted, err := strconv.Unquote(part); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
part = unquoted
|
||||
}
|
||||
exp, err := regexp.Compile(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
match = exp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fieldType := field.Type
|
||||
maxValues, minValues := 1, 1
|
||||
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
minValues = 0
|
||||
} else if fieldType.Kind() == reflect.Array {
|
||||
maxValues, minValues = fieldType.Len(), fieldType.Len()
|
||||
fieldType = fieldType.Elem()
|
||||
} else if fieldType.Kind() == reflect.Slice {
|
||||
fieldType = fieldType.Elem()
|
||||
maxValues, minValues = -1, 0
|
||||
}
|
||||
|
||||
if pathValueConvertTable[fieldType.Kind()] == nil {
|
||||
if _, ok := fieldType.MethodByName("FromString"); !ok {
|
||||
cachedUnsupportedPathValueFields[t] = field
|
||||
return nil, UnsupportedPathValueType{inner: field.Type}
|
||||
}
|
||||
}
|
||||
fields = append(fields, structFormFields{
|
||||
PathKey: pathKey,
|
||||
FieldKey: field.Name,
|
||||
|
||||
Type: fieldType,
|
||||
MaxValues: maxValues,
|
||||
MinValues: minValues,
|
||||
|
||||
Match: match,
|
||||
})
|
||||
|
||||
cachedStructFormFields[t] = fields
|
||||
}
|
||||
} else {
|
||||
return nil, UnsupportedFormType{inner: v.Type()}
|
||||
}
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
func (form *Form[T]) FromRequest(r *http.Request) error {
|
||||
v := reflect.ValueOf(form).Elem().FieldByName("Data")
|
||||
fields, err := findReflectFormFields(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil {
|
||||
return err
|
||||
}
|
||||
q := r.PostForm
|
||||
|
||||
for _, field := range fields {
|
||||
values := q[field.PathKey]
|
||||
if len(values) < field.MinValues || (field.MaxValues > 0 && len(values) > field.MaxValues) {
|
||||
return FormValueNotFitIn{Count: len(values), field: field}
|
||||
}
|
||||
|
||||
structField := v.FieldByName(field.FieldKey)
|
||||
|
||||
if structField.Kind() == reflect.Slice && structField.Len() < len(values) {
|
||||
if structField.Cap() < len(values) {
|
||||
structField.SetCap(len(values))
|
||||
}
|
||||
structField.SetLen(len(values))
|
||||
}
|
||||
|
||||
for idx, value := range values {
|
||||
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
|
||||
if idx != len(values)-1 {
|
||||
// only the last takes effect
|
||||
continue
|
||||
}
|
||||
}
|
||||
// check match regexp
|
||||
if field.Match != nil && !field.Match.MatchString(value) {
|
||||
return InvalidFormValue{Match: field.Match, Value: value}
|
||||
}
|
||||
|
||||
// convert to elem value
|
||||
convert := pathValueConvertTable[field.Type.Kind()]
|
||||
result, err := convert(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reflectValue := reflect.ValueOf(result)
|
||||
|
||||
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
|
||||
if field.MinValues == 0 {
|
||||
newReflectValue := reflect.New(field.Type)
|
||||
newReflectValue.Elem().Set(reflectValue)
|
||||
reflectValue = newReflectValue
|
||||
}
|
||||
structField.Set(reflectValue)
|
||||
} else {
|
||||
structField.Index(idx).Set(reflectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (form Form[T]) TakeRequestBody() {}
|
@ -2,8 +2,15 @@ package magic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
var (
|
||||
jsonValidator = NewValidator("json")
|
||||
)
|
||||
|
||||
type JsonDecodeError struct {
|
||||
@ -18,10 +25,25 @@ func (err JsonDecodeError) WriteResponse(rw http.ResponseWriter) {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
type Json[T any] struct {
|
||||
Data T
|
||||
type ValidateError []string
|
||||
|
||||
func (err ValidateError) Error() string {
|
||||
return "invalid or missing fields: [" + strings.Join(err, ", ") + "]"
|
||||
}
|
||||
|
||||
func (v ValidateError) WriteResponse(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(v.Error()))
|
||||
}
|
||||
|
||||
type Json[T any] struct {
|
||||
Data T `json:"data"`
|
||||
}
|
||||
|
||||
// TakeRequestBody is a marker method to indicate that this extractor takes the request body.
|
||||
func (j Json[T]) TakeRequestBody() {}
|
||||
|
||||
func NewJson[T any](data T) Json[T] {
|
||||
return Json[T]{
|
||||
Data: data,
|
||||
@ -29,35 +51,41 @@ func NewJson[T any](data T) Json[T] {
|
||||
}
|
||||
|
||||
func (data *Json[T]) FromRequest(request *http.Request) error {
|
||||
if request.Body == nil {
|
||||
panic("body is already taken")
|
||||
}
|
||||
|
||||
bodyReader := request.Body
|
||||
defer bodyReader.Close()
|
||||
|
||||
request.Body = http.NoBody
|
||||
request.Body = nil
|
||||
|
||||
if err := json.NewDecoder(bodyReader).Decode(&data.Data); err != nil {
|
||||
return JsonDecodeError{inner: err}
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(data.Data)
|
||||
if value.Kind() == reflect.Struct || (value.Kind() == reflect.Ptr && value.Elem().Kind() == reflect.Struct) {
|
||||
if err := jsonValidator.StructCtx(request.Context(), data.Data); err == nil {
|
||||
return nil
|
||||
} else if validateErr, isValidateErr := err.(validator.ValidationErrors); isValidateErr {
|
||||
fields := make([]string, len(validateErr))
|
||||
for idx, err := range validateErr {
|
||||
fields[idx] = err.Field()
|
||||
}
|
||||
return ValidateError(fields)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type counter struct {
|
||||
counter int64
|
||||
}
|
||||
|
||||
func (c *counter) Write(b []byte) (int, error) {
|
||||
n := len(b)
|
||||
c.counter += int64(n)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (container Json[T]) Write(w io.Writer) (int64, error) {
|
||||
counter := &counter{}
|
||||
err := json.NewEncoder(io.MultiWriter(counter, w)).Encode(container.Data)
|
||||
return counter.counter, err
|
||||
}
|
||||
|
||||
func (json Json[T]) WriteResponse(w http.ResponseWriter) {
|
||||
json.Write(w)
|
||||
func (data *Json[T]) WriteResponse(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if data == nil {
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
@ -8,5 +8,6 @@ import (
|
||||
type Map map[string]any
|
||||
|
||||
func (m Map) RespondWriter(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(m)
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ func (query *Query[T]) FromRequest(r *http.Request) error {
|
||||
|
||||
if structField.Kind() == reflect.Slice && structField.Len() < len(values) {
|
||||
if structField.Cap() < len(values) {
|
||||
structField.SetCap(len(values))
|
||||
structField.Grow(len(values) - structField.Cap())
|
||||
}
|
||||
structField.SetLen(len(values))
|
||||
}
|
||||
|
@ -2,7 +2,10 @@ package magic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"git.jeffthecoder.xyz/public/lazyhandler/middleware/cleanup"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -12,4 +15,49 @@ func init() {
|
||||
RegisterExtractorGeneric[context.Context](func(r *http.Request) (any, error) {
|
||||
return r.Context(), nil
|
||||
})
|
||||
// RegisterExtractorThatTakesBodyGeneric for io.ReadCloser extracts the request body.
|
||||
//
|
||||
// IMPORTANT: This extractor consumes the request body (r.Body). After this extractor
|
||||
// is used, r.Body will be set to nil to prevent it from being read multiple times.
|
||||
// Any subsequent attempt to read the body will fail.
|
||||
//
|
||||
// A cleanup function is automatically registered to close the body once the request
|
||||
// is finished. This requires the cleanup middleware to be present in the handler chain.
|
||||
RegisterExtractorThatTakesBodyGeneric[io.ReadCloser](func(r *http.Request) (any, error) {
|
||||
if r.Body == nil {
|
||||
panic("body is already taken")
|
||||
}
|
||||
|
||||
body := r.Body
|
||||
r.Body = nil
|
||||
|
||||
cleanup.Register(r.Context(), cleanup.CleanupFunc(func() {
|
||||
body.Close()
|
||||
}))
|
||||
|
||||
return body, nil
|
||||
})
|
||||
|
||||
// RegisterExtractorThatTakesBodyGeneric for io.Reader extracts the request body.
|
||||
//
|
||||
// IMPORTANT: This extractor consumes the request body (r.Body). After this extractor
|
||||
// is used, r.Body will be set to nil to prevent it from being read multiple times.
|
||||
// Any subsequent attempt to read the body will fail.
|
||||
//
|
||||
// A cleanup function is automatically registered to close the body once the request
|
||||
// is finished. This requires the cleanup middleware to be present in the handler chain.
|
||||
RegisterExtractorThatTakesBodyGeneric[io.Reader](func(r *http.Request) (any, error) {
|
||||
if r.Body == nil {
|
||||
panic("body is already taken")
|
||||
}
|
||||
|
||||
body := r.Body
|
||||
r.Body = nil
|
||||
|
||||
cleanup.Register(r.Context(), cleanup.CleanupFunc(func() {
|
||||
body.Close()
|
||||
}))
|
||||
|
||||
return body, nil
|
||||
})
|
||||
}
|
||||
|
@ -22,6 +22,26 @@ type ErrorResponse interface {
|
||||
WriteResponse(http.ResponseWriter)
|
||||
}
|
||||
|
||||
func WrapSimpleErrorWithStatus(message string, status int) ErrorResponse {
|
||||
return simpleErrorResponseWithoutBody{
|
||||
message: message,
|
||||
status: status,
|
||||
}
|
||||
}
|
||||
|
||||
type simpleErrorResponseWithoutBody struct {
|
||||
message string
|
||||
status int
|
||||
}
|
||||
|
||||
func (err simpleErrorResponseWithoutBody) Error() string {
|
||||
return err.message
|
||||
}
|
||||
|
||||
func (err simpleErrorResponseWithoutBody) WriteResponse(w http.ResponseWriter) {
|
||||
w.WriteHeader(err.status)
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterExtractorThatTakesResponseWriterGeneric[http.ResponseWriter](func(w http.ResponseWriter, r *http.Request) (any, error) {
|
||||
return w, nil
|
||||
|
@ -2,12 +2,8 @@ package magic
|
||||
|
||||
import "net/http"
|
||||
|
||||
type State[T any] struct {
|
||||
Data T
|
||||
}
|
||||
|
||||
func RegisterState[T any](data T) {
|
||||
RegisterExtractor(State[T]{}, func(r *http.Request) (any, error) {
|
||||
return State[T]{Data: data}, nil
|
||||
RegisterExtractor(data, func(r *http.Request) (any, error) {
|
||||
return data, nil
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user