api add middleware support (#140)
* rebase upstream
* rebase
* trim no need line
* trim no need line
* trim no need line
* update doc
* remove update
* remove no need
* remove no need
* goctl add jwt support
* goctl add jwt support
* goctl add jwt support
* goctl support import
* goctl support import
* support return ()
* revert
* refactor and rename folder to group
* remove no need
* add anonymous annotation
* optimized
* rename
* rename
* update test
* api add middleware support: usage:
@server(
middleware: M1, M2
)
* api add middleware support: usage:
@server(
middleware: M1, M2
)
* simple logic
* should reverse middlewares
* optimized
* optimized
* rename
Co-authored-by: kingxt <dream4kingxt@163.com>
This commit is contained in:
@@ -103,6 +103,13 @@ func WithJwtTransition(secret, prevSecret string) RouteOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
|
||||||
|
for i := len(ms) - 1; i >= 0; i-- {
|
||||||
|
rs = WithMiddleware(ms[i], rs...)
|
||||||
|
}
|
||||||
|
return rs
|
||||||
|
}
|
||||||
|
|
||||||
func WithMiddleware(middleware Middleware, rs ...Route) []Route {
|
func WithMiddleware(middleware Middleware, rs ...Route) []Route {
|
||||||
routes := make([]Route, len(rs))
|
routes := make([]Route, len(rs))
|
||||||
|
|
||||||
|
|||||||
@@ -68,3 +68,75 @@ func TestWithMiddleware(t *testing.T) {
|
|||||||
"wan": "2020",
|
"wan": "2020",
|
||||||
}, m)
|
}, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMultiMiddleware(t *testing.T) {
|
||||||
|
m := make(map[string]string)
|
||||||
|
router := router.NewPatRouter()
|
||||||
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var v struct {
|
||||||
|
Nickname string `form:"nickname"`
|
||||||
|
Zipcode int64 `form:"zipcode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err := httpx.Parse(r, &v)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
_, err = io.WriteString(w, fmt.Sprintf("%s:%s", v.Nickname, m[v.Nickname]))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
rs := WithMiddlewares([]Middleware{
|
||||||
|
func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var v struct {
|
||||||
|
Name string `path:"name"`
|
||||||
|
Year string `path:"year"`
|
||||||
|
}
|
||||||
|
assert.Nil(t, httpx.ParsePath(r, &v))
|
||||||
|
m[v.Name] = v.Year
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var v struct {
|
||||||
|
Name string `form:"nickname"`
|
||||||
|
Zipcode string `form:"zipcode"`
|
||||||
|
}
|
||||||
|
assert.Nil(t, httpx.ParseForm(r, &v))
|
||||||
|
assert.NotEmpty(t, m)
|
||||||
|
m[v.Name] = v.Zipcode + v.Zipcode
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}, Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/first/:name/:year",
|
||||||
|
Handler: handler,
|
||||||
|
}, Route{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/second/:name/:year",
|
||||||
|
Handler: handler,
|
||||||
|
})
|
||||||
|
|
||||||
|
urls := []string{
|
||||||
|
"http://hello.com/first/kevin/2017?nickname=whatever&zipcode=200000",
|
||||||
|
"http://hello.com/second/wan/2020?nickname=whatever&zipcode=200000",
|
||||||
|
}
|
||||||
|
for _, route := range rs {
|
||||||
|
assert.Nil(t, router.Handle(route.Method, route.Path, route.Handler))
|
||||||
|
}
|
||||||
|
for _, url := range urls {
|
||||||
|
r, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, r)
|
||||||
|
|
||||||
|
assert.Equal(t, "whatever:200000200000", rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.EqualValues(t, map[string]string{
|
||||||
|
"kevin": "2017",
|
||||||
|
"wan": "2020",
|
||||||
|
"whatever": "200000200000",
|
||||||
|
}, m)
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,9 +31,9 @@ func RegisterHandlers(engine *rest.Server, serverCtx *svc.ServiceContext) {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
routesAdditionTemplate = `
|
routesAdditionTemplate = `
|
||||||
engine.AddRoutes([]rest.Route{
|
engine.AddRoutes(
|
||||||
{{.routes}}
|
{{.routes}}
|
||||||
}{{.jwt}}{{.signature}})
|
{{.jwt}}{{.signature}})
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -52,6 +52,7 @@ type (
|
|||||||
jwtEnabled bool
|
jwtEnabled bool
|
||||||
signatureEnabled bool
|
signatureEnabled bool
|
||||||
authName string
|
authName string
|
||||||
|
middleware []string
|
||||||
}
|
}
|
||||||
route struct {
|
route struct {
|
||||||
method string
|
method string
|
||||||
@@ -87,8 +88,22 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error {
|
|||||||
if g.signatureEnabled {
|
if g.signatureEnabled {
|
||||||
signature = fmt.Sprintf(", rest.WithSignature(serverCtx.Config.%s.Signature)", g.authName)
|
signature = fmt.Sprintf(", rest.WithSignature(serverCtx.Config.%s.Signature)", g.authName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var routes string
|
||||||
|
if len(g.middleware) > 0 {
|
||||||
|
var params = g.middleware
|
||||||
|
for i := range params {
|
||||||
|
params[i] = "serverCtx." + params[i]
|
||||||
|
}
|
||||||
|
var middlewareStr = strings.Join(params, ", ")
|
||||||
|
routes = fmt.Sprintf("rest.WithMultiMiddleware([]rest.Middleware{ %s }, []rest.Route{\n %s \n}),",
|
||||||
|
middlewareStr, strings.TrimSpace(gbuilder.String()))
|
||||||
|
} else {
|
||||||
|
routes = fmt.Sprintf("[]rest.Route{\n %s \n},", strings.TrimSpace(gbuilder.String()))
|
||||||
|
}
|
||||||
|
|
||||||
if err := gt.Execute(&builder, map[string]string{
|
if err := gt.Execute(&builder, map[string]string{
|
||||||
"routes": strings.TrimSpace(gbuilder.String()),
|
"routes": routes,
|
||||||
"jwt": jwt,
|
"jwt": jwt,
|
||||||
"signature": signature,
|
"signature": signature,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -185,6 +200,11 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
|
|||||||
groupedRoutes.authName = value
|
groupedRoutes.authName = value
|
||||||
groupedRoutes.jwtEnabled = true
|
groupedRoutes.jwtEnabled = true
|
||||||
}
|
}
|
||||||
|
if value, ok := apiutil.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
|
||||||
|
for _, item := range strings.Split(value, ",") {
|
||||||
|
groupedRoutes.middleware = append(groupedRoutes.middleware, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
routes = append(routes, groupedRoutes)
|
routes = append(routes, groupedRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,16 +9,20 @@ import (
|
|||||||
"github.com/tal-tech/go-zero/tools/goctl/api/util"
|
"github.com/tal-tech/go-zero/tools/goctl/api/util"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/templatex"
|
"github.com/tal-tech/go-zero/tools/goctl/templatex"
|
||||||
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/vars"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
contextFilename = "servicecontext.go"
|
contextFilename = "servicecontext.go"
|
||||||
contextTemplate = `package svc
|
contextTemplate = `package svc
|
||||||
|
|
||||||
import {{.configImport}}
|
import (
|
||||||
|
{{.configImport}}
|
||||||
|
)
|
||||||
|
|
||||||
type ServiceContext struct {
|
type ServiceContext struct {
|
||||||
Config {{.config}}
|
Config {{.config}}
|
||||||
|
{{.middleware}}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServiceContext(c {{.config}}) *ServiceContext {
|
func NewServiceContext(c {{.config}}) *ServiceContext {
|
||||||
@@ -53,12 +57,22 @@ func genServiceContext(dir string, api *spec.ApiSpec) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var middlewareStr string
|
||||||
|
for _, item := range getMiddleware(api) {
|
||||||
|
middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)
|
||||||
|
}
|
||||||
|
|
||||||
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
|
var configImport = "\"" + ctlutil.JoinPackages(parentPkg, configDir) + "\""
|
||||||
|
if len(middlewareStr) > 0 {
|
||||||
|
configImport += fmt.Sprintf("\n\"%s/rest\"", vars.ProjectOpenSourceUrl)
|
||||||
|
}
|
||||||
|
|
||||||
t := template.Must(template.New("contextTemplate").Parse(text))
|
t := template.Must(template.New("contextTemplate").Parse(text))
|
||||||
buffer := new(bytes.Buffer)
|
buffer := new(bytes.Buffer)
|
||||||
err = t.Execute(buffer, map[string]string{
|
err = t.Execute(buffer, map[string]string{
|
||||||
"configImport": configImport,
|
"configImport": configImport,
|
||||||
"config": "config.Config",
|
"config": "config.Config",
|
||||||
|
"middleware": middlewareStr,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -66,6 +66,18 @@ func getAuths(api *spec.ApiSpec) []string {
|
|||||||
return authNames.KeysStr()
|
return authNames.KeysStr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getMiddleware(api *spec.ApiSpec) []string {
|
||||||
|
result := collection.NewSet()
|
||||||
|
for _, g := range api.Service.Groups {
|
||||||
|
if value, ok := util.GetAnnotationValue(g.Annotations, "server", "middleware"); ok {
|
||||||
|
for _, item := range strings.Split(value, ",") {
|
||||||
|
result.Add(strings.TrimSpace(item))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.KeysStr()
|
||||||
|
}
|
||||||
|
|
||||||
func formatCode(code string) string {
|
func formatCode(code string) string {
|
||||||
ret, err := goformat.Source([]byte(code))
|
ret, err := goformat.Source([]byte(code))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -119,6 +119,24 @@ service A-api {
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
|
const apiHasMiddleware = `
|
||||||
|
type Request struct {
|
||||||
|
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
Message string ` + "`" + `json:"message"` + "`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
@server(
|
||||||
|
middleware: TokenValidate
|
||||||
|
)
|
||||||
|
service A-api {
|
||||||
|
@handler GreetHandler
|
||||||
|
get /greet/from/:name(Request) returns (Response)
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
func TestParser(t *testing.T) {
|
func TestParser(t *testing.T) {
|
||||||
filename := "greet.api"
|
filename := "greet.api"
|
||||||
err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm)
|
err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm)
|
||||||
@@ -198,3 +216,16 @@ func TestAnonymousAnnotation(t *testing.T) {
|
|||||||
assert.Equal(t, len(api.Service.Routes), 1)
|
assert.Equal(t, len(api.Service.Routes), 1)
|
||||||
assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler")
|
assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApiHasMiddleware(t *testing.T) {
|
||||||
|
filename := "greet.api"
|
||||||
|
err := ioutil.WriteFile(filename, []byte(apiHasMiddleware), os.ModePerm)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
defer os.Remove(filename)
|
||||||
|
|
||||||
|
parser, err := NewParser(filename)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
_, err = parser.Parse()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user