sync from project

This commit is contained in:
2025-06-18 10:12:19 +08:00
parent 61ffeeb3b8
commit fb579e8689
20 changed files with 1332 additions and 103 deletions

View File

@ -0,0 +1,61 @@
package cleanup
import (
"context"
"log/slog"
"net/http"
"reflect"
"git.jeffthecoder.xyz/public/lazyhandler/middleware"
)
type ctxKey int
const (
cleanupCtxKey ctxKey = iota
)
type CleanupContext struct {
funcs []Cleanup
}
type Cleanup interface {
Name() string
Cleanup()
}
type CleanupFunc func()
func (fn CleanupFunc) Name() string {
return reflect.TypeOf(fn).String()
}
func (fn CleanupFunc) Cleanup() {
defer func() {
if v := recover(); v != nil {
slog.With("v", v).Warn("cleanup panicked")
}
}()
fn()
}
func Collect() middleware.Middleware {
return middleware.WrapFunc(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := &CleanupContext{}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cleanupCtxKey, ctx)))
for _, cleanupFunc := range ctx.funcs {
defer cleanupFunc.Cleanup()
}
})
})
}
// Register adds a Cleanup function to the CleanupContext in the provided context.
// If the CleanupContext is not found in the context, the Cleanup function is not registered.
func Register(ctx context.Context, c Cleanup) {
if ctx, ok := ctx.Value(cleanupCtxKey).(*CleanupContext); ok {
ctx.funcs = append(ctx.funcs, c)
}
}

View File

@ -0,0 +1,54 @@
package cleanup
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCleanupMiddleware(t *testing.T) {
var executionOrder []string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Register(r.Context(), CleanupFunc(func() {
executionOrder = append(executionOrder, "first")
}))
Register(r.Context(), CleanupFunc(func() {
executionOrder = append(executionOrder, "second")
}))
w.WriteHeader(http.StatusOK)
})
req := httptest.NewRequest("GET", "/", nil)
rr := httptest.NewRecorder()
middleware := Collect().WrapHandler(handler)
middleware.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
rr.Code, http.StatusOK)
}
expectedOrder := "second,first"
actualOrder := strings.Join(executionOrder, ",")
if actualOrder != expectedOrder {
t.Errorf("cleanup functions executed in wrong order: got %s want %s",
actualOrder, expectedOrder)
}
}
func TestCleanupMissingContext(t *testing.T) {
// This test ensures that Register does not panic when the context is missing.
// The function should fail silently.
defer func() {
if r := recover(); r != nil {
t.Errorf("The code panicked when it should not have")
}
}()
req := httptest.NewRequest("GET", "/", nil)
// No middleware, so no context
Register(req.Context(), CleanupFunc(func() {}))
}

View File

@ -75,6 +75,10 @@ func (log Log) WrapHandler(next http.Handler) http.Handler {
})
}
// Logger retrieves the slog.Logger from the request context.
// If the logger is not found in the context (e.g., the httplog middleware is not used),
// it creates and returns a new logger with basic request information.
// Using the httplog middleware is recommended to ensure the configured logger is available.
func Logger(r *http.Request) *slog.Logger {
if logger, ok := r.Context().Value(loggerKey).(*slog.Logger); ok {
return logger.With("time", time.Now())

View File

@ -0,0 +1,85 @@
package httplog
import (
"bufio"
"bytes"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"testing"
)
type mockHijacker struct {
*httptest.ResponseRecorder
hijacked bool
}
func (m *mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
m.hijacked = true
// Return dummy values, not used in this test
return nil, nil, nil
}
func TestResponseRecorder_Hijack(t *testing.T) {
recorder := &responseRecorder{
ResponseWriter: &mockHijacker{ResponseRecorder: httptest.NewRecorder()},
}
_, _, err := recorder.Hijack()
if err != nil {
t.Fatalf("Hijack failed: %v", err)
}
if recorder.ResponseWriter != nil {
t.Error("ResponseWriter should be nil after Hijack")
}
}
func TestLogger(t *testing.T) {
var buf bytes.Buffer
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, nil)))
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger := Logger(r)
logger.Info("test message")
w.WriteHeader(http.StatusOK)
})
t.Run("with middleware", func(t *testing.T) {
buf.Reset()
logMiddleware := Log{LogStart: true, LogFinish: true}
wrappedHandler := logMiddleware.WrapHandler(handler)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) {
t.Error("expected start log message, but not found")
}
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) {
t.Error("expected handler log message, but not found")
}
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) {
t.Error("expected finish log message, but not found")
}
})
t.Run("without middleware", func(t *testing.T) {
buf.Reset()
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) {
t.Error("expected handler log message, but not found")
}
if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) {
t.Error("unexpected start log message found")
}
if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) {
t.Error("unexpected finish log message found")
}
})
}

View File

@ -42,3 +42,20 @@ func DebugPanicHandler(w http.ResponseWriter, r *http.Request, err any) {
func Debug() middleware.Middleware {
return Recover(DebugPanicHandler)
}
// ProductionPanicHandler is a PanicResponseFunc that provides a generic error message
// in production environments and logs the detailed panic error.
func ProductionPanicHandler(w http.ResponseWriter, r *http.Request, err any) {
httplog.Logger(r).With("panic", err).Error("request panicked")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]any{
"error": "Internal Server Error",
})
}
// Production returns a middleware that recovers from panics and handles them
// with ProductionPanicHandler.
func Production() middleware.Middleware {
return Recover(ProductionPanicHandler)
}

View File

@ -0,0 +1,56 @@
package recover
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func panicHandler(w http.ResponseWriter, r *http.Request) {
panic("test panic")
}
func TestDebugRecover(t *testing.T) {
handler := http.HandlerFunc(panicHandler)
wrappedHandler := Debug().WrapHandler(handler)
req := httptest.NewRequest("GET", "/panic", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
}
var resp map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if resp["error"] != "test panic" {
t.Errorf("expected error 'test panic', got '%v'", resp["error"])
}
}
func TestProductionRecover(t *testing.T) {
handler := http.HandlerFunc(panicHandler)
wrappedHandler := Production().WrapHandler(handler)
req := httptest.NewRequest("GET", "/panic", nil)
rr := httptest.NewRecorder()
wrappedHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusInternalServerError {
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
}
var resp map[string]any
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
if resp["error"] != "Internal Server Error" {
t.Errorf("expected error 'Internal Server Error', got '%v'", resp["error"])
}
}