From 842656aa901a3f208462eb4a7a705286199aea6b Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sat, 19 Feb 2022 20:50:33 +0800 Subject: [PATCH] feat: log 404 requests with traceid (#1554) --- rest/engine.go | 22 ++++++ rest/engine_test.go | 73 +++++++++++++++++++ rest/handler/authhandler.go | 49 +------------ rest/handler/authhandler_test.go | 20 ----- rest/handler/breakerhandler.go | 4 +- rest/handler/prometheushandler.go | 4 +- rest/handler/sheddinghandler.go | 4 +- rest/handler/tracinghandler.go | 8 +- rest/handler/tracinghandler_test.go | 44 ++++++----- rest/internal/cors/handlers.go | 45 +----------- rest/internal/cors/handlers_test.go | 47 ------------ .../response/headeronceresponsewriter.go | 57 +++++++++++++++ .../response/headeronceresponsewriter_test.go | 58 +++++++++++++++ .../withcoderesponsewriter.go | 9 ++- .../withcoderesponsewriter_test.go | 19 ++++- rest/server.go | 4 +- 16 files changed, 279 insertions(+), 188 deletions(-) create mode 100644 rest/internal/response/headeronceresponsewriter.go create mode 100644 rest/internal/response/headeronceresponsewriter_test.go rename rest/internal/{security => response}/withcoderesponsewriter.go (85%) rename rest/internal/{security => response}/withcoderesponsewriter_test.go (71%) diff --git a/rest/engine.go b/rest/engine.go index 8ba13378..46e92f0c 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -14,6 +14,7 @@ import ( "github.com/zeromicro/go-zero/rest/handler" "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/internal" + "github.com/zeromicro/go-zero/rest/internal/response" ) // use 1000m to represent 100% @@ -154,6 +155,27 @@ func (ng *engine) getShedder(priority bool) load.Shedder { return ng.shedder } +// notFoundHandler returns a middleware that handles 404 not found requests. +func (ng *engine) notFoundHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + chain := alice.New( + handler.TracingHandler(ng.conf.Name, ""), + ng.getLogHandler(), + ) + + var h http.Handler + if next != nil { + h = chain.Then(next) + } else { + h = chain.Then(http.NotFoundHandler()) + } + + cw := response.NewHeaderOnceResponseWriter(w) + h.ServeHTTP(cw, r) + cw.WriteHeader(http.StatusNotFound) + }) +} + func (ng *engine) setTlsConfig(cfg *tls.Config) { ng.tlsConfig = cfg } diff --git a/rest/engine_test.go b/rest/engine_test.go index 799309e6..ca540149 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -1,13 +1,17 @@ package rest import ( + "context" "errors" "net/http" + "net/http/httptest" + "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/conf" + "github.com/zeromicro/go-zero/core/logx" ) func TestNewEngine(t *testing.T) { @@ -190,6 +194,75 @@ func TestEngine_checkedTimeout(t *testing.T) { } } +func TestEngine_notFoundHandler(t *testing.T) { + logx.Disable() + + ng := newEngine(RestConf{}) + ts := httptest.NewServer(ng.notFoundHandler(nil)) + defer ts.Close() + + client := ts.Client() + err := func(ctx context.Context) error { + req, err := http.NewRequest("GET", ts.URL+"/bad", nil) + assert.Nil(t, err) + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusNotFound, res.StatusCode) + return res.Body.Close() + }(context.Background()) + + assert.Nil(t, err) +} + +func TestEngine_notFoundHandlerNotNil(t *testing.T) { + logx.Disable() + + ng := newEngine(RestConf{}) + var called int32 + ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + }))) + defer ts.Close() + + client := ts.Client() + err := func(ctx context.Context) error { + req, err := http.NewRequest("GET", ts.URL+"/bad", nil) + assert.Nil(t, err) + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusNotFound, res.StatusCode) + return res.Body.Close() + }(context.Background()) + + assert.Nil(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&called)) +} + +func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) { + logx.Disable() + + ng := newEngine(RestConf{}) + var called int32 + ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + w.WriteHeader(http.StatusExpectationFailed) + }))) + defer ts.Close() + + client := ts.Client() + err := func(ctx context.Context) error { + req, err := http.NewRequest("GET", ts.URL+"/bad", nil) + assert.Nil(t, err) + res, err := client.Do(req) + assert.Nil(t, err) + assert.Equal(t, http.StatusExpectationFailed, res.StatusCode) + return res.Body.Close() + }(context.Background()) + + assert.Nil(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&called)) +} + type mockedRouter struct{} func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) { diff --git a/rest/handler/authhandler.go b/rest/handler/authhandler.go index 26aacca0..9099e62b 100644 --- a/rest/handler/authhandler.go +++ b/rest/handler/authhandler.go @@ -1,15 +1,14 @@ package handler import ( - "bufio" "context" "errors" - "net" "net/http" "net/http/httputil" "github.com/golang-jwt/jwt/v4" "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/rest/internal/response" "github.com/zeromicro/go-zero/rest/token" ) @@ -105,7 +104,7 @@ func detailAuthLog(r *http.Request, reason string) { } func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) { - writer := newGuardedResponseWriter(w) + writer := response.NewHeaderOnceResponseWriter(w) if err != nil { detailAuthLog(r, err.Error()) @@ -121,47 +120,3 @@ func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback Un // if user not setting HTTP header, we set header with 401 writer.WriteHeader(http.StatusUnauthorized) } - -type guardedResponseWriter struct { - writer http.ResponseWriter - wroteHeader bool -} - -func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter { - return &guardedResponseWriter{ - writer: w, - } -} - -func (grw *guardedResponseWriter) Flush() { - if flusher, ok := grw.writer.(http.Flusher); ok { - flusher.Flush() - } -} - -func (grw *guardedResponseWriter) Header() http.Header { - return grw.writer.Header() -} - -// Hijack implements the http.Hijacker interface. -// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. -func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if hijacked, ok := grw.writer.(http.Hijacker); ok { - return hijacked.Hijack() - } - - return nil, nil, errors.New("server doesn't support hijacking") -} - -func (grw *guardedResponseWriter) Write(body []byte) (int, error) { - return grw.writer.Write(body) -} - -func (grw *guardedResponseWriter) WriteHeader(statusCode int) { - if grw.wroteHeader { - return - } - - grw.wroteHeader = true - grw.writer.WriteHeader(statusCode) -} diff --git a/rest/handler/authhandler_test.go b/rest/handler/authhandler_test.go index 1587b35f..e9496e10 100644 --- a/rest/handler/authhandler_test.go +++ b/rest/handler/authhandler_test.go @@ -90,26 +90,6 @@ func TestAuthHandler_NilError(t *testing.T) { }) } -func TestAuthHandler_Flush(t *testing.T) { - resp := httptest.NewRecorder() - handler := newGuardedResponseWriter(resp) - handler.Flush() - assert.True(t, resp.Flushed) -} - -func TestAuthHandler_Hijack(t *testing.T) { - resp := httptest.NewRecorder() - writer := newGuardedResponseWriter(resp) - assert.NotPanics(t, func() { - writer.Hijack() - }) - - writer = newGuardedResponseWriter(mockedHijackable{resp}) - assert.NotPanics(t, func() { - writer.Hijack() - }) -} - func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) { now := time.Now().Unix() claims := make(jwt.MapClaims) diff --git a/rest/handler/breakerhandler.go b/rest/handler/breakerhandler.go index 516aba76..e9d7243a 100644 --- a/rest/handler/breakerhandler.go +++ b/rest/handler/breakerhandler.go @@ -9,7 +9,7 @@ import ( "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/rest/httpx" - "github.com/zeromicro/go-zero/rest/internal/security" + "github.com/zeromicro/go-zero/rest/internal/response" ) const breakerSeparator = "://" @@ -28,7 +28,7 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle return } - cw := &security.WithCodeResponseWriter{Writer: w} + cw := &response.WithCodeResponseWriter{Writer: w} defer func() { if cw.Code < http.StatusInternalServerError { promise.Accept() diff --git a/rest/handler/prometheushandler.go b/rest/handler/prometheushandler.go index a1a286a0..639d59c0 100644 --- a/rest/handler/prometheushandler.go +++ b/rest/handler/prometheushandler.go @@ -8,7 +8,7 @@ import ( "github.com/zeromicro/go-zero/core/metric" "github.com/zeromicro/go-zero/core/prometheus" "github.com/zeromicro/go-zero/core/timex" - "github.com/zeromicro/go-zero/rest/internal/security" + "github.com/zeromicro/go-zero/rest/internal/response" ) const serverNamespace = "http_server" @@ -41,7 +41,7 @@ func PrometheusHandler(path string) func(http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { startTime := timex.Now() - cw := &security.WithCodeResponseWriter{Writer: w} + cw := &response.WithCodeResponseWriter{Writer: w} defer func() { metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path) metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code)) diff --git a/rest/handler/sheddinghandler.go b/rest/handler/sheddinghandler.go index ffbee34f..977824dc 100644 --- a/rest/handler/sheddinghandler.go +++ b/rest/handler/sheddinghandler.go @@ -8,7 +8,7 @@ import ( "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stat" "github.com/zeromicro/go-zero/rest/httpx" - "github.com/zeromicro/go-zero/rest/internal/security" + "github.com/zeromicro/go-zero/rest/internal/response" ) const serviceType = "api" @@ -41,7 +41,7 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand return } - cw := &security.WithCodeResponseWriter{Writer: w} + cw := &response.WithCodeResponseWriter{Writer: w} defer func() { if cw.Code == http.StatusServiceUnavailable { promise.Fail() diff --git a/rest/handler/tracinghandler.go b/rest/handler/tracinghandler.go index 0a8cba9e..0b1cdf7e 100644 --- a/rest/handler/tracinghandler.go +++ b/rest/handler/tracinghandler.go @@ -18,12 +18,16 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) + spanName := path + if len(spanName) == 0 { + spanName = r.URL.Path + } spanCtx, span := tracer.Start( ctx, - path, + spanName, oteltrace.WithSpanKind(oteltrace.SpanKindServer), oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest( - serviceName, path, r)...), + serviceName, spanName, r)...), ) defer span.End() diff --git a/rest/handler/tracinghandler_test.go b/rest/handler/tracinghandler_test.go index b25edaea..94af55e0 100644 --- a/rest/handler/tracinghandler_test.go +++ b/rest/handler/tracinghandler_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "testing" + "github.com/justinas/alice" "github.com/stretchr/testify/assert" ztrace "github.com/zeromicro/go-zero/core/trace" "go.opentelemetry.io/otel" @@ -21,28 +22,31 @@ func TestOtelHandler(t *testing.T) { Sampler: 1.0, }) - ts := httptest.NewServer( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header)) - spanCtx := trace.SpanContextFromContext(ctx) - assert.Equal(t, true, spanCtx.IsValid()) - }), - ) - defer ts.Close() + for _, test := range []string{"", "bar"} { + t.Run(test, func(t *testing.T) { + h := alice.New(TracingHandler("foo", test)).Then( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header)) + spanCtx := trace.SpanContextFromContext(ctx) + assert.True(t, spanCtx.IsValid()) + })) + ts := httptest.NewServer(h) + defer ts.Close() - client := ts.Client() - err := func(ctx context.Context) error { - ctx, span := otel.Tracer("httptrace/client").Start(ctx, "test") - defer span.End() + client := ts.Client() + err := func(ctx context.Context) error { + ctx, span := otel.Tracer("httptrace/client").Start(ctx, "test") + defer span.End() - req, _ := http.NewRequest("GET", ts.URL, nil) - otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header)) + req, _ := http.NewRequest("GET", ts.URL, nil) + otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header)) - res, err := client.Do(req) - assert.Equal(t, err, nil) - _ = res.Body.Close() - return nil - }(context.Background()) + res, err := client.Do(req) + assert.Nil(t, err) + return res.Body.Close() + }(context.Background()) - assert.Equal(t, err, nil) + assert.Nil(t, err) + }) + } } diff --git a/rest/internal/cors/handlers.go b/rest/internal/cors/handlers.go index c613b27d..7bb3f077 100644 --- a/rest/internal/cors/handlers.go +++ b/rest/internal/cors/handlers.go @@ -1,10 +1,9 @@ package cors import ( - "bufio" - "errors" - "net" "net/http" + + "github.com/zeromicro/go-zero/rest/internal/response" ) const ( @@ -30,7 +29,7 @@ const ( // At most one origin can be specified, other origins are ignored if given, default to be *. func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gw := &guardedResponseWriter{w: w} + gw := response.NewHeaderOnceResponseWriter(w) checkAndSetHeaders(gw, r, origins) if fn != nil { fn(gw) @@ -62,44 +61,6 @@ func Middleware(fn func(w http.Header), origins ...string) func(http.HandlerFunc } } -type guardedResponseWriter struct { - w http.ResponseWriter - wroteHeader bool -} - -func (w *guardedResponseWriter) Flush() { - if flusher, ok := w.w.(http.Flusher); ok { - flusher.Flush() - } -} - -func (w *guardedResponseWriter) Header() http.Header { - return w.w.Header() -} - -// Hijack implements the http.Hijacker interface. -// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. -func (w *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if hijacked, ok := w.w.(http.Hijacker); ok { - return hijacked.Hijack() - } - - return nil, nil, errors.New("server doesn't support hijacking") -} - -func (w *guardedResponseWriter) Write(bytes []byte) (int, error) { - return w.w.Write(bytes) -} - -func (w *guardedResponseWriter) WriteHeader(code int) { - if w.wroteHeader { - return - } - - w.w.WriteHeader(code) - w.wroteHeader = true -} - func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) { setVaryHeaders(w, r) diff --git a/rest/internal/cors/handlers_test.go b/rest/internal/cors/handlers_test.go index 047fdb98..dea7bb4e 100644 --- a/rest/internal/cors/handlers_test.go +++ b/rest/internal/cors/handlers_test.go @@ -1,8 +1,6 @@ package cors import ( - "bufio" - "net" "net/http" "net/http/httptest" "testing" @@ -131,48 +129,3 @@ func TestCorsHandlerWithOrigins(t *testing.T) { } } } - -func TestGuardedResponseWriter_Flush(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) - handler := NotAllowedHandler(func(w http.ResponseWriter) { - w.Header().Set("X-Test", "test") - w.WriteHeader(http.StatusServiceUnavailable) - _, err := w.Write([]byte("content")) - assert.Nil(t, err) - - flusher, ok := w.(http.Flusher) - assert.True(t, ok) - flusher.Flush() - }, "foo.com") - - resp := httptest.NewRecorder() - handler.ServeHTTP(resp, req) - assert.Equal(t, http.StatusServiceUnavailable, resp.Code) - assert.Equal(t, "test", resp.Header().Get("X-Test")) - assert.Equal(t, "content", resp.Body.String()) -} - -func TestGuardedResponseWriter_Hijack(t *testing.T) { - resp := httptest.NewRecorder() - writer := &guardedResponseWriter{ - w: resp, - } - assert.NotPanics(t, func() { - writer.Hijack() - }) - - writer = &guardedResponseWriter{ - w: mockedHijackable{resp}, - } - assert.NotPanics(t, func() { - writer.Hijack() - }) -} - -type mockedHijackable struct { - *httptest.ResponseRecorder -} - -func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return nil, nil, nil -} diff --git a/rest/internal/response/headeronceresponsewriter.go b/rest/internal/response/headeronceresponsewriter.go new file mode 100644 index 00000000..62ba4442 --- /dev/null +++ b/rest/internal/response/headeronceresponsewriter.go @@ -0,0 +1,57 @@ +package response + +import ( + "bufio" + "errors" + "net" + "net/http" +) + +// HeaderOnceResponseWriter is a http.ResponseWriter implementation +// that only the first WriterHeader takes effect. +type HeaderOnceResponseWriter struct { + w http.ResponseWriter + wroteHeader bool +} + +// NewHeaderOnceResponseWriter returns a HeaderOnceResponseWriter. +func NewHeaderOnceResponseWriter(w http.ResponseWriter) http.ResponseWriter { + return &HeaderOnceResponseWriter{w: w} +} + +// Flush flushes the response writer. +func (w *HeaderOnceResponseWriter) Flush() { + if flusher, ok := w.w.(http.Flusher); ok { + flusher.Flush() + } +} + +// Header returns the http header. +func (w *HeaderOnceResponseWriter) Header() http.Header { + return w.w.Header() +} + +// Hijack implements the http.Hijacker interface. +// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. +func (w *HeaderOnceResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacked, ok := w.w.(http.Hijacker); ok { + return hijacked.Hijack() + } + + return nil, nil, errors.New("server doesn't support hijacking") +} + +// Write writes bytes into w. +func (w *HeaderOnceResponseWriter) Write(bytes []byte) (int, error) { + return w.w.Write(bytes) +} + +// WriteHeader writes code into w, and not sealing the writer. +func (w *HeaderOnceResponseWriter) WriteHeader(code int) { + if w.wroteHeader { + return + } + + w.w.WriteHeader(code) + w.wroteHeader = true +} diff --git a/rest/internal/response/headeronceresponsewriter_test.go b/rest/internal/response/headeronceresponsewriter_test.go new file mode 100644 index 00000000..cbc8fe1d --- /dev/null +++ b/rest/internal/response/headeronceresponsewriter_test.go @@ -0,0 +1,58 @@ +package response + +import ( + "bufio" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHeaderOnceResponseWriter_Flush(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cw := NewHeaderOnceResponseWriter(w) + cw.Header().Set("X-Test", "test") + cw.WriteHeader(http.StatusServiceUnavailable) + cw.WriteHeader(http.StatusExpectationFailed) + _, err := cw.Write([]byte("content")) + assert.Nil(t, err) + + flusher, ok := cw.(http.Flusher) + assert.True(t, ok) + flusher.Flush() + }) + + resp := httptest.NewRecorder() + handler.ServeHTTP(resp, req) + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) + assert.Equal(t, "test", resp.Header().Get("X-Test")) + assert.Equal(t, "content", resp.Body.String()) +} + +func TestHeaderOnceResponseWriter_Hijack(t *testing.T) { + resp := httptest.NewRecorder() + writer := &HeaderOnceResponseWriter{ + w: resp, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) + + writer = &HeaderOnceResponseWriter{ + w: mockedHijackable{resp}, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) +} + +type mockedHijackable struct { + *httptest.ResponseRecorder +} + +func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, nil +} diff --git a/rest/internal/security/withcoderesponsewriter.go b/rest/internal/response/withcoderesponsewriter.go similarity index 85% rename from rest/internal/security/withcoderesponsewriter.go rename to rest/internal/response/withcoderesponsewriter.go index 2496e301..4fa9631e 100644 --- a/rest/internal/security/withcoderesponsewriter.go +++ b/rest/internal/response/withcoderesponsewriter.go @@ -1,7 +1,8 @@ -package security +package response import ( "bufio" + "errors" "net" "net/http" ) @@ -27,7 +28,11 @@ func (w *WithCodeResponseWriter) Header() http.Header { // Hijack implements the http.Hijacker interface. // This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it. func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.Writer.(http.Hijacker).Hijack() + if hijacked, ok := w.Writer.(http.Hijacker); ok { + return hijacked.Hijack() + } + + return nil, nil, errors.New("server doesn't support hijacking") } // Write writes bytes into w. diff --git a/rest/internal/security/withcoderesponsewriter_test.go b/rest/internal/response/withcoderesponsewriter_test.go similarity index 71% rename from rest/internal/security/withcoderesponsewriter_test.go rename to rest/internal/response/withcoderesponsewriter_test.go index 3a627798..a4fcfae6 100644 --- a/rest/internal/security/withcoderesponsewriter_test.go +++ b/rest/internal/response/withcoderesponsewriter_test.go @@ -1,4 +1,4 @@ -package security +package response import ( "net/http" @@ -31,3 +31,20 @@ func TestWithCodeResponseWriter(t *testing.T) { assert.Equal(t, "test", resp.Header().Get("X-Test")) assert.Equal(t, "content", resp.Body.String()) } + +func TestWithCodeResponseWriter_Hijack(t *testing.T) { + resp := httptest.NewRecorder() + writer := &WithCodeResponseWriter{ + Writer: resp, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) + + writer = &WithCodeResponseWriter{ + Writer: mockedHijackable{resp}, + } + assert.NotPanics(t, func() { + writer.Hijack() + }) +} diff --git a/rest/server.go b/rest/server.go index 391ee75c..0c053ceb 100644 --- a/rest/server.go +++ b/rest/server.go @@ -49,6 +49,7 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) { router: router.NewRouter(), } + opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...) for _, opt := range opts { opt(server) } @@ -163,7 +164,8 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route { // WithNotFoundHandler returns a RunOption with not found handler set to given handler. func WithNotFoundHandler(handler http.Handler) RunOption { return func(server *Server) { - server.router.SetNotFoundHandler(handler) + notFoundHandler := server.ngin.notFoundHandler(handler) + server.router.SetNotFoundHandler(notFoundHandler) } }