* fixes #987 * chore: fix test failure * chore: add comments
This commit is contained in:
@@ -3,8 +3,10 @@ package gogen
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"path"
|
"path"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
|
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||||
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
@@ -64,12 +66,8 @@ func genLogicByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group,
|
|||||||
var requestString string
|
var requestString string
|
||||||
if len(route.ResponseTypeName()) > 0 {
|
if len(route.ResponseTypeName()) > 0 {
|
||||||
resp := responseGoTypeName(route, typesPacket)
|
resp := responseGoTypeName(route, typesPacket)
|
||||||
responseString = "(" + resp + ", error)"
|
responseString = "(resp " + resp + ", err error)"
|
||||||
if strings.HasPrefix(resp, "*") {
|
returnString = "return"
|
||||||
returnString = fmt.Sprintf("return &%s{}, nil", strings.TrimPrefix(resp, "*"))
|
|
||||||
} else {
|
|
||||||
returnString = fmt.Sprintf("return %s{}, nil", resp)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
responseString = "error"
|
responseString = "error"
|
||||||
returnString = "return nil"
|
returnString = "return nil"
|
||||||
@@ -116,9 +114,47 @@ func genLogicImports(route spec.Route, parentPkg string) string {
|
|||||||
var imports []string
|
var imports []string
|
||||||
imports = append(imports, `"context"`+"\n")
|
imports = append(imports, `"context"`+"\n")
|
||||||
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, contextDir)))
|
imports = append(imports, fmt.Sprintf("\"%s\"", ctlutil.JoinPackages(parentPkg, contextDir)))
|
||||||
if len(route.ResponseTypeName()) > 0 || len(route.RequestTypeName()) > 0 {
|
if shallImportTypesPackage(route) {
|
||||||
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir)))
|
imports = append(imports, fmt.Sprintf("\"%s\"\n", ctlutil.JoinPackages(parentPkg, typesDir)))
|
||||||
}
|
}
|
||||||
imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
|
imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
|
||||||
return strings.Join(imports, "\n\t")
|
return strings.Join(imports, "\n\t")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func onlyPrimitiveTypes(val string) bool {
|
||||||
|
fields := strings.FieldsFunc(val, func(r rune) bool {
|
||||||
|
return r == '[' || r == ']' || r == ' '
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, field := range fields {
|
||||||
|
if field == "map" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// ignore array dimension number, like [5]int
|
||||||
|
if _, err := strconv.Atoi(field); err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !api.IsBasicType(field) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func shallImportTypesPackage(route spec.Route) bool {
|
||||||
|
if len(route.RequestTypeName()) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
respTypeName := route.ResponseTypeName()
|
||||||
|
if len(respTypeName) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if onlyPrimitiveTypes(respTypeName) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|||||||
@@ -267,11 +267,8 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) interface{} {
|
|||||||
}
|
}
|
||||||
case *Literal:
|
case *Literal:
|
||||||
lit := dataType.Literal.Text()
|
lit := dataType.Literal.Text()
|
||||||
if api.IsGolangKeyWord(dataType.Literal.Text()) {
|
if api.IsGolangKeyWord(lit) {
|
||||||
v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", dataType.Literal.Text()))
|
v.panic(dataType.Literal, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", lit))
|
||||||
}
|
|
||||||
if api.IsBasicType(lit) {
|
|
||||||
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
|
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ func TestRoute(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
_, err = parser.Accept(fn, ` post /foo/bar returns (int)`)
|
_, err = parser.Accept(fn, ` post /foo/bar returns (int)`)
|
||||||
assert.Error(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
_, err = parser.Accept(fn, ` post /foo/bar returns (*int)`)
|
_, err = parser.Accept(fn, ` post /foo/bar returns (*int)`)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|||||||
Reference in New Issue
Block a user