lazyhandler/inject.go
2025-02-22 23:06:16 +08:00

200 lines
5.0 KiB
Go

package lazyhandler
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"reflect"
"git.jeffthecoder.xyz/public/lazyhandler/magic"
"git.jeffthecoder.xyz/public/lazyhandler/util"
)
type ErrArgumentIsNotExtractable int
func (err ErrArgumentIsNotExtractable) Error() string {
return fmt.Sprintf("argument %v is not extractable", int(err))
}
type ErrReturnValueNotConvertableIntoResponsePart int
func (err ErrReturnValueNotConvertableIntoResponsePart) Error() string {
return fmt.Sprintf("return value %v is not convertable into response part", int(err))
}
var (
ErrNotAFunc = errors.New("not a function")
ErrDuplicateResponseWriterExtractor = errors.New("duplicate response writer extractor")
ErrResponseWriterCannotBeExtracted = errors.New("http.ResponseWriter extractor must not exists if function has return value")
)
func canConvert[T any](o any) bool {
t := reflect.TypeOf((*T)(nil)).Elem()
if reflectValue, ok := o.(reflect.Value); ok {
return reflectValue.CanConvert(t)
}
if reflectType, ok := o.(reflect.Type); ok {
return reflectType.ConvertibleTo(t)
}
return reflect.ValueOf(o).CanConvert(t)
}
func MagicHandler(fn any) (http.Handler, error) {
funcValue := reflect.ValueOf(fn)
t := funcValue.Type()
if t.Kind() != reflect.Func {
return nil, ErrNotAFunc
}
responseWriterExtracted := -1
extractors := make([]func(http.ResponseWriter, *http.Request) (any, error), t.NumIn())
for idx := 0; idx < t.NumIn(); idx++ {
in := t.In(idx)
extractor, isTakeResponseWriter := magic.GetExtractor(in)
if extractor == nil {
return nil, ErrArgumentIsNotExtractable(idx)
}
if isTakeResponseWriter {
if responseWriterExtracted >= 0 {
return nil, ErrDuplicateResponseWriterExtractor
}
responseWriterExtracted = idx
}
extractors[idx] = extractor
}
// http.ResponseWriter extractor must not exists if function has return value
if responseWriterExtracted >= 0 && t.NumOut() > 0 {
return nil, ErrResponseWriterCannotBeExtracted
}
for idx := 0; idx < t.NumOut(); idx++ {
out := t.Out(idx)
// int(status) || string(body)
if out.Kind() == reflect.Int || out.Kind() == reflect.String {
continue
}
// []byte(body)
if out.Kind() == reflect.Slice && out.Elem().Kind() == reflect.Uint8 {
continue
}
// [T] map[string]T(header)
if out.Kind() == reflect.Map && out.Key().Kind() == reflect.String {
continue
}
if _, ok := util.Implements[io.Reader](out); ok {
continue
}
if _, ok := util.Implements[magic.RespondWriter](out); ok {
continue
}
// last is error
if _, ok := util.Implements[error](out); ok && idx == t.NumOut()-1 {
continue
}
return nil, ErrReturnValueNotConvertableIntoResponsePart(idx)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var closers []io.Closer
defer func() {
for _, closer := range closers {
closer.Close()
}
}()
in := make([]reflect.Value, len(extractors))
for idx, extractor := range extractors {
v, err := extractor(w, r)
if err != nil {
if errResponse, ok := err.(magic.ErrorResponse); ok {
errResponse.WriteResponse(w)
} else {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(magic.Map{
"error": err.Error(),
})
}
return
}
if closer, ok := v.(io.Closer); ok {
defer closer.Close()
}
in[idx] = reflect.ValueOf(v)
}
values := funcValue.Call(in)
if numValues := len(values); numValues > 0 {
lastValue := values[numValues-1].Interface()
if err, isError := lastValue.(error); isError {
if err != nil {
if errResponse, ok := lastValue.(magic.ErrorResponse); ok {
errResponse.WriteResponse(w)
} else {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(magic.Map{
"error": err.Error(),
})
}
return
}
values = values[:numValues-1]
}
}
for _, value := range values {
obj := value.Interface()
if value.Kind() == reflect.Int {
w.WriteHeader(int(value.Int()))
} else if value.Kind() == reflect.String {
w.Write([]byte(value.String()))
} else if canConvert[[]byte](value) {
w.Write(value.Bytes())
} else if reader, ok := util.Implements[io.Reader](value); ok {
io.Copy(w, reader)
} else if responseWriter, ok := util.Implements[magic.RespondWriter](value); ok {
if responseWriter == nil {
continue
}
responseWriter.WriteResponse(w)
} else if headers, ok := obj.(http.Header); ok { // not
for name, values := range headers {
for idx, value := range values {
if idx == 0 {
w.Header().Set(name, value)
} else {
w.Header().Add(name, value)
}
}
}
} else if value.Kind() == reflect.Map {
for _, key := range value.MapKeys() {
value := value.MapIndex(key).Interface()
w.Header().Set(key.String(), fmt.Sprint(value))
}
}
}
}), nil
}
func MustMagicHandler(fn any) http.Handler {
if handler, err := MagicHandler(fn); err != nil {
panic("can not be converted to http.Handler: " + fmt.Sprintf("%T", fn) + ": " + err.Error())
} else {
return handler
}
}