chore: refactor to simplify disabling builtin middlewares (#2031)
* chore: refactor to simplify disabling builtin middlewares * chore: rename methods
This commit is contained in:
@@ -25,15 +25,15 @@ const topCpuUsage = 1000
|
|||||||
var ErrSignatureConfig = errors.New("bad config for Signature")
|
var ErrSignatureConfig = errors.New("bad config for Signature")
|
||||||
|
|
||||||
type engine struct {
|
type engine struct {
|
||||||
conf RestConf
|
conf RestConf
|
||||||
routes []featuredRoutes
|
routes []featuredRoutes
|
||||||
unauthorizedCallback handler.UnauthorizedCallback
|
unauthorizedCallback handler.UnauthorizedCallback
|
||||||
unsignedCallback handler.UnsignedCallback
|
unsignedCallback handler.UnsignedCallback
|
||||||
middlewares []Middleware
|
disableDefaultMiddlewares bool
|
||||||
shedder load.Shedder
|
middlewares []Middleware
|
||||||
priorityShedder load.Shedder
|
shedder load.Shedder
|
||||||
tlsConfig *tls.Config
|
priorityShedder load.Shedder
|
||||||
chain *alice.Chain
|
tlsConfig *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func newEngine(c RestConf) *engine {
|
func newEngine(c RestConf) *engine {
|
||||||
@@ -87,7 +87,7 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
|
|||||||
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
||||||
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
||||||
var chain alice.Chain
|
var chain alice.Chain
|
||||||
if ng.chain == nil {
|
if !ng.disableDefaultMiddlewares {
|
||||||
chain = alice.New(
|
chain = alice.New(
|
||||||
handler.TracingHandler(ng.conf.Name, route.Path),
|
handler.TracingHandler(ng.conf.Name, route.Path),
|
||||||
ng.getLogHandler(),
|
ng.getLogHandler(),
|
||||||
@@ -101,15 +101,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
|
|||||||
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
|
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
|
||||||
handler.GunzipHandler,
|
handler.GunzipHandler,
|
||||||
)
|
)
|
||||||
} else {
|
|
||||||
chain = *ng.chain
|
|
||||||
}
|
}
|
||||||
|
|
||||||
chain = ng.appendAuthHandler(fr, chain, verifier)
|
|
||||||
|
|
||||||
for _, middleware := range ng.middlewares {
|
for _, middleware := range ng.middlewares {
|
||||||
chain = chain.Append(convertMiddleware(middleware))
|
chain = chain.Append(convertMiddleware(middleware))
|
||||||
}
|
}
|
||||||
|
chain = ng.appendAuthHandler(fr, chain, verifier)
|
||||||
handle := chain.ThenFunc(route.Handler)
|
handle := chain.ThenFunc(route.Handler)
|
||||||
|
|
||||||
return router.Handle(route.Method, route.Path, handle)
|
return router.Handle(route.Method, route.Path, handle)
|
||||||
@@ -213,10 +210,6 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
|
|||||||
ng.tlsConfig = cfg
|
ng.tlsConfig = cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) setChainConfig(chain *alice.Chain) {
|
|
||||||
ng.chain = chain
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
||||||
ng.unauthorizedCallback = callback
|
ng.unauthorizedCallback = callback
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -250,7 +250,9 @@ func TestEngine_checkedChain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2()))
|
server := MustNewServer(RestConf{}, DisableDefaultMiddlewares())
|
||||||
|
server.Use(ToMiddleware(middleware1()))
|
||||||
|
server.Use(ToMiddleware(middleware2()))
|
||||||
server.router = chainRouter{}
|
server.router = chainRouter{}
|
||||||
server.AddRoutes(
|
server.AddRoutes(
|
||||||
[]Route{
|
[]Route{
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/justinas/alice"
|
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/rest/handler"
|
"github.com/zeromicro/go-zero/rest/handler"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
@@ -96,6 +95,13 @@ func (s *Server) Use(middleware Middleware) {
|
|||||||
s.ngin.use(middleware)
|
s.ngin.use(middleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DisableDefaultMiddlewares returns a RunOption that disables the builtin middlewares.
|
||||||
|
func DisableDefaultMiddlewares() RunOption {
|
||||||
|
return func(svr *Server) {
|
||||||
|
svr.ngin.disableDefaultMiddlewares = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ToMiddleware converts the given handler to a Middleware.
|
// ToMiddleware converts the given handler to a Middleware.
|
||||||
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
func ToMiddleware(handler func(next http.Handler) http.Handler) Middleware {
|
||||||
return func(handle http.HandlerFunc) http.HandlerFunc {
|
return func(handle http.HandlerFunc) http.HandlerFunc {
|
||||||
@@ -243,17 +249,6 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithChain returns a RunOption that with given chain config.
|
|
||||||
func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption {
|
|
||||||
return func(svr *Server) {
|
|
||||||
chain := alice.New()
|
|
||||||
for _, middleware := range middlewares {
|
|
||||||
chain = chain.Append(middleware)
|
|
||||||
}
|
|
||||||
svr.ngin.setChainConfig(&chain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
|
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
|
||||||
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
||||||
return func(svr *Server) {
|
return func(svr *Server) {
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/zeromicro/go-zero/core/conf"
|
"github.com/zeromicro/go-zero/core/conf"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"github.com/zeromicro/go-zero/core/service"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"github.com/zeromicro/go-zero/rest/router"
|
"github.com/zeromicro/go-zero/rest/router"
|
||||||
)
|
)
|
||||||
@@ -102,6 +103,18 @@ Port: 54321
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewServerError(t *testing.T) {
|
||||||
|
_, err := NewServer(RestConf{
|
||||||
|
ServiceConf: service.ServiceConf{
|
||||||
|
Log: logx.LogConf{
|
||||||
|
// file mode, no path specified
|
||||||
|
Mode: "file",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithMaxBytes(t *testing.T) {
|
func TestWithMaxBytes(t *testing.T) {
|
||||||
const maxBytes = 1000
|
const maxBytes = 1000
|
||||||
var fr featuredRoutes
|
var fr featuredRoutes
|
||||||
@@ -320,6 +333,7 @@ Port: 54321
|
|||||||
rt := router.NewRouter()
|
rt := router.NewRouter()
|
||||||
svr, err := NewServer(cnf, WithRouter(rt))
|
svr, err := NewServer(cnf, WithRouter(rt))
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
defer svr.Stop()
|
||||||
|
|
||||||
opt := WithCors("local")
|
opt := WithCors("local")
|
||||||
opt(svr)
|
opt(svr)
|
||||||
@@ -408,3 +422,16 @@ Port: 54321
|
|||||||
out := <-ch
|
out := <-ch
|
||||||
assert.Equal(t, expect, out)
|
assert.Equal(t, expect, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleError(t *testing.T) {
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
handleError(nil)
|
||||||
|
handleError(http.ErrServerClosed)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateSecret(t *testing.T) {
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
validateSecret("short")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user