sync from project

This commit is contained in:
2025-06-18 10:12:19 +08:00
parent 61ffeeb3b8
commit fb579e8689
20 changed files with 1332 additions and 103 deletions

37
magic/autodecode.go Normal file
View File

@ -0,0 +1,37 @@
package magic
import (
"net/http"
"strings"
)
type ErrUknownContentType string
func (err ErrUknownContentType) Error() string {
return "unknown content-type: " + string(err)
}
type AutoDecode[T any] struct {
Data T
}
func (d *AutoDecode[T]) FromRequest(r *http.Request) error {
contentType, _, _ := strings.Cut(r.Header.Get("Content-Type"), ";")
switch contentType {
case "application/json":
decoder := &Json[T]{}
if err := decoder.FromRequest(r); err != nil {
return err
}
d.Data = decoder.Data
case "multipart/form-data", "application/x-www-form-urlencoded":
decoder := &Form[T]{}
if err := decoder.FromRequest(r); err != nil {
return err
}
d.Data = decoder.Data
}
return ErrUknownContentType(contentType)
}
func (d AutoDecode[T]) TakeRequestBody() {}

View File

@ -4,6 +4,8 @@ import (
"reflect"
"strconv"
"strings"
"github.com/go-playground/validator/v10"
)
var (
@ -54,21 +56,21 @@ var (
if v, err := strconv.ParseInt(s, 10, 16); err != nil {
return nil, err
} else {
return int8(v), nil
return int16(v), nil
}
},
reflect.Int32: func(s string) (any, error) {
if v, err := strconv.ParseInt(s, 10, 32); err != nil {
return nil, err
} else {
return int8(v), nil
return int32(v), nil
}
},
reflect.Int64: func(s string) (any, error) {
if v, err := strconv.ParseInt(s, 10, 64); err != nil {
return nil, err
} else {
return int8(v), nil
return v, nil
}
},
@ -108,4 +110,24 @@ var (
}
},
}
defaultValidator = NewValidator("")
)
func NewValidator(nameFromTag string) *validator.Validate {
validate := validator.New(
validator.WithRequiredStructEnabled(),
)
if nameFromTag != "" {
validate.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get(nameFromTag), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
}
return validate
}

View File

