Goctl rpc patch (#117)

* remove mock generation

* add: proto project import

* update document

* remove mock generation

* add: proto project import

* update document

* remove NL

* update document

* optimize code

* add test

* add test
This commit is contained in:
Keson
2020-10-10 16:19:46 +08:00
committed by GitHub
parent c32759d735
commit 0a9c427443
26 changed files with 1394 additions and 230 deletions

View File

@@ -0,0 +1,46 @@
package parser
import (
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
func Transfer(proto, target string, externalImport []*Import, console console.Console) (*PbAst, error) {
messageM := make(map[string]lang.PlaceholderType)
enumM := make(map[string]*Enum)
protoAst, err := parseProto(proto, messageM, enumM)
if err != nil {
return nil, err
}
for _, item := range externalImport {
err = checkImport(item.OriginalProtoPath)
if err != nil {
return nil, err
}
innerAst, err := parseProto(item.OriginalProtoPath, protoAst.Message, protoAst.Enum)
if err != nil {
return nil, err
}
for k, v := range innerAst.Message {
protoAst.Message[k] = v
}
for k, v := range innerAst.Enum {
protoAst.Enum[k] = v
}
}
protoAst.Import = externalImport
protoAst.PbSrc = filepath.Join(target, strings.TrimSuffix(filepath.Base(proto), ".proto")+".pb.go")
return transfer(protoAst, console)
}
func transfer(proto *Proto, console console.Console) (*PbAst, error) {
parser := MustNewAstParser(proto, console)
parse, err := parser.Parse()
if err != nil {
return nil, err
}
return parse, nil
}

View File

@@ -6,6 +6,7 @@ import (
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"sort"
"strings"
@@ -18,8 +19,9 @@ import (
const (
flagStar = "*"
flagDot = "."
suffixServer = "Server"
referenceContext = "context."
referenceContext = "context"
unknownPrefix = "XXX_"
ignoreJsonTagExpression = `json:"-"`
)
@@ -34,19 +36,23 @@ var (
}`
fieldTemplate = `{{if .hasDoc}}{{.doc}}
{{end}}{{.name}} {{.type}} {{.tag}}{{if .hasComment}}{{.comment}}{{end}}`
anyTypeTemplate = "Any struct {\n\tTypeUrl string `json:\"typeUrl\"`\n\tValue []byte `json:\"value\"`\n}"
objectM = make(map[string]*Struct)
)
type (
astParser struct {
golang []byte
filterStruct map[string]lang.PlaceholderType
filterEnum map[string]*Enum
console.Console
fileSet *token.FileSet
proto *Proto
}
Field struct {
Name stringx.String
TypeName string
Type Type
JsonTag string
Document []string
Comment []string
@@ -57,13 +63,33 @@ type (
Comment []string
Field []*Field
}
ConstLit struct {
Name stringx.String
Document []string
Comment []string
Lit []*Lit
}
Lit struct {
Key string
Value int
}
Type struct {
// eg:context.Context
Expression string
// eg: *context.Context
StarExpression string
// Invoke Type Expression
InvokeTypeExpression string
// eg:context
Package string
// eg:Context
Name string
}
Func struct {
Name stringx.String
InType string
InTypeName string // remove *Context,such as LoginRequest、UserRequest
OutTypeName string // remove *Context
OutType string
Document []string
Name stringx.String
ParameterIn Type
ParameterOut Type
Document []string
}
RpcService struct {
Name stringx.String
@@ -71,54 +97,98 @@ type (
}
// parsing for rpc
PbAst struct {
Package string
// external reference
Imports map[string]string
Strcuts map[string]*Struct
// rpc server's functions,not all functions
Service []*RpcService
ContainsAny bool
Imports map[string]string
Structure map[string]*Struct
Service []*RpcService
*Proto
}
)
func NewAstParser(golang []byte, filterStruct map[string]lang.PlaceholderType, log console.Console) *astParser {
func MustNewAstParser(proto *Proto, log console.Console) *astParser {
return &astParser{
golang: golang,
filterStruct: filterStruct,
filterStruct: proto.Message,
filterEnum: proto.Enum,
Console: log,
fileSet: token.NewFileSet(),
proto: proto,
}
}
func (a *astParser) Parse() (*PbAst, error) {
fSet := a.fileSet
f, err := parser.ParseFile(fSet, "", a.golang, parser.ParseComments)
var pbAst PbAst
pbAst.ContainsAny = a.proto.ContainsAny
pbAst.Proto = a.proto
pbAst.Structure = make(map[string]*Struct)
pbAst.Imports = make(map[string]string)
structure, imports, services, err := a.parse(a.proto.PbSrc)
if err != nil {
return nil, err
}
commentMap := ast.NewCommentMap(fSet, f, f.Comments)
f.Comments = commentMap.Filter(f).Comments()
var pbAst PbAst
pbAst.Package = a.mustGetIndentName(f.Name)
imports := make(map[string]string)
for _, item := range f.Imports {
if item == nil {
continue
}
if item.Path == nil {
continue
}
key := a.mustGetIndentName(item.Name)
value := item.Path.Value
imports[key] = value
dependencyStructure, err := a.parseExternalDependency()
if err != nil {
return nil, err
}
structs, funcs := a.mustScope(f.Scope)
pbAst.Imports = imports
pbAst.Strcuts = structs
pbAst.Service = funcs
for k, v := range structure {
pbAst.Structure[k] = v
}
for k, v := range dependencyStructure {
pbAst.Structure[k] = v
}
for key, path := range imports {
pbAst.Imports[key] = path
}
pbAst.Service = append(pbAst.Service, services...)
return &pbAst, nil
}
func (a *astParser) mustScope(scope *ast.Scope) (map[string]*Struct, []*RpcService) {
func (a *astParser) parse(pbSrc string) (structure map[string]*Struct, imports map[string]string, services []*RpcService, retErr error) {
structure = make(map[string]*Struct)
imports = make(map[string]string)
data, err := ioutil.ReadFile(pbSrc)
if err != nil {
retErr = err
return
}
fSet := a.fileSet
f, err := parser.ParseFile(fSet, "", data, parser.ParseComments)
if err != nil {
retErr = err
return
}
commentMap := ast.NewCommentMap(fSet, f, f.Comments)
f.Comments = commentMap.Filter(f).Comments()
strucs, function := a.mustScope(f.Scope, a.mustGetIndentName(f.Name))
for k, v := range strucs {
if v == nil {
continue
}
structure[k] = v
}
importList := f.Imports
for _, item := range importList {
name := a.mustGetIndentName(item.Name)
if item.Path != nil {
imports[name] = item.Path.Value
}
}
services = append(services, function...)
return
}
func (a *astParser) parseExternalDependency() (map[string]*Struct, error) {
m := make(map[string]*Struct)
for _, impo := range a.proto.Import {
ret, _, _, err := a.parse(impo.OriginalPbPath)
if err != nil {
return nil, err
}
for k, v := range ret {
m[k] = v
}
}
return m, nil
}
func (a *astParser) mustScope(scope *ast.Scope, sourcePackage string) (map[string]*Struct, []*RpcService) {
if scope == nil {
return nil, nil
}
@@ -140,7 +210,7 @@ func (a *astParser) mustScope(scope *ast.Scope) (map[string]*Struct, []*RpcServi
switch v := tp.(type) {
case *ast.StructType:
st, err := a.parseObject(name, v)
st, err := a.parseObject(name, v, sourcePackage)
a.Must(err)
structs[st.Name.Lower()] = st
@@ -148,7 +218,7 @@ func (a *astParser) mustScope(scope *ast.Scope) (map[string]*Struct, []*RpcServi
if !strings.HasSuffix(name, suffixServer) {
continue
}
list := a.mustServerFunctions(v)
list := a.mustServerFunctions(v, sourcePackage)
serviceList = append(serviceList, &RpcService{
Name: stringx.From(strings.TrimSuffix(name, suffixServer)),
Funcs: list,
@@ -163,7 +233,7 @@ func (a *astParser) mustScope(scope *ast.Scope) (map[string]*Struct, []*RpcServi
return targetStruct, serviceList
}
func (a *astParser) mustServerFunctions(v *ast.InterfaceType) []*Func {
func (a *astParser) mustServerFunctions(v *ast.InterfaceType, sourcePackage string) []*Func {
funcs := make([]*Func, 0)
methodObject := v.Methods
if methodObject == nil {
@@ -187,31 +257,27 @@ func (a *astParser) mustServerFunctions(v *ast.InterfaceType) []*Func {
}
params := v.Params
if params != nil {
inList, err := a.parseFields(params.List, true)
inList, err := a.parseFields(params.List, true, sourcePackage)
a.Must(err)
for _, data := range inList {
if strings.HasPrefix(data.TypeName, referenceContext) {
if data.Type.Package == referenceContext {
continue
}
// currently,does not support external references
item.InTypeName = data.TypeName
item.InType = strings.TrimPrefix(data.TypeName, flagStar)
item.ParameterIn = data.Type
break
}
}
results := v.Results
if results != nil {
outList, err := a.parseFields(results.List, true)
outList, err := a.parseFields(results.List, true, sourcePackage)
a.Must(err)
for _, data := range outList {
if strings.HasPrefix(data.TypeName, referenceContext) {
if data.Type.Package == referenceContext {
continue
}
// currently,does not support external references
item.OutTypeName = data.TypeName
item.OutType = strings.TrimPrefix(data.TypeName, flagStar)
item.ParameterOut = data.Type
break
}
}
@@ -220,7 +286,67 @@ func (a *astParser) mustServerFunctions(v *ast.InterfaceType) []*Func {
return funcs
}
func (a *astParser) parseObject(structName string, tp *ast.StructType) (*Struct, error) {
func (a *astParser) getFieldType(v string, sourcePackage string) Type {
var pkg, name, expression, starExpression, invokeTypeExpression string
if strings.Contains(v, ".") {
starExpression = v
if strings.Contains(v, "*") {
leftIndex := strings.Index(v, "*")
rightIndex := strings.Index(v, ".")
if leftIndex >= 0 {
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
invokeTypeExpression = v[rightIndex+1:]
}
} else {
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
leftIndex := strings.Index(v, "]")
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[rightIndex+1:]
}
}
} else {
expression = strings.TrimPrefix(v, flagStar)
switch v {
case "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64",
"bool", "string", "bytes":
invokeTypeExpression = v
break
default:
name = expression
invokeTypeExpression = v
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
starExpression = strings.ReplaceAll(v, flagStar, flagStar+sourcePackage+".")
} else {
starExpression = fmt.Sprintf("*%v.%v", sourcePackage, name)
invokeTypeExpression = v
}
}
}
expression = strings.TrimPrefix(starExpression, flagStar)
index := strings.LastIndex(expression, flagDot)
if index > 0 {
pkg = expression[0:index]
name = expression[index+1:]
} else {
pkg = sourcePackage
}
return Type{
Expression: expression,
StarExpression: starExpression,
InvokeTypeExpression: invokeTypeExpression,
Package: pkg,
Name: name,
}
}
func (a *astParser) parseObject(structName string, tp *ast.StructType, sourcePackage string) (*Struct, error) {
if data, ok := objectM[structName]; ok {
return data, nil
}
@@ -237,7 +363,7 @@ func (a *astParser) parseObject(structName string, tp *ast.StructType) (*Struct,
}
fieldList := fields.List
members, err := a.parseFields(fieldList, false)
members, err := a.parseFields(fieldList, false, sourcePackage)
if err != nil {
return nil, err
}
@@ -245,7 +371,7 @@ func (a *astParser) parseObject(structName string, tp *ast.StructType) (*Struct,
for _, m := range members {
var field Field
field.Name = m.Name
field.TypeName = m.TypeName
field.Type = m.Type
field.JsonTag = m.JsonTag
field.Document = m.Document
field.Comment = m.Comment
@@ -255,7 +381,7 @@ func (a *astParser) parseObject(structName string, tp *ast.StructType) (*Struct,
return &st, nil
}
func (a *astParser) parseFields(fields []*ast.Field, onlyType bool) ([]*Field, error) {
func (a *astParser) parseFields(fields []*ast.Field, onlyType bool, sourcePackage string) ([]*Field, error) {
ret := make([]*Field, 0)
for _, field := range fields {
var item Field
@@ -278,7 +404,7 @@ func (a *astParser) parseFields(fields []*ast.Field, onlyType bool) ([]*Field, e
return nil, err
}
item.TypeName = typeName
item.Type = a.getFieldType(typeName, sourcePackage)
if onlyType {
ret = append(ret, &item)
continue
@@ -414,10 +540,30 @@ func (a *astParser) wrapError(pos token.Pos, format string, arg ...interface{})
return fmt.Errorf("line %v: %s", file.Line, fmt.Sprintf(format, arg...))
}
func (f *Func) GetDoc() string {
return strings.Join(f.Document, util.NL)
}
func (f *Func) HaveDoc() bool {
return len(f.Document) > 0
}
func (a *PbAst) GenEnumCode() (string, error) {
var element []string
for _, item := range a.Enum {
code, err := item.GenEnumCode()
if err != nil {
return "", err
}
element = append(element, code)
}
return strings.Join(element, util.NL), nil
}
func (a *PbAst) GenTypesCode() (string, error) {
types := make([]string, 0)
sts := make([]*Struct, 0)
for _, item := range a.Strcuts {
for _, item := range a.Structure {
sts = append(sts, item)
}
sort.Slice(sts, func(i, j int) bool {
@@ -434,8 +580,17 @@ func (a *PbAst) GenTypesCode() (string, error) {
}
types = append(types, structCode)
}
types = append(types, a.genAnyCode())
for _, item := range a.Enum {
typeCode, err := item.GenEnumTypeCode()
if err != nil {
return "", err
}
types = append(types, typeCode)
}
buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
"types": strings.Join(types, "\n\n"),
"types": strings.Join(types, util.NL+util.NL),
})
if err != nil {
return "", err
@@ -444,6 +599,13 @@ func (a *PbAst) GenTypesCode() (string, error) {
return buffer.String(), nil
}
func (a *PbAst) genAnyCode() string {
if !a.ContainsAny {
return ""
}
return anyTypeTemplate
}
func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
fields := make([]string, 0)
for _, f := range s.Field {
@@ -451,10 +613,10 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
if len(f.Comment) > 0 {
comment = f.Comment[0]
}
doc = strings.Join(f.Document, "\n")
doc = strings.Join(f.Document, util.NL)
buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
"name": f.Name.Title(),
"type": f.TypeName,
"type": f.Type.InvokeTypeExpression,
"tag": f.JsonTag,
"hasDoc": len(f.Document) > 0,
"doc": doc,
@@ -470,7 +632,7 @@ func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
"type": containsTypeStatement,
"name": s.Name.Title(),
"fields": strings.Join(fields, "\n"),
"fields": strings.Join(fields, util.NL),
})
if err != nil {
return "", err

View File

@@ -0,0 +1,294 @@
package parser
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/emicklei/proto"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
AnyImport = "google/protobuf/any.proto"
)
var (
enumTypeTemplate = `{{.name}} int32`
enumTemplate = `const (
{{.element}}
)`
enumFiledTemplate = `{{.key}} {{.name}} = {{.value}}`
)
type (
MessageField struct {
Type string
Name stringx.String
}
Message struct {
Name stringx.String
Element []*MessageField
*proto.Message
}
Enum struct {
Name stringx.String
Element []*EnumField
*proto.Enum
}
EnumField struct {
Key string
Value int
}
Proto struct {
Package string
Import []*Import
PbSrc string
ContainsAny bool
Message map[string]lang.PlaceholderType
Enum map[string]*Enum
}
Import struct {
ProtoImportName string
PbImportName string
OriginalDir string
OriginalProtoPath string
OriginalPbPath string
BridgeImport string
exists bool
//xx.proto
protoName string
// xx.pb.go
pbName string
}
)
func checkImport(src string) error {
r, err := os.Open(src)
if err != nil {
return err
}
defer r.Close()
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return err
}
var base = filepath.Base(src)
proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
if err != nil {
return
}
err = fmt.Errorf("%v:%v the external proto cannot import other proto files", base, i.Position.Line)
}))
if err != nil {
return err
}
return nil
}
func ParseImport(src string) ([]*Import, bool, error) {
bridgeImportM := make(map[string]string)
r, err := os.Open(src)
if err != nil {
return nil, false, err
}
defer r.Close()
workDir := filepath.Dir(src)
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return nil, false, err
}
protoImportSet := collection.NewSet()
var containsAny bool
proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
if i.Filename == AnyImport {
containsAny = true
return
}
protoImportSet.AddStr(i.Filename)
if i.Comment != nil {
lines := i.Comment.Lines
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "@") {
continue
}
line = strings.TrimPrefix(line, "@")
bridgeImportM[i.Filename] = line
}
}
}))
var importList []*Import
for _, item := range protoImportSet.KeysStr() {
pb := strings.TrimSuffix(filepath.Base(item), filepath.Ext(item)) + ".pb.go"
var pbImportName, brideImport string
if v, ok := bridgeImportM[item]; ok {
pbImportName = v
brideImport = "M" + item + "=" + v
} else {
pbImportName = item
}
var impo = Import{
ProtoImportName: item,
PbImportName: pbImportName,
BridgeImport: brideImport,
}
protoSource := filepath.Join(workDir, item)
pbSource := filepath.Join(filepath.Dir(protoSource), pb)
if util.FileExists(protoSource) && util.FileExists(pbSource) {
impo.OriginalProtoPath = protoSource
impo.OriginalPbPath = pbSource
impo.OriginalDir = filepath.Dir(protoSource)
impo.exists = true
impo.protoName = filepath.Base(item)
impo.pbName = pb
} else {
return nil, false, fmt.Errorf("「%v」: import must be found in the relative directory of 「%v」", item, filepath.Base(src))
}
importList = append(importList, &impo)
}
return importList, containsAny, nil
}
func parseProto(src string, messageM map[string]lang.PlaceholderType, enumM map[string]*Enum) (*Proto, error) {
if !filepath.IsAbs(src) {
return nil, fmt.Errorf("expected absolute path,but found: %v", src)
}
r, err := os.Open(src)
if err != nil {
return nil, err
}
defer r.Close()
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return nil, err
}
// xx.proto
fileBase := filepath.Base(src)
var resp Proto
proto.Walk(parseRet, proto.WithPackage(func(p *proto.Package) {
if err != nil {
return
}
if len(resp.Package) != 0 {
err = fmt.Errorf("%v:%v duplicate package「%v」", fileBase, p.Position.Line, p.Name)
}
if len(p.Name) == 0 {
err = errors.New("package not found")
}
resp.Package = p.Name
}), proto.WithMessage(func(message *proto.Message) {
if err != nil {
return
}
for _, item := range message.Elements {
switch item.(type) {
case *proto.NormalField, *proto.MapField, *proto.Comment:
continue
default:
err = fmt.Errorf("%v: unsupport inline declaration", fileBase)
return
}
}
name := stringx.From(message.Name)
if _, ok := messageM[name.Lower()]; ok {
err = fmt.Errorf("%v:%v duplicate message 「%v」", fileBase, message.Position.Line, message.Name)
return
}
messageM[name.Lower()] = lang.Placeholder
}), proto.WithEnum(func(enum *proto.Enum) {
if err != nil {
return
}
var node Enum
node.Enum = enum
node.Name = stringx.From(enum.Name)
for _, item := range enum.Elements {
v, ok := item.(*proto.EnumField)
if !ok {
continue
}
node.Element = append(node.Element, &EnumField{
Key: v.Name,
Value: v.Integer,
})
}
if _, ok := enumM[node.Name.Lower()]; ok {
err = fmt.Errorf("%v:%v duplicate enum 「%v」", fileBase, node.Position.Line, node.Name.Source())
return
}
lower := stringx.From(enum.Name).Lower()
enumM[lower] = &node
}))
if err != nil {
return nil, err
}
resp.Message = messageM
resp.Enum = enumM
return &resp, nil
}
func (e *Enum) GenEnumCode() (string, error) {
var element []string
for _, item := range e.Element {
code, err := item.GenEnumFieldCode(e.Name.Source())
if err != nil {
return "", err
}
element = append(element, code)
}
buffer, err := util.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{
"element": strings.Join(element, util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (e *Enum) GenEnumTypeCode() (string, error) {
buffer, err := util.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{
"name": e.Name.Source(),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) {
buffer, err := util.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{
"key": e.Key,
"name": parentName,
"value": e.Value,
})
if err != nil {
return "", err
}
return buffer.String(), nil
}