sync from project
This commit is contained in:
61
middleware/cleanup/cleanup.go
Normal file
61
middleware/cleanup/cleanup.go
Normal file
@ -0,0 +1,61 @@
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
"git.jeffthecoder.xyz/public/lazyhandler/middleware"
|
||||
)
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
cleanupCtxKey ctxKey = iota
|
||||
)
|
||||
|
||||
type CleanupContext struct {
|
||||
funcs []Cleanup
|
||||
}
|
||||
|
||||
type Cleanup interface {
|
||||
Name() string
|
||||
Cleanup()
|
||||
}
|
||||
|
||||
type CleanupFunc func()
|
||||
|
||||
func (fn CleanupFunc) Name() string {
|
||||
return reflect.TypeOf(fn).String()
|
||||
}
|
||||
|
||||
func (fn CleanupFunc) Cleanup() {
|
||||
defer func() {
|
||||
if v := recover(); v != nil {
|
||||
slog.With("v", v).Warn("cleanup panicked")
|
||||
}
|
||||
}()
|
||||
|
||||
fn()
|
||||
}
|
||||
|
||||
func Collect() middleware.Middleware {
|
||||
return middleware.WrapFunc(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := &CleanupContext{}
|
||||
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cleanupCtxKey, ctx)))
|
||||
for _, cleanupFunc := range ctx.funcs {
|
||||
defer cleanupFunc.Cleanup()
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Register adds a Cleanup function to the CleanupContext in the provided context.
|
||||
// If the CleanupContext is not found in the context, the Cleanup function is not registered.
|
||||
func Register(ctx context.Context, c Cleanup) {
|
||||
if ctx, ok := ctx.Value(cleanupCtxKey).(*CleanupContext); ok {
|
||||
ctx.funcs = append(ctx.funcs, c)
|
||||
}
|
||||
}
|
54
middleware/cleanup/cleanup_test.go
Normal file
54
middleware/cleanup/cleanup_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCleanupMiddleware(t *testing.T) {
|
||||
var executionOrder []string
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Register(r.Context(), CleanupFunc(func() {
|
||||
executionOrder = append(executionOrder, "first")
|
||||
}))
|
||||
Register(r.Context(), CleanupFunc(func() {
|
||||
executionOrder = append(executionOrder, "second")
|
||||
}))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
middleware := Collect().WrapHandler(handler)
|
||||
middleware.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||
rr.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
expectedOrder := "second,first"
|
||||
actualOrder := strings.Join(executionOrder, ",")
|
||||
if actualOrder != expectedOrder {
|
||||
t.Errorf("cleanup functions executed in wrong order: got %s want %s",
|
||||
actualOrder, expectedOrder)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupMissingContext(t *testing.T) {
|
||||
// This test ensures that Register does not panic when the context is missing.
|
||||
// The function should fail silently.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("The code panicked when it should not have")
|
||||
}
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
// No middleware, so no context
|
||||
Register(req.Context(), CleanupFunc(func() {}))
|
||||
}
|
@ -75,6 +75,10 @@ func (log Log) WrapHandler(next http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// Logger retrieves the slog.Logger from the request context.
|
||||
// If the logger is not found in the context (e.g., the httplog middleware is not used),
|
||||
// it creates and returns a new logger with basic request information.
|
||||
// Using the httplog middleware is recommended to ensure the configured logger is available.
|
||||
func Logger(r *http.Request) *slog.Logger {
|
||||
if logger, ok := r.Context().Value(loggerKey).(*slog.Logger); ok {
|
||||
return logger.With("time", time.Now())
|
||||
|
85
middleware/httplog/log_test.go
Normal file
85
middleware/httplog/log_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package httplog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockHijacker struct {
|
||||
*httptest.ResponseRecorder
|
||||
hijacked bool
|
||||
}
|
||||
|
||||
func (m *mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
m.hijacked = true
|
||||
// Return dummy values, not used in this test
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func TestResponseRecorder_Hijack(t *testing.T) {
|
||||
recorder := &responseRecorder{
|
||||
ResponseWriter: &mockHijacker{ResponseRecorder: httptest.NewRecorder()},
|
||||
}
|
||||
|
||||
_, _, err := recorder.Hijack()
|
||||
if err != nil {
|
||||
t.Fatalf("Hijack failed: %v", err)
|
||||
}
|
||||
|
||||
if recorder.ResponseWriter != nil {
|
||||
t.Error("ResponseWriter should be nil after Hijack")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, nil)))
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
logger := Logger(r)
|
||||
logger.Info("test message")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("with middleware", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
logMiddleware := Log{LogStart: true, LogFinish: true}
|
||||
wrappedHandler := logMiddleware.WrapHandler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) {
|
||||
t.Error("expected start log message, but not found")
|
||||
}
|
||||
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) {
|
||||
t.Error("expected handler log message, but not found")
|
||||
}
|
||||
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) {
|
||||
t.Error("expected finish log message, but not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("without middleware", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if !bytes.Contains(buf.Bytes(), []byte("level=INFO msg=\"test message\"")) {
|
||||
t.Error("expected handler log message, but not found")
|
||||
}
|
||||
if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=request")) {
|
||||
t.Error("unexpected start log message found")
|
||||
}
|
||||
if bytes.Contains(buf.Bytes(), []byte("level=INFO msg=response")) {
|
||||
t.Error("unexpected finish log message found")
|
||||
}
|
||||
})
|
||||
}
|
@ -42,3 +42,20 @@ func DebugPanicHandler(w http.ResponseWriter, r *http.Request, err any) {
|
||||
func Debug() middleware.Middleware {
|
||||
return Recover(DebugPanicHandler)
|
||||
}
|
||||
|
||||
// ProductionPanicHandler is a PanicResponseFunc that provides a generic error message
|
||||
// in production environments and logs the detailed panic error.
|
||||
func ProductionPanicHandler(w http.ResponseWriter, r *http.Request, err any) {
|
||||
httplog.Logger(r).With("panic", err).Error("request panicked")
|
||||
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": "Internal Server Error",
|
||||
})
|
||||
}
|
||||
|
||||
// Production returns a middleware that recovers from panics and handles them
|
||||
// with ProductionPanicHandler.
|
||||
func Production() middleware.Middleware {
|
||||
return Recover(ProductionPanicHandler)
|
||||
}
|
||||
|
56
middleware/recover/recover_test.go
Normal file
56
middleware/recover/recover_test.go
Normal file
@ -0,0 +1,56 @@
|
||||
package recover
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func panicHandler(w http.ResponseWriter, r *http.Request) {
|
||||
panic("test panic")
|
||||
}
|
||||
|
||||
func TestDebugRecover(t *testing.T) {
|
||||
handler := http.HandlerFunc(panicHandler)
|
||||
wrappedHandler := Debug().WrapHandler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/panic", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response body: %v", err)
|
||||
}
|
||||
|
||||
if resp["error"] != "test panic" {
|
||||
t.Errorf("expected error 'test panic', got '%v'", resp["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProductionRecover(t *testing.T) {
|
||||
handler := http.HandlerFunc(panicHandler)
|
||||
wrappedHandler := Production().WrapHandler(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/panic", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected status %d, got %d", http.StatusInternalServerError, rr.Code)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response body: %v", err)
|
||||
}
|
||||
|
||||
if resp["error"] != "Internal Server Error" {
|
||||
t.Errorf("expected error 'Internal Server Error', got '%v'", resp["error"])
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user