@@ -126,7 +126,7 @@ func WithChain(chn chain.Chain) RunOption {
|
|||||||
func WithCors(origin ...string) RunOption {
|
func WithCors(origin ...string) RunOption {
|
||||||
return func(server *Server) {
|
return func(server *Server) {
|
||||||
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...))
|
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...))
|
||||||
server.Use(cors.Middleware(nil, origin...))
|
server.router = newCorsRouter(server.router, nil, origin...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,7 +136,7 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt
|
|||||||
origin ...string) RunOption {
|
origin ...string) RunOption {
|
||||||
return func(server *Server) {
|
return func(server *Server) {
|
||||||
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
|
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...))
|
||||||
server.Use(cors.Middleware(middlewareFn, origin...))
|
server.router = newCorsRouter(server.router, middlewareFn, origin...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,3 +291,19 @@ func validateSecret(secret string) {
|
|||||||
panic("secret's length can't be less than 8")
|
panic("secret's length can't be less than 8")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type corsRouter struct {
|
||||||
|
httpx.Router
|
||||||
|
middleware Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...string) httpx.Router {
|
||||||
|
return &corsRouter{
|
||||||
|
Router: router,
|
||||||
|
middleware: cors.Middleware(headerFn, origins...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
c.middleware(c.Router.ServeHTTP)(w, r)
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/rest/chain"
|
"github.com/zeromicro/go-zero/rest/chain"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
|
"github.com/zeromicro/go-zero/rest/internal/cors"
|
||||||
"github.com/zeromicro/go-zero/rest/router"
|
"github.com/zeromicro/go-zero/rest/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -515,3 +516,23 @@ func TestServer_WithChain(t *testing.T) {
|
|||||||
rt.ServeHTTP(httptest.NewRecorder(), req)
|
rt.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
|
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_WithCors(t *testing.T) {
|
||||||
|
var called int32
|
||||||
|
middleware := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
r := router.NewRouter()
|
||||||
|
assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler())))
|
||||||
|
|
||||||
|
cr := &corsRouter{
|
||||||
|
Router: r,
|
||||||
|
middleware: cors.Middleware(nil, "*"),
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
cr.ServeHTTP(httptest.NewRecorder(), req)
|
||||||
|
assert.Equal(t, int32(0), atomic.LoadInt32(&called))
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user