refactor: simplify tls config in rest (#1181)
This commit is contained in:
1
go.mod
1
go.mod
@@ -54,4 +54,5 @@ require (
|
||||
k8s.io/api v0.20.10
|
||||
k8s.io/apimachinery v0.20.10
|
||||
k8s.io/client-go v0.20.10
|
||||
k8s.io/utils v0.0.0-20201110183641-67b214c5f920
|
||||
)
|
||||
|
||||
@@ -35,7 +35,7 @@ type (
|
||||
KeyFile string `json:",optional"`
|
||||
Verbose bool `json:",optional"`
|
||||
MaxConns int `json:",default=10000"`
|
||||
MaxBytes int64 `json:",default=1048576,range=[0:33554432]"`
|
||||
MaxBytes int64 `json:",default=1048576"`
|
||||
// milliseconds
|
||||
Timeout int64 `json:",default=3000"`
|
||||
CpuThreshold int64 `json:",default=900,range=[0:1000]"`
|
||||
|
||||
103
rest/engine.go
103
rest/engine.go
@@ -47,58 +47,63 @@ func newEngine(c RestConf) *engine {
|
||||
return srv
|
||||
}
|
||||
|
||||
func (s *engine) AddRoutes(r featuredRoutes) {
|
||||
s.routes = append(s.routes, r)
|
||||
func (ng *engine) AddRoutes(r featuredRoutes) {
|
||||
ng.routes = append(ng.routes, r)
|
||||
}
|
||||
|
||||
func (s *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
||||
s.unauthorizedCallback = callback
|
||||
func (ng *engine) SetUnauthorizedCallback(callback handler.UnauthorizedCallback) {
|
||||
ng.unauthorizedCallback = callback
|
||||
}
|
||||
|
||||
func (s *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
|
||||
s.unsignedCallback = callback
|
||||
func (ng *engine) SetUnsignedCallback(callback handler.UnsignedCallback) {
|
||||
ng.unsignedCallback = callback
|
||||
}
|
||||
|
||||
func (s *engine) Start() error {
|
||||
return s.StartWithRouter(router.NewRouter())
|
||||
func (ng *engine) Start() error {
|
||||
return ng.StartWithRouter(router.NewRouter())
|
||||
}
|
||||
|
||||
func (s *engine) StartWithRouter(router httpx.Router) error {
|
||||
if err := s.bindRoutes(router); err != nil {
|
||||
func (ng *engine) StartWithRouter(router httpx.Router) error {
|
||||
if err := ng.bindRoutes(router); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(s.conf.CertFile) == 0 && len(s.conf.KeyFile) == 0 {
|
||||
return internal.StartHttp(s.conf.Host, s.conf.Port, router)
|
||||
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
||||
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router)
|
||||
}
|
||||
|
||||
return internal.StartHttps(s.conf.Host, s.conf.Port, s.conf.CertFile, s.conf.KeyFile, s.tlsConfig, router)
|
||||
return internal.StartHttps(ng.conf.Host, ng.conf.Port, ng.conf.CertFile,
|
||||
ng.conf.KeyFile, router, func(srv *http.Server) {
|
||||
if ng.tlsConfig != nil {
|
||||
srv.TLSConfig = ng.tlsConfig
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
||||
func (ng *engine) appendAuthHandler(fr featuredRoutes, chain alice.Chain,
|
||||
verifier func(alice.Chain) alice.Chain) alice.Chain {
|
||||
if fr.jwt.enabled {
|
||||
if len(fr.jwt.prevSecret) == 0 {
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
||||
} else {
|
||||
chain = chain.Append(handler.Authorize(fr.jwt.secret,
|
||||
handler.WithPrevSecret(fr.jwt.prevSecret),
|
||||
handler.WithUnauthorizedCallback(s.unauthorizedCallback)))
|
||||
handler.WithUnauthorizedCallback(ng.unauthorizedCallback)))
|
||||
}
|
||||
}
|
||||
|
||||
return verifier(chain)
|
||||
}
|
||||
|
||||
func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
||||
verifier, err := s.signatureVerifier(fr.signature)
|
||||
func (ng *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metrics *stat.Metrics) error {
|
||||
verifier, err := ng.signatureVerifier(fr.signature)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, route := range fr.routes {
|
||||
if err := s.bindRoute(fr, router, metrics, route, verifier); err != nil {
|
||||
if err := ng.bindRoute(fr, router, metrics, route, verifier); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -106,24 +111,24 @@ func (s *engine) bindFeaturedRoutes(router httpx.Router, fr featuredRoutes, metr
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
||||
func (ng *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat.Metrics,
|
||||
route Route, verifier func(chain alice.Chain) alice.Chain) error {
|
||||
chain := alice.New(
|
||||
handler.TracingHandler(s.conf.Name, route.Path),
|
||||
s.getLogHandler(),
|
||||
handler.TracingHandler(ng.conf.Name, route.Path),
|
||||
ng.getLogHandler(),
|
||||
handler.PrometheusHandler(route.Path),
|
||||
handler.MaxConns(s.conf.MaxConns),
|
||||
handler.MaxConns(ng.conf.MaxConns),
|
||||
handler.BreakerHandler(route.Method, route.Path, metrics),
|
||||
handler.SheddingHandler(s.getShedder(fr.priority), metrics),
|
||||
handler.TimeoutHandler(time.Duration(s.conf.Timeout)*time.Millisecond),
|
||||
handler.SheddingHandler(ng.getShedder(fr.priority), metrics),
|
||||
handler.TimeoutHandler(time.Duration(ng.conf.Timeout)*time.Millisecond),
|
||||
handler.RecoverHandler,
|
||||
handler.MetricHandler(metrics),
|
||||
handler.MaxBytesHandler(s.conf.MaxBytes),
|
||||
handler.MaxBytesHandler(ng.conf.MaxBytes),
|
||||
handler.GunzipHandler,
|
||||
)
|
||||
chain = s.appendAuthHandler(fr, chain, verifier)
|
||||
chain = ng.appendAuthHandler(fr, chain, verifier)
|
||||
|
||||
for _, middleware := range s.middlewares {
|
||||
for _, middleware := range ng.middlewares {
|
||||
chain = chain.Append(convertMiddleware(middleware))
|
||||
}
|
||||
handle := chain.ThenFunc(route.Handler)
|
||||
@@ -131,11 +136,11 @@ func (s *engine) bindRoute(fr featuredRoutes, router httpx.Router, metrics *stat
|
||||
return router.Handle(route.Method, route.Path, handle)
|
||||
}
|
||||
|
||||
func (s *engine) bindRoutes(router httpx.Router) error {
|
||||
metrics := s.createMetrics()
|
||||
func (ng *engine) bindRoutes(router httpx.Router) error {
|
||||
metrics := ng.createMetrics()
|
||||
|
||||
for _, fr := range s.routes {
|
||||
if err := s.bindFeaturedRoutes(router, fr, metrics); err != nil {
|
||||
for _, fr := range ng.routes {
|
||||
if err := ng.bindFeaturedRoutes(router, fr, metrics); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -143,35 +148,39 @@ func (s *engine) bindRoutes(router httpx.Router) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *engine) createMetrics() *stat.Metrics {
|
||||
func (ng *engine) createMetrics() *stat.Metrics {
|
||||
var metrics *stat.Metrics
|
||||
|
||||
if len(s.conf.Name) > 0 {
|
||||
metrics = stat.NewMetrics(s.conf.Name)
|
||||
if len(ng.conf.Name) > 0 {
|
||||
metrics = stat.NewMetrics(ng.conf.Name)
|
||||
} else {
|
||||
metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", s.conf.Host, s.conf.Port))
|
||||
metrics = stat.NewMetrics(fmt.Sprintf("%s:%d", ng.conf.Host, ng.conf.Port))
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
func (s *engine) getLogHandler() func(http.Handler) http.Handler {
|
||||
if s.conf.Verbose {
|
||||
func (ng *engine) getLogHandler() func(http.Handler) http.Handler {
|
||||
if ng.conf.Verbose {
|
||||
return handler.DetailedLogHandler
|
||||
}
|
||||
|
||||
return handler.LogHandler
|
||||
}
|
||||
|
||||
func (s *engine) getShedder(priority bool) load.Shedder {
|
||||
if priority && s.priorityShedder != nil {
|
||||
return s.priorityShedder
|
||||
func (ng *engine) getShedder(priority bool) load.Shedder {
|
||||
if priority && ng.priorityShedder != nil {
|
||||
return ng.priorityShedder
|
||||
}
|
||||
|
||||
return s.shedder
|
||||
return ng.shedder
|
||||
}
|
||||
|
||||
func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
|
||||
func (ng *engine) setTlsConfig(cfg *tls.Config) {
|
||||
ng.tlsConfig = cfg
|
||||
}
|
||||
|
||||
func (ng *engine) signatureVerifier(signature signatureSetting) (func(chain alice.Chain) alice.Chain, error) {
|
||||
if !signature.enabled {
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
return chain
|
||||
@@ -201,9 +210,9 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice
|
||||
}
|
||||
|
||||
return func(chain alice.Chain) alice.Chain {
|
||||
if s.unsignedCallback != nil {
|
||||
if ng.unsignedCallback != nil {
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
decrypters, signature.Expiry, signature.Strict, s.unsignedCallback))
|
||||
decrypters, signature.Expiry, signature.Strict, ng.unsignedCallback))
|
||||
}
|
||||
|
||||
return chain.Append(handler.ContentSecurityHandler(
|
||||
@@ -211,8 +220,8 @@ func (s *engine) signatureVerifier(signature signatureSetting) (func(chain alice
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *engine) use(middleware Middleware) {
|
||||
s.middlewares = append(s.middlewares, middleware)
|
||||
func (ng *engine) use(middleware Middleware) {
|
||||
ng.middlewares = append(ng.middlewares, middleware)
|
||||
}
|
||||
|
||||
func convertMiddleware(ware Middleware) func(http.Handler) http.Handler {
|
||||
|
||||
@@ -2,38 +2,46 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/proc"
|
||||
)
|
||||
|
||||
// StartOption defines the method to customize http.Server.
|
||||
type StartOption func(srv *http.Server)
|
||||
|
||||
// StartHttp starts a http server.
|
||||
func StartHttp(host string, port int, handler http.Handler) error {
|
||||
return start(host, port, handler, nil, func(srv *http.Server) error {
|
||||
func StartHttp(host string, port int, handler http.Handler, opts ...StartOption) error {
|
||||
return start(host, port, handler, func(srv *http.Server) error {
|
||||
return srv.ListenAndServe()
|
||||
})
|
||||
}, opts...)
|
||||
}
|
||||
|
||||
// StartHttps starts a https server.
|
||||
func StartHttps(host string, port int, certFile, keyFile string, tlsConfig *tls.Config, handler http.Handler) error {
|
||||
return start(host, port, handler, tlsConfig, func(srv *http.Server) error {
|
||||
func StartHttps(host string, port int, certFile, keyFile string, handler http.Handler,
|
||||
opts ...StartOption) error {
|
||||
return start(host, port, handler, func(srv *http.Server) error {
|
||||
// certFile and keyFile are set in buildHttpsServer
|
||||
return srv.ListenAndServeTLS(certFile, keyFile)
|
||||
})
|
||||
}, opts...)
|
||||
}
|
||||
|
||||
func start(host string, port int, handler http.Handler, tlsConfig *tls.Config, run func(srv *http.Server) error) (err error) {
|
||||
func start(host string, port int, handler http.Handler, run func(srv *http.Server) error,
|
||||
opts ...StartOption) (err error) {
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", host, port),
|
||||
Handler: handler,
|
||||
}
|
||||
if tlsConfig != nil {
|
||||
server.TLSConfig = tlsConfig
|
||||
for _, opt := range opts {
|
||||
opt(server)
|
||||
}
|
||||
|
||||
waitForCalled := proc.AddWrapUpListener(func() {
|
||||
server.Shutdown(context.Background())
|
||||
if e := server.Shutdown(context.Background()); err != nil {
|
||||
logx.Error(e)
|
||||
}
|
||||
})
|
||||
defer func() {
|
||||
if err == http.ErrServerClosed {
|
||||
|
||||
@@ -48,8 +48,8 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) {
|
||||
server := &Server{
|
||||
ngin: newEngine(c),
|
||||
opts: runOptions{
|
||||
start: func(srv *engine) error {
|
||||
return srv.Start()
|
||||
start: func(ng *engine) error {
|
||||
return ng.Start()
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -171,8 +171,8 @@ func WithPriority() RouteOption {
|
||||
// WithRouter returns a RunOption that make server run with given router.
|
||||
func WithRouter(router httpx.Router) RunOption {
|
||||
return func(server *Server) {
|
||||
server.opts.start = func(srv *engine) error {
|
||||
return srv.StartWithRouter(router)
|
||||
server.opts.start = func(ng *engine) error {
|
||||
return ng.StartWithRouter(router)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -187,26 +187,24 @@ func WithSignature(signature SignatureConf) RouteOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
|
||||
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
||||
return func(engine *Server) {
|
||||
engine.ngin.SetUnauthorizedCallback(callback)
|
||||
// WithTLSConfig returns a RunOption that with given tls config.
|
||||
func WithTLSConfig(cfg *tls.Config) RunOption {
|
||||
return func(srv *Server) {
|
||||
srv.ngin.setTlsConfig(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig returns a RunOption that with given tls config.
|
||||
func WithTLSConfig(cipherSuites []uint16) RunOption {
|
||||
return func(engine *Server) {
|
||||
engine.ngin.tlsConfig = &tls.Config{
|
||||
CipherSuites: cipherSuites,
|
||||
}
|
||||
// WithUnauthorizedCallback returns a RunOption that with given unauthorized callback set.
|
||||
func WithUnauthorizedCallback(callback handler.UnauthorizedCallback) RunOption {
|
||||
return func(srv *Server) {
|
||||
srv.ngin.SetUnauthorizedCallback(callback)
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnsignedCallback returns a RunOption that with given unsigned callback set.
|
||||
func WithUnsignedCallback(callback handler.UnsignedCallback) RunOption {
|
||||
return func(engine *Server) {
|
||||
engine.ngin.SetUnsignedCallback(callback)
|
||||
return func(srv *Server) {
|
||||
srv.ngin.SetUnsignedCallback(callback)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -227,8 +227,10 @@ Port: 54321
|
||||
var cnf RestConf
|
||||
assert.Nil(t, conf.LoadConfigFromYamlBytes([]byte(configYaml), &cnf))
|
||||
|
||||
testConfig := []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
testConfig := &tls.Config{
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
@@ -239,7 +241,7 @@ Port: 54321
|
||||
{
|
||||
c: cnf,
|
||||
opts: []RunOption{WithTLSConfig(testConfig)},
|
||||
res: &tls.Config{CipherSuites: testConfig},
|
||||
res: testConfig,
|
||||
},
|
||||
{
|
||||
c: cnf,
|
||||
|
||||
Reference in New Issue
Block a user