feat: rest.WithChain to replace builtin middlewares (#2033)

* feat: rest.WithChain to replace builtin middlewares

* chore: add comments

* chore: refine code
This commit is contained in:
Kevin Wan
2022-06-19 17:41:33 +08:00
committed by GitHub
parent 50f16e2892
commit 47c49de94e
6 changed files with 322 additions and 98 deletions

View File

@@ -8,10 +8,10 @@ import (
"sort"
"time"
"github.com/justinas/alice"
"github.com/zeromicro/go-zero/core/codec"
"github.com/zeromicro/go-zero/core/load"
"github.com/zeromicro/go-zero/core/stat"
"github.com/zeromicro/go-zero/rest/chain"
"github.com/zeromicro/go-zero/rest/handler"
"github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/rest/internal"
@@ -25,15 +25,15 @@ const topCpuUsage = 1000
var ErrSignatureConfig = errors.New("bad config for Signature")
type engine struct {
conf RestConf
routes []featuredRoutes
unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback handler.UnsignedCallback
disableDefaultMiddlewares bool
middlewares []Middleware
shedder load.Shedder
priorityShedder load.Shedder
tlsConfig *tls.Config
conf RestConf
routes []featuredRoutes
unauthorizedCallback handler.UnauthorizedCallback
unsignedCallback handler.UnsignedCallback
chain chain.Chain
middlewares []Middleware
shedder load.Shedder
priorityShedder load.Shedder
tlsConfig *tls.Config
}
func newEngine(c RestConf) *engine {
@@ -53,20 +53,20 @@ func (ng *engine) addRoutes(r featuredRoutes) {
ng.routes = append(ng.routes, r)
}
func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
verifier func(alice.Chain) alice.Chain) alice.Chain {
func (ng *engine) appendAuthHandler(fr featuredRoutes, chn chain.Chain,
verifier func(chain.Chain) chain.Chain) chain.Chain {
if fr.jwt.enabled {
if len(fr.jwt.prevSecret) == 0 {
chain = chain.Append(handler.Authorize(fr.jwt.secret,
chn = chn.Append(handler.Authorize(fr.jwt.secret,
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
} else {
chain = chain.Append(handler.Authorize(fr.jwt.secret,
chn = chn.Append(handler.Authorize(fr.jwt.secret,
handler.WithPrevSecret(fr.jwt.prevSecret),
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
}
}
return verifier(chain)
return verifier(chn)
}
func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
@@ -85,10 +85,10 @@ func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, met
}
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
route Route, verifier func(chain alice.Chain) alice.Chain) error {
var chain alice.Chain
if !ng.disableDefaultMiddlewares {
chain = alice.New(
route Route, verifier func(chain.Chain) chain.Chain) error {
chn := ng.chain
if chn == nil {
chn = chain.New(
handler.TracingHandler(ng.conf.Name, route.Path),
ng.getLogHandler(),
handler.PrometheusHandler(route.Path),
@@ -103,11 +103,12 @@ func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *sta
)
}
chn = ng.appendAuthHandler(fr, chn, verifier)
for _, middleware := range ng.middlewares {
chain = chain.Append(convertMiddleware(middleware))
chn = chn.Append(convertMiddleware(middleware))
}
chain = ng.appendAuthHandler(fr, chain, verifier)
handle := chain.ThenFunc(route.Handler)
handle := chn.ThenFunc(route.Handler)
return router.Handle(route.Method, route.Path, handle)
}
@@ -171,16 +172,16 @@ func (ng *engine) getShedder(priority bool) load.Shedder {
// notFoundHandler returns a middleware that handles 404 not found requests.
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
chain := alice.New(
chn := chain.New(
handler.TracingHandler(ng.conf.Name, ""),
ng.getLogHandler(),
)
var h http.Handler
if next != nil {
h = chain.Then(next)
h = chn.Then(next)
} else {
h = chain.Then(http.NotFoundHandler())
h = chn.Then(http.NotFoundHandler())
}
cw := response.NewHeaderOnceResponseWriter(w)
@@ -218,10 +219,10 @@ func (ng *engine) setUnsignedCallback(callback handler.UnsignedCallback) {
ng.unsignedCallback = callback
}
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain.Chain) chain.Chain, error) {
if !signature.enabled {
return func(chain alice.Chain) alice.Chain {
return chain
return func(chn chain.Chain) chain.Chain {
return chn
}, nil
}
@@ -230,8 +231,8 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
return nil, ErrSignatureConfig
}
return func(chain alice.Chain) alice.Chain {
return chain
return func(chn chain.Chain) chain.Chain {
return chn
}, nil
}
@@ -247,14 +248,13 @@ func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alic
decrypters[fingerprint] = decrypter
}
return func(chain alice.Chain) alice.Chain {
return func(chn chain.Chain) chain.Chain {
if ng.unsignedCallback != nil {
return chain.Append(handler.ContentSecurityHandler(
return chn.Append(handler.ContentSecurityHandler(
decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
}
return chain.Append(handler.ContentSecurityHandler(
decrypters, signature.Expiry, signature.Strict))
return chn.Append(handler.ContentSecurityHandler(decrypters, signature.Expiry, signature.Strict))
}, nil
}