refactor: guard timeout on API files (#1726)
This commit is contained in:
@@ -75,6 +75,7 @@ func format(query string, args ...interface{}) (string, error) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if j > i+1 {
|
if j > i+1 {
|
||||||
index, err := strconv.Atoi(query[i+1 : j])
|
index, err := strconv.Atoi(query[i+1 : j])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -119,6 +119,14 @@ func (ng *engine) bindRoutes(router httpx.Router) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ng *engine) checkedMaxBytes(bytes int64) int64 {
|
||||||
|
if bytes > 0 {
|
||||||
|
return bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
return ng.conf.MaxBytes
|
||||||
|
}
|
||||||
|
|
||||||
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
|
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
|
||||||
if timeout > 0 {
|
if timeout > 0 {
|
||||||
return timeout
|
return timeout
|
||||||
@@ -127,15 +135,6 @@ func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
|
|||||||
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) checkedMaxBytes(bytes int64) int64 {
|
|
||||||
|
|
||||||
if bytes > 0 {
|
|
||||||
return bytes
|
|
||||||
}
|
|
||||||
|
|
||||||
return ng.conf.MaxBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ng *engine) createMetrics() *stat.Metrics {
|
func (ng *engine) createMetrics() *stat.Metrics {
|
||||||
var metrics *stat.Metrics
|
var metrics *stat.Metrics
|
||||||
|
|
||||||
|
|||||||
@@ -137,6 +137,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithMaxBytes returns a RouteOption to set maxBytes with the given value.
|
||||||
|
func WithMaxBytes(maxBytes int64) RouteOption {
|
||||||
|
return func(r *featuredRoutes) {
|
||||||
|
r.maxBytes = maxBytes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithMiddlewares adds given middlewares to given routes.
|
// WithMiddlewares adds given middlewares to given routes.
|
||||||
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
|
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
|
||||||
for i := len(ms) - 1; i >= 0; i-- {
|
for i := len(ms) - 1; i >= 0; i-- {
|
||||||
@@ -223,13 +230,6 @@ func WithTimeout(timeout time.Duration) RouteOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithMaxBytes returns a RouteOption to set maxBytes with given value.
|
|
||||||
func WithMaxBytes(maxBytes int64) RouteOption {
|
|
||||||
return func(r *featuredRoutes) {
|
|
||||||
r.maxBytes = maxBytes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTLSConfig returns a RunOption that with given tls config.
|
// WithTLSConfig returns a RunOption that with given tls config.
|
||||||
func WithTLSConfig(cfg *tls.Config) RunOption {
|
func WithTLSConfig(cfg *tls.Config) RunOption {
|
||||||
return func(svr *Server) {
|
return func(svr *Server) {
|
||||||
|
|||||||
@@ -95,6 +95,13 @@ Port: 54321
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithMaxBytes(t *testing.T) {
|
||||||
|
const maxBytes = 1000
|
||||||
|
var fr featuredRoutes
|
||||||
|
WithMaxBytes(maxBytes)(&fr)
|
||||||
|
assert.Equal(t, int64(maxBytes), fr.maxBytes)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithMiddleware(t *testing.T) {
|
func TestWithMiddleware(t *testing.T) {
|
||||||
m := make(map[string]string)
|
m := make(map[string]string)
|
||||||
rt := router.NewRouter()
|
rt := router.NewRouter()
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ const (
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"{{if .hasTimeout}}
|
||||||
|
"time"{{end}}
|
||||||
|
|
||||||
{{.importPackages}}
|
{{.importPackages}}
|
||||||
)
|
)
|
||||||
@@ -38,6 +39,7 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
|
|||||||
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
|
{{.routes}} {{.jwt}}{{.signature}} {{.prefix}} {{.timeout}}
|
||||||
)
|
)
|
||||||
`
|
`
|
||||||
|
timeoutThreshold = time.Millisecond
|
||||||
)
|
)
|
||||||
|
|
||||||
var mapping = map[string]string{
|
var mapping = map[string]string{
|
||||||
@@ -59,7 +61,6 @@ type (
|
|||||||
signatureEnabled bool
|
signatureEnabled bool
|
||||||
authName string
|
authName string
|
||||||
timeout string
|
timeout string
|
||||||
timeoutEnable bool
|
|
||||||
middlewares []string
|
middlewares []string
|
||||||
prefix string
|
prefix string
|
||||||
jwtTrans string
|
jwtTrans string
|
||||||
@@ -83,6 +84,7 @@ func genRoutes(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var hasTimeout bool
|
||||||
gt := template.Must(template.New("groupTemplate").Parse(templateText))
|
gt := template.Must(template.New("groupTemplate").Parse(templateText))
|
||||||
for _, g := range groups {
|
for _, g := range groups {
|
||||||
var gbuilder strings.Builder
|
var gbuilder strings.Builder
|
||||||
@@ -114,12 +116,19 @@ rest.WithPrefix("%s"),`, g.prefix)
|
|||||||
}
|
}
|
||||||
|
|
||||||
var timeout string
|
var timeout string
|
||||||
if g.timeoutEnable {
|
if len(g.timeout) > 0 {
|
||||||
duration, err := time.ParseDuration(g.timeout)
|
duration, err := time.ParseDuration(g.timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
return err
|
||||||
}
|
}
|
||||||
timeout = fmt.Sprintf("rest.WithTimeout(%d),", duration)
|
|
||||||
|
// why we check this, maybe some users set value 1, it's 1ns, not 1s.
|
||||||
|
if duration < timeoutThreshold {
|
||||||
|
return fmt.Errorf("timeout should not less than 1ms, now %v", duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout = fmt.Sprintf("rest.WithTimeout(%d * time.Millisecond),", duration/time.Millisecond)
|
||||||
|
hasTimeout = true
|
||||||
}
|
}
|
||||||
|
|
||||||
var routes string
|
var routes string
|
||||||
@@ -152,8 +161,8 @@ rest.WithPrefix("%s"),`, g.prefix)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
routeFilename = routeFilename + ".go"
|
|
||||||
|
|
||||||
|
routeFilename = routeFilename + ".go"
|
||||||
filename := path.Join(dir, handlerDir, routeFilename)
|
filename := path.Join(dir, handlerDir, routeFilename)
|
||||||
os.Remove(filename)
|
os.Remove(filename)
|
||||||
|
|
||||||
@@ -165,7 +174,8 @@ rest.WithPrefix("%s"),`, g.prefix)
|
|||||||
category: category,
|
category: category,
|
||||||
templateFile: routesTemplateFile,
|
templateFile: routesTemplateFile,
|
||||||
builtinTemplate: routesTemplate,
|
builtinTemplate: routesTemplate,
|
||||||
data: map[string]string{
|
data: map[string]interface{}{
|
||||||
|
"hasTimeout": hasTimeout,
|
||||||
"importPackages": genRouteImports(rootPkg, api),
|
"importPackages": genRouteImports(rootPkg, api),
|
||||||
"routesAdditions": strings.TrimSpace(builder.String()),
|
"routesAdditions": strings.TrimSpace(builder.String()),
|
||||||
},
|
},
|
||||||
@@ -184,7 +194,8 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder), pathx.JoinPackages(parentPkg, handlerDir, folder)))
|
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder),
|
||||||
|
pathx.JoinPackages(parentPkg, handlerDir, folder)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
imports := importSet.KeysStr()
|
imports := importSet.KeysStr()
|
||||||
@@ -218,12 +229,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := g.GetAnnotation("timeout")
|
groupedRoutes.timeout = g.GetAnnotation("timeout")
|
||||||
|
|
||||||
if len(timeout) > 0 {
|
|
||||||
groupedRoutes.timeoutEnable = true
|
|
||||||
groupedRoutes.timeout = timeout
|
|
||||||
}
|
|
||||||
|
|
||||||
jwt := g.GetAnnotation("jwt")
|
jwt := g.GetAnnotation("jwt")
|
||||||
if len(jwt) > 0 {
|
if len(jwt) > 0 {
|
||||||
|
|||||||
Reference in New Issue
Block a user