feat: support CORS by using rest.WithCors(...) (#1212)
* feat: support CORS by using rest.WithCors(...) * chore: add comments * refactor: lowercase unexported methods * ci: fix lint errors
This commit is contained in:
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/tal-tech/go-zero/rest/handler"
|
"github.com/tal-tech/go-zero/rest/handler"
|
||||||
"github.com/tal-tech/go-zero/rest/httpx"
|
"github.com/tal-tech/go-zero/rest/httpx"
|
||||||
"github.com/tal-tech/go-zero/rest/internal"
|
"github.com/tal-tech/go-zero/rest/internal"
|
||||||
"github.com/tal-tech/go-zero/rest/router"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// use 1000m to represent 100%
|
// use 1000m to represent 100%
|
||||||
@@ -47,39 +46,10 @@ func newEngine(c RestConf) *engine {
|
|||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) AddRoutes(r featuredRoutes) {
|
func (ng *engine) addRoutes(r featuredRoutes) {
|
||||||
ng.routes = append(ng.routes, r)
|
ng.routes = append(ng.routes, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
|
||||||
ng.unauthorizedCallback = callback
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ng *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
|
|
||||||
ng.unsignedCallback = callback
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ng *engine) Start() error {
|
|
||||||
return ng.StartWithRouter(router.NewRouter())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ng *engine) StartWithRouter(router httpx.Router) error {
|
|
||||||
if err := ng.bindRoutes(router); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
|
||||||
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
|
|
||||||
}
|
|
||||||
|
|
||||||
return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
|
|
||||||
ng.conf.KeyFile, router, func(srv *http.Server) {
|
|
||||||
if ng.tlsConfig != nil {
|
|
||||||
srv.TLSConfig = ng.tlsConfig
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
||||||
verifier func(alice.Chain) alice.Chain) alice.Chain {
|
verifier func(alice.Chain) alice.Chain) alice.Chain {
|
||||||
if fr.jwt.enabled {
|
if fr.jwt.enabled {
|
||||||
@@ -188,6 +158,14 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
|
|||||||
ng.tlsConfig = cfg
|
ng.tlsConfig = cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
||||||
|
ng.unauthorizedCallback = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
|
||||||
|
ng.unsignedCallback = callback
|
||||||
|
}
|
||||||
|
|
||||||
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
|
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
|
||||||
if !signature.enabled {
|
if !signature.enabled {
|
||||||
return func(chain alice.Chain) alice.Chain {
|
return func(chain alice.Chain) alice.Chain {
|
||||||
@@ -228,6 +206,23 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ng *engine) start(router httpx.Router) error {
|
||||||
|
if err := ng.bindRoutes(router); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
||||||
|
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
|
||||||
|
ng.conf.KeyFile, router, func(srv *http.Server) {
|
||||||
|
if ng.tlsConfig != nil {
|
||||||
|
srv.TLSConfig = ng.tlsConfig
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (ng *engine) use(middleware Middleware) {
|
func (ng *engine) use(middleware Middleware) {
|
||||||
ng.middlewares = append(ng.middlewares, middleware)
|
ng.middlewares = append(ng.middlewares, middleware)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -144,13 +144,13 @@ Verbose: true
|
|||||||
var cnf RestConf
|
var cnf RestConf
|
||||||
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
|
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(yaml), &cnf))
|
||||||
ng := newEngine(cnf)
|
ng := newEngine(cnf)
|
||||||
ng.AddRoutes(route)
|
ng.addRoutes(route)
|
||||||
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
|
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
assert.NotNil(t, ng.StartWithRouter(mockedRouter{}))
|
assert.NotNil(t, ng.start(mockedRouter{}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
package rest
|
|
||||||
|
|
||||||
import "net/http"
|
|
||||||
|
|
||||||
const (
|
|
||||||
allowOrigin = "Access-Control-Allow-Origin"
|
|
||||||
allOrigins = "*"
|
|
||||||
allowMethods = "Access-Control-Allow-Methods"
|
|
||||||
allowHeaders = "Access-Control-Allow-Headers"
|
|
||||||
headers = "Content-Type, Content-Length, Origin"
|
|
||||||
methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CorsHandler handles cross domain OPTIONS requests.
|
|
||||||
// At most one origin can be specified, other origins are ignored if given.
|
|
||||||
func CorsHandler(origins ...string) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if len(origins) > 0 {
|
|
||||||
w.Header().Set(allowOrigin, origins[0])
|
|
||||||
} else {
|
|
||||||
w.Header().Set(allowOrigin, allOrigins)
|
|
||||||
}
|
|
||||||
w.Header().Set(allowMethods, methods)
|
|
||||||
w.Header().Set(allowHeaders, headers)
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
package rest
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCorsHandlerWithOrigins(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
origins []string
|
|
||||||
expect string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "allow all origins",
|
|
||||||
expect: allOrigins,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "allow one origin",
|
|
||||||
origins: []string{"local"},
|
|
||||||
expect: "local",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "allow many origins",
|
|
||||||
origins: []string{"local", "remote"},
|
|
||||||
expect: "local",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler := CorsHandler(test.origins...)
|
|
||||||
handler.ServeHTTP(w, nil)
|
|
||||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
|
||||||
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
64
rest/internal/cors/handlers.go
Normal file
64
rest/internal/cors/handlers.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package cors
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
const (
|
||||||
|
allowOrigin = "Access-Control-Allow-Origin"
|
||||||
|
allOrigins = "*"
|
||||||
|
allowMethods = "Access-Control-Allow-Methods"
|
||||||
|
allowHeaders = "Access-Control-Allow-Headers"
|
||||||
|
allowCredentials = "Access-Control-Allow-Credentials"
|
||||||
|
exposeHeaders = "Access-Control-Expose-Headers"
|
||||||
|
allowHeadersVal = "Content-Type, Origin, X-CSRF-Token, Authorization, AccessToken, Token, Range"
|
||||||
|
exposeHeadersVal = "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers"
|
||||||
|
methods = "GET, HEAD, POST, PATCH, PUT, DELETE"
|
||||||
|
allowTrue = "true"
|
||||||
|
maxAgeHeader = "Access-Control-Max-Age"
|
||||||
|
maxAgeHeaderVal = "86400"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handler handles cross domain not allowed requests.
|
||||||
|
// At most one origin can be specified, other origins are ignored if given, default to be *.
|
||||||
|
func Handler(origin ...string) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
setHeader(w, getOrigin(origin))
|
||||||
|
|
||||||
|
if r.Method != http.MethodOptions {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware returns a middleware that adds CORS headers to the response.
|
||||||
|
func Middleware(origin ...string) func(http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
setHeader(w, getOrigin(origin))
|
||||||
|
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
} else {
|
||||||
|
next(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOrigin(origins []string) string {
|
||||||
|
if len(origins) > 0 {
|
||||||
|
return origins[0]
|
||||||
|
} else {
|
||||||
|
return allOrigins
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setHeader(w http.ResponseWriter, origin string) {
|
||||||
|
w.Header().Set(allowOrigin, origin)
|
||||||
|
w.Header().Set(allowMethods, methods)
|
||||||
|
w.Header().Set(allowHeaders, allowHeadersVal)
|
||||||
|
w.Header().Set(exposeHeaders, exposeHeadersVal)
|
||||||
|
w.Header().Set(allowCredentials, allowTrue)
|
||||||
|
w.Header().Set(maxAgeHeader, maxAgeHeaderVal)
|
||||||
|
}
|
||||||
76
rest/internal/cors/handlers_test.go
Normal file
76
rest/internal/cors/handlers_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package cors
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCorsHandlerWithOrigins(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
origins []string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "allow all origins",
|
||||||
|
expect: allOrigins,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow one origin",
|
||||||
|
origins: []string{"local"},
|
||||||
|
expect: "local",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allow many origins",
|
||||||
|
origins: []string{"local", "remote"},
|
||||||
|
expect: "local",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
methods := []string{
|
||||||
|
http.MethodOptions,
|
||||||
|
http.MethodGet,
|
||||||
|
http.MethodPost,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
for _, method := range methods {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name+"-handler", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler := Handler(test.origins...)
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if method == http.MethodOptions {
|
||||||
|
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Result().StatusCode)
|
||||||
|
}
|
||||||
|
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
for _, method := range methods {
|
||||||
|
test := test
|
||||||
|
t.Run(test.name+"-middleware", func(t *testing.T) {
|
||||||
|
r := httptest.NewRequest(method, "http://localhost", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler := Middleware(test.origins...)(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
handler.ServeHTTP(w, r)
|
||||||
|
if method == http.MethodOptions {
|
||||||
|
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, http.StatusOK, w.Result().StatusCode)
|
||||||
|
}
|
||||||
|
assert.Equal(t, test.expect, w.Header().Get(allowOrigin))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,21 +10,18 @@ import (
|
|||||||
"github.com/tal-tech/go-zero/core/logx"
|
"github.com/tal-tech/go-zero/core/logx"
|
||||||
"github.com/tal-tech/go-zero/rest/handler"
|
"github.com/tal-tech/go-zero/rest/handler"
|
||||||
"github.com/tal-tech/go-zero/rest/httpx"
|
"github.com/tal-tech/go-zero/rest/httpx"
|
||||||
|
"github.com/tal-tech/go-zero/rest/internal/cors"
|
||||||
"github.com/tal-tech/go-zero/rest/router"
|
"github.com/tal-tech/go-zero/rest/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
runOptions struct {
|
|
||||||
start func(*engine) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunOption defines the method to customize a Server.
|
// RunOption defines the method to customize a Server.
|
||||||
RunOption func(*Server)
|
RunOption func(*Server)
|
||||||
|
|
||||||
// A Server is a http server.
|
// A Server is a http server.
|
||||||
Server struct {
|
Server struct {
|
||||||
ngin *engine
|
ngin *engine
|
||||||
opts runOptions
|
router httpx.Router
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,12 +45,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
server := &Server{
|
server := &Server{
|
||||||
ngin: newEngine(c),
|
ngin: newEngine(c),
|
||||||
opts: runOptions{
|
router: router.NewRouter(),
|
||||||
start: func(ng *engine) error {
|
|
||||||
return ng.Start()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
@@ -71,7 +64,7 @@ func (s *Server) AddRoutes(rs []Route, opts ...RouteOption) {
|
|||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(&r)
|
opt(&r)
|
||||||
}
|
}
|
||||||
s.ngin.AddRoutes(r)
|
s.ngin.addRoutes(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRoute adds given route into the Server.
|
// AddRoute adds given route into the Server.
|
||||||
@@ -83,7 +76,7 @@ func (s *Server) AddRoute(r Route, opts ...RouteOption) {
|
|||||||
// Graceful shutdown is enabled by default.
|
// Graceful shutdown is enabled by default.
|
||||||
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
||||||
func (s *Server) Start() {
|
func (s *Server) Start() {
|
||||||
handleError(s.opts.start(s.ngin))
|
handleError(s.ngin.start(s.router))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the Server.
|
// Stop stops the Server.
|
||||||
@@ -103,6 +96,14 @@ func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithCors returns a func to enable CORS for given origin, or default to all origins (*).
|
||||||
|
func WithCors(origin ...string) RunOption {
|
||||||
|
return func(server *Server) {
|
||||||
|
server.router.SetNotAllowedHandler(cors.Handler(origin...))
|
||||||
|
server.Use(cors.Middleware(origin...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithJwt returns a func to enable jwt authentication in given route.
|
// WithJwt returns a func to enable jwt authentication in given route.
|
||||||
func WithJwt(secret string) RouteOption {
|
func WithJwt(secret string) RouteOption {
|
||||||
return func(r *featuredRoutes) {
|
return func(r *featuredRoutes) {
|
||||||
@@ -151,16 +152,16 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route {
|
|||||||
|
|
||||||
// WithNotFoundHandler returns a RunOption with not found handler set to given handler.
|
// WithNotFoundHandler returns a RunOption with not found handler set to given handler.
|
||||||
func WithNotFoundHandler(handler http.Handler) RunOption {
|
func WithNotFoundHandler(handler http.Handler) RunOption {
|
||||||
rt := router.NewRouter()
|
return func(server *Server) {
|
||||||
rt.SetNotFoundHandler(handler)
|
server.router.SetNotFoundHandler(handler)
|
||||||
return WithRouter(rt)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler.
|
// WithNotAllowedHandler returns a RunOption with not allowed handler set to given handler.
|
||||||
func WithNotAllowedHandler(handler http.Handler) RunOption {
|
func WithNotAllowedHandler(handler http.Handler) RunOption {
|
||||||
rt := router.NewRouter()
|
return func(server *Server) {
|
||||||
rt.SetNotAllowedHandler(handler)
|
server.router.SetNotAllowedHandler(handler)
|
||||||
return WithRouter(rt)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPrefix adds group as a prefix to the route paths.
|
// WithPrefix adds group as a prefix to the route paths.
|
||||||
@@ -189,9 +190,7 @@ func WithPriority() RouteOption {
|
|||||||
// WithRouter returns a RunOption that make server run with given router.
|
// WithRouter returns a RunOption that make server run with given router.
|
||||||
func WithRouter(router httpx.Router) RunOption {
|
func WithRouter(router httpx.Router) RunOption {
|
||||||
return func(server *Server) {
|
return func(server *Server) {
|
||||||
server.opts.start = func(ng *engine) error {
|
server.router = router
|
||||||
return ng.StartWithRouter(router)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,14 +221,14 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
|
|||||||
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
|
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
|
||||||
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
||||||
return func(srv *Server) {
|
return func(srv *Server) {
|
||||||
srv.ngin.SetUnauthorizedCallback(callback)
|
srv.ngin.setUnauthorizedCallback(callback)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithUnsignedCallback returns a RunOption that with given unsigned callback set.
|
// WithUnsignedCallback returns a RunOption that with given unsigned callback set.
|
||||||
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
|
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
|
||||||
return func(srv *Server) {
|
return func(srv *Server) {
|
||||||
srv.ngin.SetUnsignedCallback(callback)
|
srv.ngin.setUnsignedCallback(callback)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,11 +22,6 @@ Port: 54321
|
|||||||
`
|
`
|
||||||
var cnf RestConf
|
var cnf RestConf
|
||||||
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
|
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
|
||||||
failStart := func(server *Server) {
|
|
||||||
server.opts.start = func(e *engine) error {
|
|
||||||
return http.ErrServerClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
c RestConf
|
c RestConf
|
||||||
@@ -35,38 +30,40 @@ Port: 54321
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
c: RestConf{},
|
c: RestConf{},
|
||||||
opts: []RunOption{failStart},
|
opts: []RunOption{WithRouter(mockedRouter{}), WithCors()},
|
||||||
fail: true,
|
fail: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
c: cnf,
|
c: cnf,
|
||||||
opts: []RunOption{failStart},
|
opts: []RunOption{WithRouter(mockedRouter{})},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
c: cnf,
|
c: cnf,
|
||||||
opts: []RunOption{WithNotAllowedHandler(nil), failStart},
|
opts: []RunOption{WithRouter(mockedRouter{}), WithNotAllowedHandler(nil)},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
c: cnf,
|
c: cnf,
|
||||||
opts: []RunOption{WithNotFoundHandler(nil), failStart},
|
opts: []RunOption{WithNotFoundHandler(nil), WithRouter(mockedRouter{})},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
c: cnf,
|
c: cnf,
|
||||||
opts: []RunOption{WithUnauthorizedCallback(nil), failStart},
|
opts: []RunOption{WithUnauthorizedCallback(nil), WithRouter(mockedRouter{})},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
c: cnf,
|
c: cnf,
|
||||||
opts: []RunOption{WithUnsignedCallback(nil), failStart},
|
opts: []RunOption{WithUnsignedCallback(nil), WithRouter(mockedRouter{})},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
srv, err := NewServer(test.c, test.opts...)
|
var srv *Server
|
||||||
|
var err error
|
||||||
if test.fail {
|
if test.fail {
|
||||||
|
_, err = NewServer(test.c, test.opts...)
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
continue
|
continue
|
||||||
|
} else {
|
||||||
|
srv = MustNewServer(test.c, test.opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
|
srv.Use(ToMiddleware(func(next http.Handler) http.Handler {
|
||||||
@@ -80,8 +77,21 @@ Port: 54321
|
|||||||
Handler: nil,
|
Handler: nil,
|
||||||
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
|
}, WithJwt("thesecret"), WithSignature(SignatureConf{}),
|
||||||
WithJwtTransition("preivous", "thenewone"))
|
WithJwtTransition("preivous", "thenewone"))
|
||||||
srv.Start()
|
|
||||||
srv.Stop()
|
func() {
|
||||||
|
defer func() {
|
||||||
|
p := recover()
|
||||||
|
switch v := p.(type) {
|
||||||
|
case error:
|
||||||
|
assert.Equal(t, "foo", v.Error())
|
||||||
|
default:
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
srv.Start()
|
||||||
|
srv.Stop()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +190,9 @@ func TestMultiMiddlewares(t *testing.T) {
|
|||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
ToMiddleware(func(next http.Handler) http.Handler {
|
||||||
|
return next
|
||||||
|
}),
|
||||||
}, Route{
|
}, Route{
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
Path: "/first/:name/:year",
|
Path: "/first/:name/:year",
|
||||||
@@ -282,3 +295,18 @@ Port: 54321
|
|||||||
assert.Equal(t, srv.ngin.tlsConfig, testCase.res)
|
assert.Equal(t, srv.ngin.tlsConfig, testCase.res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithCors(t *testing.T) {
|
||||||
|
const configYaml = `
|
||||||
|
Name: foo
|
||||||
|
Port: 54321
|
||||||
|
`
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
|
||||||
|
rt := router.NewRouter()
|
||||||
|
srv, err := NewServer(cnf, WithRouter(rt))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
opt := WithCors("local")
|
||||||
|
opt(srv)
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,12 +27,12 @@ import (
|
|||||||
{{.importPackages}}
|
{{.importPackages}}
|
||||||
)
|
)
|
||||||
|
|
||||||
func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) {
|
func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
|
||||||
{{.routesAdditions}}
|
{{.routesAdditions}}
|
||||||
}
|
}
|
||||||
`
|
`
|
||||||
routesAdditionTemplate = `
|
routesAdditionTemplate = `
|
||||||
engine.AddRoutes(
|
server.AddRoutes(
|
||||||
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}}
|
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}}
|
||||||
)
|
)
|
||||||
`
|
`
|
||||||
|
|||||||
Reference in New Issue
Block a user