add user middleware chain function (#1913)
* add user middleware chain function * fix staticcheck SA4006 * chang code Implementation style Co-authored-by: kemq1 <kemq1@spdb.com.cn>
This commit is contained in:
@@ -33,6 +33,7 @@ type engine struct {
|
|||||||
shedder load.Shedder
|
shedder load.Shedder
|
||||||
priorityShedder load.Shedder
|
priorityShedder load.Shedder
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
|
chain *alice.Chain
|
||||||
}
|
}
|
||||||
|
|
||||||
func newEngine(c RestConf) *engine {
|
func newEngine(c RestConf) *engine {
|
||||||
@@ -85,19 +86,25 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
|
|||||||
|
|
||||||
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
||||||
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
||||||
chain := alice.New(
|
var chain alice.Chain
|
||||||
handler.TracingHandler(ng.conf.Name, route.Path),
|
if ng.chain == nil {
|
||||||
ng.getLogHandler(),
|
chain = alice.New(
|
||||||
handler.PrometheusHandler(route.Path),
|
handler.TracingHandler(ng.conf.Name, route.Path),
|
||||||
handler.MaxConns(ng.conf.MaxConns),
|
ng.getLogHandler(),
|
||||||
handler.BreakerHandler(route.Method, route.Path, metrics),
|
handler.PrometheusHandler(route.Path),
|
||||||
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
|
handler.MaxConns(ng.conf.MaxConns),
|
||||||
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
|
handler.BreakerHandler(route.Method, route.Path, metrics),
|
||||||
handler.RecoverHandler,
|
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
|
||||||
handler.MetricHandler(metrics),
|
handler.TimeoutHandler(ng.checkedTimeout(fr.timeout)),
|
||||||
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
|
handler.RecoverHandler,
|
||||||
handler.GunzipHandler,
|
handler.MetricHandler(metrics),
|
||||||
)
|
handler.MaxBytesHandler(ng.checkedMaxBytes(fr.maxBytes)),
|
||||||
|
handler.GunzipHandler,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
chain = *ng.chain
|
||||||
|
}
|
||||||
|
|
||||||
chain = ng.appendAuthHandler(fr, chain, verifier)
|
chain = ng.appendAuthHandler(fr, chain, verifier)
|
||||||
|
|
||||||
for _, middleware := range ng.middlewares {
|
for _, middleware := range ng.middlewares {
|
||||||
@@ -206,6 +213,10 @@ func (ng *engine) setTlsConfig(cfg *tls.Config) {
|
|||||||
ng.tlsConfig = cfg
|
ng.tlsConfig = cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ng *engine) setChainConfig(chain *alice.Chain) {
|
||||||
|
ng.chain = chain
|
||||||
|
}
|
||||||
|
|
||||||
func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
func (ng *engine) setUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
||||||
ng.unauthorizedCallback = callback
|
ng.unauthorizedCallback = callback
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -229,6 +229,44 @@ func TestEngine_checkedMaxBytes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEngine_checkedChain(t *testing.T) {
|
||||||
|
var called int32
|
||||||
|
middleware1 := func() func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
middleware2 := func() func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
server := MustNewServer(RestConf{}, WithChain(middleware1(), middleware2()))
|
||||||
|
server.router = chainRouter{}
|
||||||
|
server.AddRoutes(
|
||||||
|
[]Route{
|
||||||
|
{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
atomic.AddInt32(&called, 1)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
server.ngin.bindRoutes(chainRouter{})
|
||||||
|
assert.Equal(t, int32(5), atomic.LoadInt32(&called))
|
||||||
|
}
|
||||||
|
|
||||||
func TestEngine_notFoundHandler(t *testing.T) {
|
func TestEngine_notFoundHandler(t *testing.T) {
|
||||||
logx.Disable()
|
logx.Disable()
|
||||||
|
|
||||||
@@ -343,3 +381,19 @@ func (m mockedRouter) SetNotFoundHandler(_ http.Handler) {
|
|||||||
|
|
||||||
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
|
func (m mockedRouter) SetNotAllowedHandler(_ http.Handler) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type chainRouter struct{}
|
||||||
|
|
||||||
|
func (c chainRouter) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainRouter) Handle(_, _ string, handler http.Handler) error {
|
||||||
|
handler.ServeHTTP(nil, nil)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainRouter) SetNotFoundHandler(_ http.Handler) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c chainRouter) SetNotAllowedHandler(_ http.Handler) {
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/justinas/alice"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/rest/handler"
|
"github.com/zeromicro/go-zero/rest/handler"
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
@@ -242,6 +243,17 @@ func WithTLSConfig(cfg *tls.Config) RunOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithChain returns a RunOption that with given chain config.
|
||||||
|
func WithChain(middlewares ...func(http.Handler) http.Handler) RunOption {
|
||||||
|
return func(svr *Server) {
|
||||||
|
chain := alice.New()
|
||||||
|
for _, middleware := range middlewares {
|
||||||
|
chain = chain.Append(middleware)
|
||||||
|
}
|
||||||
|
svr.ngin.setChainConfig(&chain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
|
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
|
||||||
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
||||||
return func(svr *Server) {
|
return func(svr *Server) {
|
||||||
|
|||||||
Reference in New Issue
Block a user