diff --git a/rest/server.go b/rest/server.go index 0d0bad13..4c7a30ff 100644 --- a/rest/server.go +++ b/rest/server.go @@ -126,7 +126,7 @@ func WithChain(chn chain.Chain) RunOption { func WithCors(origin ...string) RunOption { return func(server *Server) { server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...)) - server.Use(cors.Middleware(nil, origin...)) + server.router = newCorsRouter(server.router, nil, origin...) } } @@ -136,7 +136,7 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt origin ...string) RunOption { return func(server *Server) { server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...)) - server.Use(cors.Middleware(middlewareFn, origin...)) + server.router = newCorsRouter(server.router, middlewareFn, origin...) } } @@ -291,3 +291,19 @@ func validateSecret(secret string) { panic("secret's length can't be less than 8") } } + +type corsRouter struct { + httpx.Router + middleware Middleware +} + +func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...string) httpx.Router { + return &corsRouter{ + Router: router, + middleware: cors.Middleware(headerFn, origins...), + } +} + +func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.middleware(c.Router.ServeHTTP)(w, r) +} diff --git a/rest/server_test.go b/rest/server_test.go index bc99f125..f356ff4a 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -18,6 +18,7 @@ import ( "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/rest/chain" "github.com/zeromicro/go-zero/rest/httpx" + "github.com/zeromicro/go-zero/rest/internal/cors" "github.com/zeromicro/go-zero/rest/router" ) @@ -515,3 +516,23 @@ func TestServer_WithChain(t *testing.T) { rt.ServeHTTP(httptest.NewRecorder(), req) assert.Equal(t, int32(5), atomic.LoadInt32(&called)) } + +func TestServer_WithCors(t *testing.T) { + var called int32 + middleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + next.ServeHTTP(w, r) + }) + } + r := router.NewRouter() + assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler()))) + + cr := &corsRouter{ + Router: r, + middleware: cors.Middleware(nil, "*"), + } + req := httptest.NewRequest(http.MethodOptions, "/", nil) + cr.ServeHTTP(httptest.NewRecorder(), req) + assert.Equal(t, int32(0), atomic.LoadInt32(&called)) +}