feat: support CORS by using rest.WithCors(...) (#1212)

* feat: support CORS by using rest.WithCors(...)

* chore: add comments

* refactor: lowercase unexported methods

* ci: fix lint errors
This commit is contained in:
Kevin Wan
2021-11-07 22:42:40 +08:00
committed by GitHub
parent e8efcef108
commit c28e01fed3
9 changed files with 238 additions and 145 deletions

View File

@@ -10,21 +10,18 @@ import (
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/rest/handler"
"github.com/tal-tech/go-zero/rest/httpx"
"github.com/tal-tech/go-zero/rest/internal/cors"
"github.com/tal-tech/go-zero/rest/router"
)
type (
runOptions struct {
start func(*engine) error
}
// RunOption defines the method to customize a Server.
RunOption func(*Server)
// A Server is a http server.
Server struct {
ngin *engine
opts runOptions
ngin *engine
router httpx.Router
}
)
@@ -48,12 +45,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
}
server := &Server{
ngin: newEngine(c),
opts: runOptions{
start: func(ng *engine) error {
return ng.Start()
},
},
ngin: newEngine(c),
router: router.NewRouter(),
}
for _, opt := range opts {
@@ -71,7 +64,7 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
for _, opt := range opts {
opt(&r)
}
s.ngin.AddRoutes(r)
s.ngin.addRoutes(r)
}
// AddRoute adds given route into the Server.
@@ -83,7 +76,7 @@ func (s *Server) AddRoute(r Route, opts ...RouteOption) {
// Graceful shutdown is enabled by default.
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
func (s *Server) Start() {
handleError(s.opts.start(s.ngin))
handleError(s.ngin.start(s.router))
}
// Stop stops the Server.
@@ -103,6 +96,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
}
}
// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
func WithCors(origin ...string) RunOption {
return func(server *Server) {
server.router.SetNotAllowedHandler(cors.Handler(origin...))
server.Use(cors.Middleware(origin...))
}
}
// WithJwt returns a func to enable jwt authentication in given route.
func WithJwt(secret string) RouteOption {
return func(r *featuredRoutes) {
@@ -151,16 +152,16 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
// WithNotFoundHandler returns a RunOption with not found handler set to given handler.
func WithNotFoundHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotFoundHandler(handler)
return WithRouter(rt)
return func(server *Server) {
server.router.SetNotFoundHandler(handler)
}
}
// WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler.
func WithNotAllowedHandler(handler http.Handler) RunOption {
rt := router.NewRouter()
rt.SetNotAllowedHandler(handler)
return WithRouter(rt)
return func(server *Server) {
server.router.SetNotAllowedHandler(handler)
}
}
// WithPrefix adds group as a prefix to the route paths.
@@ -189,9 +190,7 @@ func WithPriority() RouteOption {
// WithRouter returns a RunOption that make server run with given router.
func WithRouter(router httpx.Router) RunOption {
return func(server *Server) {
server.opts.start = func(ng *engine) error {
return ng.StartWithRouter(router)
}
server.router = router
}
}
@@ -222,14 +221,14 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
return func(srv *Server) {
srv.ngin.SetUnauthorizedCallback(callback)
srv.ngin.setUnauthorizedCallback(callback)
}
}
// WithUnsignedCallback returns a RunOption that with given unsigned callback set.
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
return func(srv *Server) {
srv.ngin.SetUnsignedCallback(callback)
srv.ngin.setUnsignedCallback(callback)
}
}