@ -11,6 +11,8 @@ import (
var (
extractors = make(map[reflect.Type]func(*http.Request) (any, error))
extractorsTakeBody = make(map[reflect.Type]bool)
extractorsTakesResponseWriter = make(map[reflect.Type]func(http.ResponseWriter, *http.Request) (any, error))
)
@ -18,6 +20,12 @@ type FromRequest interface {
FromRequest(*http.Request) error
}
// Marker interface
type TakeRequestBody interface {
TakeRequestBody()
}
// Marker interface
type TakeResponseWriter interface {
TakeResponseWriter(http.ResponseWriter)
}
@ -43,6 +51,42 @@ func RegisterExtractorGeneric[T any](extractor func(*http.Request) (any, error))
RegisterExtractor(reflect.TypeOf(pointerToT).Elem(), extractor)
}
func RegisterExtractorThatTakesBody(o any, extractor func(*http.Request) (any, error)) {
if t, ok := o.(reflect.Type); ok {
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
extractors[t] = extractor
extractorsTakeBody[t] = true
} else if v, ok := o.(reflect.Value); ok {
slog.With("type", v.Type().String(), "value", v.Interface()).Debug("extractor type registered with reflect.Value")
extractors[v.Type()] = extractor
extractorsTakeBody[v.Type()] = true
} else {
t := reflect.TypeOf(o)
slog.With("type", t.String(), "value", o).Debug("extractor type registered with object")
extractors[t] = extractor
extractorsTakeBody[t] = true
}
}
func RegisterExtractorThatTakesBodyGeneric[T any](extractor func(*http.Request) (any, error)) {
var pointerToT *T
RegisterExtractorThatTakesBody(reflect.TypeOf(pointerToT).Elem(), extractor)
}
func IsTakeBody(t reflect.Type) bool {
if isTakeBody := extractorsTakeBody[t]; isTakeBody {
return isTakeBody
}
if _, ok := util.Implements[TakeRequestBody](t); ok {
return ok
}
return false
}
func RegisterExtractorThatTakesResponseWriter(o any, extractor func(http.ResponseWriter, *http.Request) (any, error)) {
if t, ok := o.(reflect.Type); ok {
slog.With("type", t.String()).Debug("extractor type registered with reflect.Type")
@ -70,62 +114,57 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
_, ptrIsTakeResponseWriter := util.Implements[TakeResponseWriter](t.Elem())
isTakeResponseWriter = isTakeResponseWriter || ptrIsTakeResponseWriter
}
if _, ok := util.Implements[FromRequest](t); ok {
if t.Kind() == reflect.Pointer {
// T.Implement(FromRequest) and T is Pointer
// create new actual T and call T.FromRequest on it
// if error happens, no error is returned. just return nil
return func(w http.ResponseWriter, r *http.Request) (any, error) {
// var t T
// return t, t.FromRequest(r)
z := reflect.New(t.Elem())
if isTakeResponseWriter {
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
fromRequestType := reflect.TypeOf((*FromRequest)(nil)).Elem()
if t.Implements(fromRequestType) {
return func(w http.ResponseWriter, r *http.Request) (any, error) {
val := reflect.New(t).Elem()
if t.Kind() == reflect.Pointer {
val = reflect.New(t.Elem())
}
if isTakeResponseWriter {
if taker, ok := val.Addr().Interface().(TakeResponseWriter); ok {
taker.TakeResponseWriter(w)
}
}
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
if err, ok := results[0].Interface().(error); ok && err != nil {
if errResponse, ok := err.(ErrorResponse); ok {
return reflect.Zero(t).Interface(), errResponse
if fromR, ok := val.Interface().(FromRequest); ok {
if err := fromR.FromRequest(r); err != nil {
return nil, err
}
} else if val.CanAddr() {
if fromR, ok := val.Addr().Interface().(FromRequest); ok {
if err := fromR.FromRequest(r); err != nil {
return nil, err
}
return reflect.Zero(t).Interface(), nil
}
return z.Interface(), nil
}, isTakeResponseWriter
}
// T.Implement(FromRequest) and T is not Pointer
// create zero T and call T.FromRequest directly on it
return func(w http.ResponseWriter, r *http.Request) (any, error) {
// var t T
// return t, t.FromRequest(r)
z := reflect.Zero(t)
if isTakeResponseWriter {
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
}
if t.Kind() == reflect.Pointer {
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
if err, ok := results[0].Interface().(error); ok {
return z.Interface(), err
return val.Interface(), nil
}
return z.Interface(), nil
return val.Addr().Interface(), nil
}, isTakeResponseWriter
} else if v, ok := util.PointerImplements[FromRequest](t); ok {
}
if reflect.PointerTo(t).Implements(fromRequestType) {
return func(w http.ResponseWriter, r *http.Request) (any, error) {
// var t T
// return t, t.FromRequest(r)
z := reflect.New(v.Type().Elem())
val := reflect.New(t)
if isTakeResponseWriter {
z.MethodByName("TakeResponseWriter").Call([]reflect.Value{reflect.ValueOf(w)})
if taker, ok := val.Interface().(TakeResponseWriter); ok {
taker.TakeResponseWriter(w)
}
}
results := z.MethodByName("FromRequest").Call([]reflect.Value{reflect.ValueOf(r)})
if err := results[0].Interface(); err == nil {
return z.Elem().Interface(), nil
if fromR, ok := val.Interface().(FromRequest); ok {
if err := fromR.FromRequest(r); err != nil {
return nil, err
}
}
return z.Elem().Interface(), results[0].Interface().(error)
return val.Elem().Interface(), nil
}, isTakeResponseWriter
}
@ -142,7 +181,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
if err != nil {
return nil, err
}
return &v, nil
// Create a pointer to the value
vp := reflect.New(reflect.TypeOf(v))
vp.Elem().Set(reflect.ValueOf(v))
return vp.Interface(), nil
}, false
}
}
@ -160,7 +202,10 @@ func GetExtractor(t reflect.Type) (func(http.ResponseWriter, *http.Request) (any
if err != nil {
return nil, err
}
return &v, nil
// Create a pointer to the value
vp := reflect.New(reflect.TypeOf(v))
vp.Elem().Set(reflect.ValueOf(v))
return vp.Interface(), nil
}, true
}
}

206
magic/form.go Normal file
View File

@ -0,0 +1,206 @@
package magic
import (
"fmt"
"net/http"
"reflect"
"regexp"
"strconv"
"strings"
)
var (
cachedStructFormFields = make(map[reflect.Type][]structFormFields)
cachedUnsupportedFormFields = make(map[reflect.Type]structFormFields)
)
type Form[T any] struct {
Data T
}
type structFormFields struct {
PathKey string
FieldKey string
Type reflect.Type
MinValues, MaxValues int
Match *regexp.Regexp
}
type UnsupportedFormType struct {
inner reflect.Type
}
func (err UnsupportedFormType) Error() string {
return "unsupported type for Form: " + err.inner.String()
}
type FormValueNotFitIn struct {
Count int
field structFormFields
}
func (err FormValueNotFitIn) Error() string {
return fmt.Sprintf("Form %v is expected to has %v~%v values, got %v", err.field.PathKey, err.field.MinValues, err.field.MaxValues, err.Count)
}
type InvalidFormValue struct {
Match *regexp.Regexp
Value string
}
func (err InvalidFormValue) Error() string {
return fmt.Sprintf("value not matched with regexp %v: %v", err.Match, err.Value)
}
func findReflectFormFields(v reflect.Value) ([]structFormFields, error) {
if cached, ok := cachedStructFormFields[v.Type()]; ok {
return cached, nil
}
if cached, ok := cachedUnsupportedFormFields[v.Type()]; ok {
return nil, UnsupportedPathValueType{inner: cached.Type}
}
var fields []structFormFields
if v.Kind() == reflect.Struct {
t := v.Type()
for idx := 0; idx < t.NumField(); idx++ {
field := t.Field(idx)
pathKey := field.Name
var match *regexp.Regexp
if tags, ok := field.Tag.Lookup("form"); ok {
parts := strings.Split(strings.TrimSpace(tags), ",")
if len(parts) == 0 {
// do nothing
} else {
if parts[0] != "" {
pathKey = parts[0]
}
for _, part := range parts[1:] {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "match=") {
part := strings.TrimPrefix(part, "match=")
if unquoted, err := strconv.Unquote(part); err != nil {
return nil, err
} else {
part = unquoted
}
exp, err := regexp.Compile(part)
if err != nil {
return nil, err
}
match = exp
}
}
}
}
fieldType := field.Type
maxValues, minValues := 1, 1
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
minValues = 0
} else if fieldType.Kind() == reflect.Array {
maxValues, minValues = fieldType.Len(), fieldType.Len()
fieldType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Slice {
fieldType = fieldType.Elem()
maxValues, minValues = -1, 0
}
if pathValueConvertTable[fieldType.Kind()] == nil {
if _, ok := fieldType.MethodByName("FromString"); !ok {
cachedUnsupportedPathValueFields[t] = field
return nil, UnsupportedPathValueType{inner: field.Type}
}
}
fields = append(fields, structFormFields{
PathKey: pathKey,
FieldKey: field.Name,
Type: fieldType,
MaxValues: maxValues,
MinValues: minValues,
Match: match,
})
cachedStructFormFields[t] = fields
}
} else {
return nil, UnsupportedFormType{inner: v.Type()}
}
return fields, nil
}
func (form *Form[T]) FromRequest(r *http.Request) error {
v := reflect.ValueOf(form).Elem().FieldByName("Data")
fields, err := findReflectFormFields(v)
if err != nil {
return err
}
if err := r.ParseMultipartForm(32 << 20); err != nil {
return err
}
q := r.PostForm
for _, field := range fields {
values := q[field.PathKey]
if len(values) < field.MinValues || (field.MaxValues > 0 && len(values) > field.MaxValues) {
return FormValueNotFitIn{Count: len(values), field: field}
}
structField := v.FieldByName(field.FieldKey)
if structField.Kind() == reflect.Slice && structField.Len() < len(values) {
if structField.Cap() < len(values) {
structField.SetCap(len(values))
}
structField.SetLen(len(values))
}
for idx, value := range values {
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
if idx != len(values)-1 {
// only the last takes effect
continue
}
}
// check match regexp
if field.Match != nil && !field.Match.MatchString(value) {
return InvalidFormValue{Match: field.Match, Value: value}
}
// convert to elem value
convert := pathValueConvertTable[field.Type.Kind()]
result, err := convert(value)
if err != nil {
return err
}
reflectValue := reflect.ValueOf(result)
if structField.Kind() != reflect.Array && structField.Kind() != reflect.Slice {
if field.MinValues == 0 {
newReflectValue := reflect.New(field.Type)
newReflectValue.Elem().Set(reflectValue)
reflectValue = newReflectValue
}
structField.Set(reflectValue)
} else {
structField.Index(idx).Set(reflectValue)
}
}
}
return nil
}
func (form Form[T]) TakeRequestBody() {}

View File

@ -2,8 +2,15 @@ package magic
import (
"encoding/json"
"io"
"net/http"
"reflect"
"strings"
"github.com/go-playground/validator/v10"
)
var (
jsonValidator = NewValidator("json")
)
type JsonDecodeError struct {
@ -18,10 +25,25 @@ func (err JsonDecodeError) WriteResponse(rw http.ResponseWriter) {
rw.WriteHeader(http.StatusBadRequest)
}
type Json[T any] struct {
Data T
type ValidateError []string
func (err ValidateError) Error() string {
return "invalid or missing fields: [" + strings.Join(err, ", ") + "]"
}
func (v ValidateError) WriteResponse(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(v.Error()))
}
type Json[T any] struct {
Data T `json:"data"`
}
// TakeRequestBody is a marker method to indicate that this extractor takes the request body.
func (j Json[T]) TakeRequestBody() {}
func NewJson[T any](data T) Json[T] {
return Json[T]{
Data: data,
@ -29,35 +51,41 @@ func NewJson[T any](data T) Json[T] {
}
func (data *Json[T]) FromRequest(request *http.Request) error {
if request.Body == nil {
panic("body is already taken")
}
bodyReader := request.Body
defer bodyReader.Close()
request.Body = http.NoBody
request.Body = nil
if err := json.NewDecoder(bodyReader).Decode(&data.Data); err != nil {
return JsonDecodeError{inner: err}
}
value := reflect.ValueOf(data.Data)
if value.Kind() == reflect.Struct || (value.Kind() == reflect.Ptr && value.Elem().Kind() == reflect.Struct) {
if err := jsonValidator.StructCtx(request.Context(), data.Data); err == nil {
return nil
} else if validateErr, isValidateErr := err.(validator.ValidationErrors); isValidateErr {
fields := make([]string, len(validateErr))
for idx, err := range validateErr {
fields[idx] = err.Field()
}
return ValidateError(fields)
} else {
return err
}
}
return nil
}
type counter struct {
counter int64
}
func (c *counter) Write(b []byte) (int, error) {
n := len(b)
c.counter += int64(n)
return n, nil
}
func (container Json[T]) Write(w io.Writer) (int64, error) {
counter := &counter{}
err := json.NewEncoder(io.MultiWriter(counter, w)).Encode(container.Data)
return counter.counter, err
}
func (json Json[T]) WriteResponse(w http.ResponseWriter) {
json.Write(w)
func (data *Json[T]) WriteResponse(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
if data == nil {
return
}
json.NewEncoder(w).Encode(data)
}

View File

@ -8,5 +8,6 @@ import (
type Map map[string]any
func (m Map) RespondWriter(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(m)
}

View File

@ -159,7 +159,7 @@ func (query *Query[T]) FromRequest(r *http.Request) error {
if structField.Kind() == reflect.Slice && structField.Len() < len(values) {
if structField.Cap() < len(values) {
structField.SetCap(len(values))
structField.Grow(len(values) - structField.Cap())
}
structField.SetLen(len(values))
}

View File

@ -2,7 +2,10 @@ package magic
import (
"context"
"io"
"net/http"
"git.jeffthecoder.xyz/public/lazyhandler/middleware/cleanup"
)
func init() {
@ -12,4 +15,49 @@ func init() {
RegisterExtractorGeneric[context.Context](func(r *http.Request) (any, error) {
return r.Context(), nil
})
// RegisterExtractorThatTakesBodyGeneric for io.ReadCloser extracts the request body.
//
// IMPORTANT: This extractor consumes the request body (r.Body). After this extractor
// is used, r.Body will be set to nil to prevent it from being read multiple times.
// Any subsequent attempt to read the body will fail.
//
// A cleanup function is automatically registered to close the body once the request
// is finished. This requires the cleanup middleware to be present in the handler chain.
RegisterExtractorThatTakesBodyGeneric[io.ReadCloser](func(r *http.Request) (any, error) {
if r.Body == nil {
panic("body is already taken")
}
body := r.Body
r.Body = nil
cleanup.Register(r.Context(), cleanup.CleanupFunc(func() {
body.Close()
}))
return body, nil
})
// RegisterExtractorThatTakesBodyGeneric for io.Reader extracts the request body.
//
// IMPORTANT: This extractor consumes the request body (r.Body). After this extractor
// is used, r.Body will be set to nil to prevent it from being read multiple times.
// Any subsequent attempt to read the body will fail.
//
// A cleanup function is automatically registered to close the body once the request
// is finished. This requires the cleanup middleware to be present in the handler chain.
RegisterExtractorThatTakesBodyGeneric[io.Reader](func(r *http.Request) (any, error) {
if r.Body == nil {
panic("body is already taken")
}
body := r.Body
r.Body = nil
cleanup.Register(r.Context(), cleanup.CleanupFunc(func() {
body.Close()
}))
return body, nil
})
}

View File

@ -22,6 +22,26 @@ type ErrorResponse interface {
WriteResponse(http.ResponseWriter)
}
func WrapSimpleErrorWithStatus(message string, status int) ErrorResponse {
return simpleErrorResponseWithoutBody{
message: message,
status: status,
}
}
type simpleErrorResponseWithoutBody struct {
message string
status int
}
func (err simpleErrorResponseWithoutBody) Error() string {
return err.message
}
func (err simpleErrorResponseWithoutBody) WriteResponse(w http.ResponseWriter) {
w.WriteHeader(err.status)
}
func init() {
RegisterExtractorThatTakesResponseWriterGeneric[http.ResponseWriter](func(w http.ResponseWriter, r *http.Request) (any, error) {
return w, nil

View File

@ -2,12 +2,8 @@ package magic
import "net/http"
type State[T any] struct {
Data T
}
func RegisterState[T any](data T) {
RegisterExtractor(State[T]{}, func(r *http.Request) (any, error) {
return State[T]{Data: data}, nil
RegisterExtractor(data, func(r *http.Request) (any, error) {
return data, nil
})
}