feat: log 404 requests with traceid (#1554)
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/rest/handler"
|
"github.com/zeromicro/go-zero/rest/handler"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/internal"
|
"github.com/zeromicro/go-zero/rest/internal"
|
||||||
|
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||||
)
|
)
|
||||||
|
|
||||||
// use 1000m to represent 100%
|
// use 1000m to represent 100%
|
||||||
@@ -154,6 +155,27 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
|
|||||||
return ng.shedder
|
return ng.shedder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||||
|
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
chain := alice.New(
|
||||||
|
handler.TracingHandler(ng.conf.Name, ""),
|
||||||
|
ng.getLogHandler(),
|
||||||
|
)
|
||||||
|
|
||||||
|
var h http.Handler
|
||||||
|
if next != nil {
|
||||||
|
h = chain.Then(next)
|
||||||
|
} else {
|
||||||
|
h = chain.Then(http.NotFoundHandler())
|
||||||
|
}
|
||||||
|
|
||||||
|
cw := response.NewHeaderOnceResponseWriter(w)
|
||||||
|
h.ServeHTTP(cw, r)
|
||||||
|
cw.WriteHeader(http.StatusNotFound)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (ng *engine) setTlsConfig(cfg *tls.Config) {
|
func (ng *engine) setTlsConfig(cfg *tls.Config) {
|
||||||
ng.tlsConfig = cfg
|
ng.tlsConfig = cfg
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
package rest
|
package rest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/conf"
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewEngine(t *testing.T) {
|
func TestNewEngine(t *testing.T) {
|
||||||
@@ -190,6 +194,75 @@ func TestEngine_checkedTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEngine_notFoundHandler(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
ng := newEngine(RestConf{})
|
||||||
|
ts := httptest.NewServer(ng.notFoundHandler(nil))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := ts.Client()
|
||||||
|
err := func(ctx context.Context) error {
|
||||||
|
req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
res, err := client.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||||
|
return res.Body.Close()
|
||||||
|
}(context.Background())
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_notFoundHandlerNotNil(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
ng := newEngine(RestConf{})
|
||||||
|
var called int32
|
||||||
|
ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
})))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := ts.Client()
|
||||||
|
err := func(ctx context.Context) error {
|
||||||
|
req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
res, err := client.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusNotFound, res.StatusCode)
|
||||||
|
return res.Body.Close()
|
||||||
|
}(context.Background())
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, int32(1), atomic.LoadInt32(&called))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEngine_notFoundHandlerNotNilWriteHeader(t *testing.T) {
|
||||||
|
logx.Disable()
|
||||||
|
|
||||||
|
ng := newEngine(RestConf{})
|
||||||
|
var called int32
|
||||||
|
ts := httptest.NewServer(ng.notFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
w.WriteHeader(http.StatusExpectationFailed)
|
||||||
|
})))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := ts.Client()
|
||||||
|
err := func(ctx context.Context) error {
|
||||||
|
req, err := http.NewRequest("GET", ts.URL+"/bad", nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
res, err := client.Do(req)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, http.StatusExpectationFailed, res.StatusCode)
|
||||||
|
return res.Body.Close()
|
||||||
|
}(context.Background())
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, int32(1), atomic.LoadInt32(&called))
|
||||||
|
}
|
||||||
|
|
||||||
type mockedRouter struct{}
|
type mockedRouter struct{}
|
||||||
|
|
||||||
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
func (m mockedRouter) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||||
"github.com/zeromicro/go-zero/rest/token"
|
"github.com/zeromicro/go-zero/rest/token"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -105,7 +104,7 @@ func detailAuthLog(r *http.Request, reason string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) {
|
func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback UnauthorizedCallback) {
|
||||||
writer := newGuardedResponseWriter(w)
|
writer := response.NewHeaderOnceResponseWriter(w)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
detailAuthLog(r, err.Error())
|
detailAuthLog(r, err.Error())
|
||||||
@@ -121,47 +120,3 @@ func unauthorized(w http.ResponseWriter, r *http.Request, err error, callback Un
|
|||||||
// if user not setting HTTP header, we set header with 401
|
// if user not setting HTTP header, we set header with 401
|
||||||
writer.WriteHeader(http.StatusUnauthorized)
|
writer.WriteHeader(http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
type guardedResponseWriter struct {
|
|
||||||
writer http.ResponseWriter
|
|
||||||
wroteHeader bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newGuardedResponseWriter(w http.ResponseWriter) *guardedResponseWriter {
|
|
||||||
return &guardedResponseWriter{
|
|
||||||
writer: w,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (grw *guardedResponseWriter) Flush() {
|
|
||||||
if flusher, ok := grw.writer.(http.Flusher); ok {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (grw *guardedResponseWriter) Header() http.Header {
|
|
||||||
return grw.writer.Header()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hijack implements the http.Hijacker interface.
|
|
||||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
|
||||||
func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
if hijacked, ok := grw.writer.(http.Hijacker); ok {
|
|
||||||
return hijacked.Hijack()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil, errors.New("server doesn't support hijacking")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
|
|
||||||
return grw.writer.Write(body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (grw *guardedResponseWriter) WriteHeader(statusCode int) {
|
|
||||||
if grw.wroteHeader {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
grw.wroteHeader = true
|
|
||||||
grw.writer.WriteHeader(statusCode)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -90,26 +90,6 @@ func TestAuthHandler_NilError(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthHandler_Flush(t *testing.T) {
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
handler := newGuardedResponseWriter(resp)
|
|
||||||
handler.Flush()
|
|
||||||
assert.True(t, resp.Flushed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAuthHandler_Hijack(t *testing.T) {
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
writer := newGuardedResponseWriter(resp)
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
writer.Hijack()
|
|
||||||
})
|
|
||||||
|
|
||||||
writer = newGuardedResponseWriter(mockedHijackable{resp})
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
writer.Hijack()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
claims := make(jwt.MapClaims)
|
claims := make(jwt.MapClaims)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/security"
|
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||||
)
|
)
|
||||||
|
|
||||||
const breakerSeparator = "://"
|
const breakerSeparator = "://"
|
||||||
@@ -28,7 +28,7 @@ func BreakerHandler(method, path string, metrics *stat.Metrics) func(http.Handle
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
cw := &response.WithCodeResponseWriter{Writer: w}
|
||||||
defer func() {
|
defer func() {
|
||||||
if cw.Code < http.StatusInternalServerError {
|
if cw.Code < http.StatusInternalServerError {
|
||||||
promise.Accept()
|
promise.Accept()
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/metric"
|
"github.com/zeromicro/go-zero/core/metric"
|
||||||
"github.com/zeromicro/go-zero/core/prometheus"
|
"github.com/zeromicro/go-zero/core/prometheus"
|
||||||
"github.com/zeromicro/go-zero/core/timex"
|
"github.com/zeromicro/go-zero/core/timex"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/security"
|
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||||
)
|
)
|
||||||
|
|
||||||
const serverNamespace = "http_server"
|
const serverNamespace = "http_server"
|
||||||
@@ -41,7 +41,7 @@ func PrometheusHandler(path string) func(http.Handler) http.Handler {
|
|||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
startTime := timex.Now()
|
startTime := timex.Now()
|
||||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
cw := &response.WithCodeResponseWriter{Writer: w}
|
||||||
defer func() {
|
defer func() {
|
||||||
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
|
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), path)
|
||||||
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))
|
metricServerReqCodeTotal.Inc(path, strconv.Itoa(cw.Code))
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/internal/security"
|
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||||
)
|
)
|
||||||
|
|
||||||
const serviceType = "api"
|
const serviceType = "api"
|
||||||
@@ -41,7 +41,7 @@ func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Hand
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cw := &security.WithCodeResponseWriter{Writer: w}
|
cw := &response.WithCodeResponseWriter{Writer: w}
|
||||||
defer func() {
|
defer func() {
|
||||||
if cw.Code == http.StatusServiceUnavailable {
|
if cw.Code == http.StatusServiceUnavailable {
|
||||||
promise.Fail()
|
promise.Fail()
|
||||||
|
|||||||
@@ -18,12 +18,16 @@ func TracingHandler(serviceName, path string) func(http.Handler) http.Handler {
|
|||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
|
ctx := propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
|
||||||
|
spanName := path
|
||||||
|
if len(spanName) == 0 {
|
||||||
|
spanName = r.URL.Path
|
||||||
|
}
|
||||||
spanCtx, span := tracer.Start(
|
spanCtx, span := tracer.Start(
|
||||||
ctx,
|
ctx,
|
||||||
path,
|
spanName,
|
||||||
oteltrace.WithSpanKind(oteltrace.SpanKindServer),
|
oteltrace.WithSpanKind(oteltrace.SpanKindServer),
|
||||||
oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
|
oteltrace.WithAttributes(semconv.HTTPServerAttributesFromHTTPRequest(
|
||||||
serviceName, path, r)...),
|
serviceName, spanName, r)...),
|
||||||
)
|
)
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/justinas/alice"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
ztrace "github.com/zeromicro/go-zero/core/trace"
|
ztrace "github.com/zeromicro/go-zero/core/trace"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
@@ -21,28 +22,31 @@ func TestOtelHandler(t *testing.T) {
|
|||||||
Sampler: 1.0,
|
Sampler: 1.0,
|
||||||
})
|
})
|
||||||
|
|
||||||
ts := httptest.NewServer(
|
for _, test := range []string{"", "bar"} {
|
||||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
t.Run(test, func(t *testing.T) {
|
||||||
ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header))
|
h := alice.New(TracingHandler("foo", test)).Then(
|
||||||
spanCtx := trace.SpanContextFromContext(ctx)
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
assert.Equal(t, true, spanCtx.IsValid())
|
ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header))
|
||||||
}),
|
spanCtx := trace.SpanContextFromContext(ctx)
|
||||||
)
|
assert.True(t, spanCtx.IsValid())
|
||||||
defer ts.Close()
|
}))
|
||||||
|
ts := httptest.NewServer(h)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
client := ts.Client()
|
client := ts.Client()
|
||||||
err := func(ctx context.Context) error {
|
err := func(ctx context.Context) error {
|
||||||
ctx, span := otel.Tracer("httptrace/client").Start(ctx, "test")
|
ctx, span := otel.Tracer("httptrace/client").Start(ctx, "test")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", ts.URL, nil)
|
req, _ := http.NewRequest("GET", ts.URL, nil)
|
||||||
otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))
|
otel.GetTextMapPropagator().Inject(ctx, propagation.HeaderCarrier(req.Header))
|
||||||
|
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
assert.Equal(t, err, nil)
|
assert.Nil(t, err)
|
||||||
_ = res.Body.Close()
|
return res.Body.Close()
|
||||||
return nil
|
}(context.Background())
|
||||||
}(context.Background())
|
|
||||||
|
|
||||||
assert.Equal(t, err, nil)
|
assert.Nil(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
package cors
|
package cors
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/rest/internal/response"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -30,7 +29,7 @@ const (
|
|||||||
// At most one origin can be specified, other origins are ignored if given, default to be *.
|
// At most one origin can be specified, other origins are ignored if given, default to be *.
|
||||||
func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler {
|
func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
gw := &guardedResponseWriter{w: w}
|
gw := response.NewHeaderOnceResponseWriter(w)
|
||||||
checkAndSetHeaders(gw, r, origins)
|
checkAndSetHeaders(gw, r, origins)
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
fn(gw)
|
fn(gw)
|
||||||
@@ -62,44 +61,6 @@ func Middleware(fn func(w http.Header), origins ...string) func(http.HandlerFunc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type guardedResponseWriter struct {
|
|
||||||
w http.ResponseWriter
|
|
||||||
wroteHeader bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *guardedResponseWriter) Flush() {
|
|
||||||
if flusher, ok := w.w.(http.Flusher); ok {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *guardedResponseWriter) Header() http.Header {
|
|
||||||
return w.w.Header()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hijack implements the http.Hijacker interface.
|
|
||||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
|
||||||
func (w *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
if hijacked, ok := w.w.(http.Hijacker); ok {
|
|
||||||
return hijacked.Hijack()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil, errors.New("server doesn't support hijacking")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *guardedResponseWriter) Write(bytes []byte) (int, error) {
|
|
||||||
return w.w.Write(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *guardedResponseWriter) WriteHeader(code int) {
|
|
||||||
if w.wroteHeader {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.w.WriteHeader(code)
|
|
||||||
w.wroteHeader = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) {
|
func checkAndSetHeaders(w http.ResponseWriter, r *http.Request, origins []string) {
|
||||||
setVaryHeaders(w, r)
|
setVaryHeaders(w, r)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package cors
|
package cors
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -131,48 +129,3 @@ func TestCorsHandlerWithOrigins(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGuardedResponseWriter_Flush(t *testing.T) {
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
|
||||||
handler := NotAllowedHandler(func(w http.ResponseWriter) {
|
|
||||||
w.Header().Set("X-Test", "test")
|
|
||||||
w.WriteHeader(http.StatusServiceUnavailable)
|
|
||||||
_, err := w.Write([]byte("content"))
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
assert.True(t, ok)
|
|
||||||
flusher.Flush()
|
|
||||||
}, "foo.com")
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(resp, req)
|
|
||||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
|
||||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
|
||||||
assert.Equal(t, "content", resp.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGuardedResponseWriter_Hijack(t *testing.T) {
|
|
||||||
resp := httptest.NewRecorder()
|
|
||||||
writer := &guardedResponseWriter{
|
|
||||||
w: resp,
|
|
||||||
}
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
writer.Hijack()
|
|
||||||
})
|
|
||||||
|
|
||||||
writer = &guardedResponseWriter{
|
|
||||||
w: mockedHijackable{resp},
|
|
||||||
}
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
writer.Hijack()
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockedHijackable struct {
|
|
||||||
*httptest.ResponseRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
|
|||||||
57
rest/internal/response/headeronceresponsewriter.go
Normal file
57
rest/internal/response/headeronceresponsewriter.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package response
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeaderOnceResponseWriter is a http.ResponseWriter implementation
|
||||||
|
// that only the first WriterHeader takes effect.
|
||||||
|
type HeaderOnceResponseWriter struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
wroteHeader bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHeaderOnceResponseWriter returns a HeaderOnceResponseWriter.
|
||||||
|
func NewHeaderOnceResponseWriter(w http.ResponseWriter) http.ResponseWriter {
|
||||||
|
return &HeaderOnceResponseWriter{w: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush flushes the response writer.
|
||||||
|
func (w *HeaderOnceResponseWriter) Flush() {
|
||||||
|
if flusher, ok := w.w.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Header returns the http header.
|
||||||
|
func (w *HeaderOnceResponseWriter) Header() http.Header {
|
||||||
|
return w.w.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack implements the http.Hijacker interface.
|
||||||
|
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||||
|
func (w *HeaderOnceResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacked, ok := w.w.(http.Hijacker); ok {
|
||||||
|
return hijacked.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, errors.New("server doesn't support hijacking")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes bytes into w.
|
||||||
|
func (w *HeaderOnceResponseWriter) Write(bytes []byte) (int, error) {
|
||||||
|
return w.w.Write(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader writes code into w, and not sealing the writer.
|
||||||
|
func (w *HeaderOnceResponseWriter) WriteHeader(code int) {
|
||||||
|
if w.wroteHeader {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.w.WriteHeader(code)
|
||||||
|
w.wroteHeader = true
|
||||||
|
}
|
||||||
58
rest/internal/response/headeronceresponsewriter_test.go
Normal file
58
rest/internal/response/headeronceresponsewriter_test.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package response
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHeaderOnceResponseWriter_Flush(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
cw := NewHeaderOnceResponseWriter(w)
|
||||||
|
cw.Header().Set("X-Test", "test")
|
||||||
|
cw.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
cw.WriteHeader(http.StatusExpectationFailed)
|
||||||
|
_, err := cw.Write([]byte("content"))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
flusher, ok := cw.(http.Flusher)
|
||||||
|
assert.True(t, ok)
|
||||||
|
flusher.Flush()
|
||||||
|
})
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(resp, req)
|
||||||
|
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||||
|
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||||
|
assert.Equal(t, "content", resp.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeaderOnceResponseWriter_Hijack(t *testing.T) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
writer := &HeaderOnceResponseWriter{
|
||||||
|
w: resp,
|
||||||
|
}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
|
||||||
|
writer = &HeaderOnceResponseWriter{
|
||||||
|
w: mockedHijackable{resp},
|
||||||
|
}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockedHijackable struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
package security
|
package response
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@@ -27,7 +28,11 @@ func (w *WithCodeResponseWriter) Header() http.Header {
|
|||||||
// Hijack implements the http.Hijacker interface.
|
// Hijack implements the http.Hijacker interface.
|
||||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||||
func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
return w.Writer.(http.Hijacker).Hijack()
|
if hijacked, ok := w.Writer.(http.Hijacker); ok {
|
||||||
|
return hijacked.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil, errors.New("server doesn't support hijacking")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write writes bytes into w.
|
// Write writes bytes into w.
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package security
|
package response
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -31,3 +31,20 @@ func TestWithCodeResponseWriter(t *testing.T) {
|
|||||||
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
assert.Equal(t, "test", resp.Header().Get("X-Test"))
|
||||||
assert.Equal(t, "content", resp.Body.String())
|
assert.Equal(t, "content", resp.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithCodeResponseWriter_Hijack(t *testing.T) {
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
writer := &WithCodeResponseWriter{
|
||||||
|
Writer: resp,
|
||||||
|
}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
|
||||||
|
writer = &WithCodeResponseWriter{
|
||||||
|
Writer: mockedHijackable{resp},
|
||||||
|
}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
writer.Hijack()
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -49,6 +49,7 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
|||||||
router: router.NewRouter(),
|
router: router.NewRouter(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...)
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(server)
|
opt(server)
|
||||||
}
|
}
|
||||||
@@ -163,7 +164,8 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
|
|||||||
// WithNotFoundHandler returns a RunOption with not found handler set to given handler.
|
// WithNotFoundHandler returns a RunOption with not found handler set to given handler.
|
||||||
func WithNotFoundHandler(handler http.Handler) RunOption {
|
func WithNotFoundHandler(handler http.Handler) RunOption {
|
||||||
return func(server *Server) {
|
return func(server *Server) {
|
||||||
server.router.SetNotFoundHandler(handler)
|
notFoundHandler := server.ngin.notFoundHandler(handler)
|
||||||
|
server.router.SetNotFoundHandler(notFoundHandler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user