reactor rpc (#179)

* reactor rpc generation

* update flag

* update command

* update command

* update unit test

* delete test file

* optimize code

* update doc

* update gen pb

* rename target dir

* update mysql data type convert rule

* add done flag

* optimize req/reply parameter

* optimize req/reply parameter

* remove waste code

* remove duplicate parameter

* format code

* format code

* optimize naming

* reactor rpcv2 to rpc

* remove new line

* format code

* rename underline to snake

* reactor getParentPackage

* remove debug log

* reactor background
This commit is contained in:
Keson
2020-11-05 14:12:47 +08:00
committed by GitHub
parent c9ec22d5f4
commit e76f44a35b
95 changed files with 2708 additions and 3301 deletions

View File

@@ -0,0 +1,10 @@
package parser
import "github.com/emicklei/proto"
func GetComment(comment *proto.Comment) string {
if comment == nil {
return ""
}
return comment.Message()
}

View File

@@ -0,0 +1,7 @@
package parser
import "github.com/emicklei/proto"
type Import struct {
*proto.Import
}

View File

@@ -0,0 +1,7 @@
package parser
import pr "github.com/emicklei/proto"
type Message struct {
*pr.Message
}

View File

@@ -0,0 +1,7 @@
package parser
import "github.com/emicklei/proto"
type Option struct {
*proto.Option
}

View File

@@ -0,0 +1,7 @@
package parser
import "github.com/emicklei/proto"
type Package struct {
*proto.Package
}

View File

@@ -1,46 +1,170 @@
package parser
import (
"errors"
"fmt"
"go/token"
"os"
"path/filepath"
"strings"
"unicode"
"unicode/utf8"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/emicklei/proto"
)
func Transfer(proto, target string, externalImport []*Import, console console.Console) (*PbAst, error) {
messageM := make(map[string]lang.PlaceholderType)
enumM := make(map[string]*Enum)
protoAst, err := parseProto(proto, messageM, enumM)
if err != nil {
return nil, err
}
for _, item := range externalImport {
err = checkImport(item.OriginalProtoPath)
if err != nil {
return nil, err
}
innerAst, err := parseProto(item.OriginalProtoPath, protoAst.Message, protoAst.Enum)
if err != nil {
return nil, err
}
for k, v := range innerAst.Message {
protoAst.Message[k] = v
}
for k, v := range innerAst.Enum {
protoAst.Enum[k] = v
}
}
protoAst.Import = externalImport
protoAst.PbSrc = filepath.Join(target, strings.TrimSuffix(filepath.Base(proto), ".proto")+".pb.go")
return transfer(protoAst, console)
type (
defaultProtoParser struct{}
)
func NewDefaultProtoParser() *defaultProtoParser {
return &defaultProtoParser{}
}
func transfer(proto *Proto, console console.Console) (*PbAst, error) {
parser := MustNewAstParser(proto, console)
parse, err := parser.Parse()
func (p *defaultProtoParser) Parse(src string) (Proto, error) {
var ret Proto
abs, err := filepath.Abs(src)
if err != nil {
return nil, err
return Proto{}, err
}
return parse, nil
r, err := os.Open(abs)
if err != nil {
return ret, err
}
defer r.Close()
parser := proto.NewParser(r)
set, err := parser.Parse()
if err != nil {
return ret, err
}
var serviceList []Service
proto.Walk(
set,
proto.WithImport(func(i *proto.Import) {
ret.Import = append(ret.Import, Import{Import: i})
}),
proto.WithMessage(func(message *proto.Message) {
ret.Message = append(ret.Message, Message{Message: message})
}),
proto.WithPackage(func(p *proto.Package) {
ret.Package = Package{Package: p}
}),
proto.WithService(func(service *proto.Service) {
serv := Service{Service: service}
elements := service.Elements
for _, el := range elements {
v, _ := el.(*proto.RPC)
if v == nil {
continue
}
serv.RPC = append(serv.RPC, &RPC{RPC: v})
}
serviceList = append(serviceList, serv)
}),
proto.WithOption(func(option *proto.Option) {
if option.Name == "go_package" {
ret.GoPackage = option.Constant.Source
}
}),
)
if len(serviceList) == 0 {
return ret, errors.New("rpc service not found")
}
if len(serviceList) > 1 {
return ret, errors.New("only one service expected")
}
service := serviceList[0]
name := filepath.Base(abs)
for _, rpc := range service.RPC {
if strings.Contains(rpc.RequestType, ".") {
return ret, fmt.Errorf("line %v:%v, request type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
}
if strings.Contains(rpc.ReturnsType, ".") {
return ret, fmt.Errorf("line %v:%v, returns type must defined in %s", rpc.Position.Line, rpc.Position.Column, name)
}
}
if len(ret.GoPackage) == 0 {
ret.GoPackage = ret.Package.Name
}
ret.PbPackage = GoSanitized(filepath.Base(ret.GoPackage))
ret.Src = abs
ret.Name = name
ret.Service = service
return ret, nil
}
// see google.golang.org/protobuf@v1.25.0/internal/strs/strings.go:71
func GoSanitized(s string) string {
// Sanitize the input to the set of valid characters,
// which must be '_' or be in the Unicode L or N categories.
s = strings.Map(func(r rune) rune {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
return r
}
return '_'
}, s)
// Prepend '_' in the event of a Go keyword conflict or if
// the identifier is invalid (does not start in the Unicode L category).
r, _ := utf8.DecodeRuneInString(s)
if token.Lookup(s).IsKeyword() || !unicode.IsLetter(r) {
return "_" + s
}
return s
}
// copy from github.com/golang/protobuf@v1.4.2/protoc-gen-go/generator/generator.go:2648
func CamelCase(s string) string {
if s == "" {
return ""
}
t := make([]byte, 0, 32)
i := 0
if s[0] == '_' {
// Need a capital letter; drop the '_'.
t = append(t, 'X')
i++
}
// Invariant: if the next letter is lower case, it must be converted
// to upper case.
// That is, we process a word at a time, where words are marked by _ or
// upper case letter. Digits are treated as words.
for ; i < len(s); i++ {
c := s[i]
if c == '_' && i+1 < len(s) && isASCIILower(s[i+1]) {
continue // Skip the underscore in s.
}
if isASCIIDigit(c) {
t = append(t, c)
continue
}
// Assume we have a letter now - if not, it's a bogus identifier.
// The next word is a sequence of characters that must start upper case.
if isASCIILower(c) {
c ^= ' ' // Make it a capital letter.
}
t = append(t, c) // Guaranteed not lower case.
// Accept lower case sequence that follows.
for i+1 < len(s) && isASCIILower(s[i+1]) {
i++
t = append(t, s[i])
}
}
return string(t)
}
func isASCIILower(c byte) bool {
return 'a' <= c && c <= 'z'
}
// Is c an ASCII digit?
func isASCIIDigit(c byte) bool {
return '0' <= c && c <= '9'
}

View File

@@ -0,0 +1,78 @@
package parser
import (
"sort"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestDefaultProtoParse(t *testing.T) {
p := NewDefaultProtoParser()
data, err := p.Parse("./test.proto")
assert.Nil(t, err)
assert.Equal(t, "base.proto", func() string {
ip := data.Import[0]
return ip.Filename
}())
assert.Equal(t, "test", data.Package.Name)
assert.Equal(t, true, data.GoPackage == "go")
assert.Equal(t, true, data.PbPackage == "_go")
assert.Equal(t, []string{"TestMessage", "TestReply", "TestReq"}, func() []string {
var list []string
for _, item := range data.Message {
list = append(list, item.Name)
}
sort.Strings(list)
return list
}())
assert.Equal(t, true, func() bool {
s := data.Service
if s.Name != "TestService" {
return false
}
rpcOne := s.RPC[0]
return rpcOne.Name == "TestRpcOne" && rpcOne.RequestType == "TestReq" && rpcOne.ReturnsType == "TestReply"
}())
}
func TestDefaultProtoParseCaseInvalidRequestType(t *testing.T) {
p := NewDefaultProtoParser()
_, err := p.Parse("./test_invalid_request.proto")
assert.True(t, true, func() bool {
return strings.Contains(err.Error(), "request type must defined in")
}())
}
func TestDefaultProtoParseCaseInvalidResponseType(t *testing.T) {
p := NewDefaultProtoParser()
_, err := p.Parse("./test_invalid_response.proto")
assert.True(t, true, func() bool {
return strings.Contains(err.Error(), "response type must defined in")
}())
}
func TestDefaultProtoParseError(t *testing.T) {
p := NewDefaultProtoParser()
_, err := p.Parse("./nil.proto")
assert.NotNil(t, err)
}
func TestDefaultProtoParse_Option(t *testing.T) {
p := NewDefaultProtoParser()
data, err := p.Parse("./test_option.proto")
assert.Nil(t, err)
assert.Equal(t, "github.com/tal-tech/go-zero", data.GoPackage)
assert.Equal(t, "go_zero", data.PbPackage)
}
func TestDefaultProtoParse_Option2(t *testing.T) {
p := NewDefaultProtoParser()
data, err := p.Parse("./test_option2.proto")
assert.Nil(t, err)
assert.Equal(t, "stream", data.GoPackage)
assert.Equal(t, "stream", data.PbPackage)
}

View File

@@ -1,643 +0,0 @@
package parser
import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/ioutil"
"sort"
"strings"
"github.com/tal-tech/go-zero/core/lang"
sx "github.com/tal-tech/go-zero/core/stringx"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
flagStar = "*"
flagDot = "."
suffixServer = "Server"
referenceContext = "context"
unknownPrefix = "XXX_"
ignoreJsonTagExpression = `json:"-"`
)
var (
errorParseError = errors.New("pb parse error")
typeTemplate = `type (
{{.types}}
)`
structTemplate = `{{if .type}}type {{end}}{{.name}} struct {
{{.fields}}
}`
fieldTemplate = `{{if .hasDoc}}{{.doc}}
{{end}}{{.name}} {{.type}} {{.tag}}{{if .hasComment}}{{.comment}}{{end}}`
anyTypeTemplate = "Any struct {\n\tTypeUrl string `json:\"typeUrl\"`\n\tValue []byte `json:\"value\"`\n}"
objectM = make(map[string]*Struct)
)
type (
astParser struct {
filterStruct map[string]lang.PlaceholderType
filterEnum map[string]*Enum
console.Console
fileSet *token.FileSet
proto *Proto
}
Field struct {
Name stringx.String
Type Type
JsonTag string
Document []string
Comment []string
}
Struct struct {
Name stringx.String
Document []string
Comment []string
Field []*Field
}
ConstLit struct {
Name stringx.String
Document []string
Comment []string
Lit []*Lit
}
Lit struct {
Key string
Value int
}
Type struct {
// eg:context.Context
Expression string
// eg: *context.Context
StarExpression string
// Invoke Type Expression
InvokeTypeExpression string
// eg:context
Package string
// eg:Context
Name string
}
Func struct {
Name stringx.String
ParameterIn Type
ParameterOut Type
Document []string
}
RpcService struct {
Name stringx.String
Funcs []*Func
}
// parsing for rpc
PbAst struct {
// deprecated: containsAny will be removed in the feature
ContainsAny bool
Imports map[string]string
Structure map[string]*Struct
Service []*RpcService
*Proto
}
)
func MustNewAstParser(proto *Proto, log console.Console) *astParser {
return &astParser{
filterStruct: proto.Message,
filterEnum: proto.Enum,
Console: log,
fileSet: token.NewFileSet(),
proto: proto,
}
}
func (a *astParser) Parse() (*PbAst, error) {
var pbAst PbAst
pbAst.ContainsAny = a.proto.ContainsAny
pbAst.Proto = a.proto
pbAst.Structure = make(map[string]*Struct)
pbAst.Imports = make(map[string]string)
structure, imports, services, err := a.parse(a.proto.PbSrc)
if err != nil {
return nil, err
}
dependencyStructure, err := a.parseExternalDependency()
if err != nil {
return nil, err
}
for k, v := range structure {
pbAst.Structure[k] = v
}
for k, v := range dependencyStructure {
pbAst.Structure[k] = v
}
for key, path := range imports {
pbAst.Imports[key] = path
}
pbAst.Service = append(pbAst.Service, services...)
return &pbAst, nil
}
func (a *astParser) parse(pbSrc string) (structure map[string]*Struct, imports map[string]string, services []*RpcService, retErr error) {
structure = make(map[string]*Struct)
imports = make(map[string]string)
data, err := ioutil.ReadFile(pbSrc)
if err != nil {
retErr = err
return
}
fSet := a.fileSet
f, err := parser.ParseFile(fSet, "", data, parser.ParseComments)
if err != nil {
retErr = err
return
}
commentMap := ast.NewCommentMap(fSet, f, f.Comments)
f.Comments = commentMap.Filter(f).Comments()
strucs, function := a.mustScope(f.Scope, a.mustGetIndentName(f.Name))
for k, v := range strucs {
if v == nil {
continue
}
structure[k] = v
}
importList := f.Imports
for _, item := range importList {
name := a.mustGetIndentName(item.Name)
if item.Path != nil {
imports[name] = item.Path.Value
}
}
services = append(services, function...)
return
}
func (a *astParser) parseExternalDependency() (map[string]*Struct, error) {
m := make(map[string]*Struct)
for _, impo := range a.proto.Import {
ret, _, _, err := a.parse(impo.OriginalPbPath)
if err != nil {
return nil, err
}
for k, v := range ret {
m[k] = v
}
}
return m, nil
}
func (a *astParser) mustScope(scope *ast.Scope, sourcePackage string) (map[string]*Struct, []*RpcService) {
if scope == nil {
return nil, nil
}
objects := scope.Objects
structs := make(map[string]*Struct)
serviceList := make([]*RpcService, 0)
for name, obj := range objects {
decl := obj.Decl
if decl == nil {
continue
}
typeSpec, ok := decl.(*ast.TypeSpec)
if !ok {
continue
}
tp := typeSpec.Type
switch v := tp.(type) {
case *ast.StructType:
st, err := a.parseObject(name, v, sourcePackage)
a.Must(err)
structs[st.Name.Lower()] = st
case *ast.InterfaceType:
if !strings.HasSuffix(name, suffixServer) {
continue
}
list := a.mustServerFunctions(v, sourcePackage)
serviceList = append(serviceList, &RpcService{
Name: stringx.From(strings.TrimSuffix(name, suffixServer)),
Funcs: list,
})
}
}
targetStruct := make(map[string]*Struct)
for st := range a.filterStruct {
lower := strings.ToLower(st)
targetStruct[lower] = structs[lower]
}
return targetStruct, serviceList
}
func (a *astParser) mustServerFunctions(v *ast.InterfaceType, sourcePackage string) []*Func {
funcs := make([]*Func, 0)
methodObject := v.Methods
if methodObject == nil {
return nil
}
for _, method := range methodObject.List {
var item Func
name := a.mustGetIndentName(method.Names[0])
doc := a.parseCommentOrDoc(method.Doc)
item.Name = stringx.From(name)
item.Document = doc
types := method.Type
if types == nil {
funcs = append(funcs, &item)
continue
}
v, ok := types.(*ast.FuncType)
if !ok {
continue
}
params := v.Params
if params != nil {
inList, err := a.parseFields(params.List, true, sourcePackage)
a.Must(err)
for _, data := range inList {
if data.Type.Package == referenceContext {
continue
}
item.ParameterIn = data.Type
break
}
}
results := v.Results
if results != nil {
outList, err := a.parseFields(results.List, true, sourcePackage)
a.Must(err)
for _, data := range outList {
if data.Type.Package == referenceContext {
continue
}
item.ParameterOut = data.Type
break
}
}
funcs = append(funcs, &item)
}
return funcs
}
func (a *astParser) getFieldType(v string, sourcePackage string) Type {
var pkg, name, expression, starExpression, invokeTypeExpression string
if strings.Contains(v, ".") {
starExpression = v
if strings.Contains(v, "*") {
leftIndex := strings.Index(v, "*")
rightIndex := strings.Index(v, ".")
if leftIndex >= 0 {
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
invokeTypeExpression = v[rightIndex+1:]
}
} else {
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
leftIndex := strings.Index(v, "]")
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[0:leftIndex+1] + v[rightIndex+1:]
} else {
rightIndex := strings.Index(v, ".")
invokeTypeExpression = v[rightIndex+1:]
}
}
} else {
expression = strings.TrimPrefix(v, flagStar)
switch v {
case "double", "float", "int32", "int64", "uint32", "uint64", "sint32", "sint64", "fixed32", "fixed64", "sfixed32", "sfixed64",
"bool", "string", "bytes":
invokeTypeExpression = v
break
default:
name = expression
invokeTypeExpression = v
if strings.HasPrefix(v, "map[") || strings.HasPrefix(v, "[]") {
starExpression = strings.ReplaceAll(v, flagStar, flagStar+sourcePackage+".")
} else {
starExpression = fmt.Sprintf("*%v.%v", sourcePackage, name)
invokeTypeExpression = v
}
}
}
expression = strings.TrimPrefix(starExpression, flagStar)
index := strings.LastIndex(expression, flagDot)
if index > 0 {
pkg = expression[0:index]
name = expression[index+1:]
} else {
pkg = sourcePackage
}
return Type{
Expression: expression,
StarExpression: starExpression,
InvokeTypeExpression: invokeTypeExpression,
Package: pkg,
Name: name,
}
}
func (a *astParser) parseObject(structName string, tp *ast.StructType, sourcePackage string) (*Struct, error) {
if data, ok := objectM[structName]; ok {
return data, nil
}
var st Struct
st.Name = stringx.From(structName)
if tp == nil {
return &st, nil
}
fields := tp.Fields
if fields == nil {
objectM[structName] = &st
return &st, nil
}
fieldList := fields.List
members, err := a.parseFields(fieldList, false, sourcePackage)
if err != nil {
return nil, err
}
for _, m := range members {
var field Field
field.Name = m.Name
field.Type = m.Type
field.JsonTag = m.JsonTag
field.Document = m.Document
field.Comment = m.Comment
st.Field = append(st.Field, &field)
}
objectM[structName] = &st
return &st, nil
}
func (a *astParser) parseFields(fields []*ast.Field, onlyType bool, sourcePackage string) ([]*Field, error) {
ret := make([]*Field, 0)
for _, field := range fields {
var item Field
tag := a.parseTag(field.Tag)
if tag == "" && !onlyType {
continue
}
if tag == ignoreJsonTagExpression {
continue
}
item.JsonTag = tag
name := a.parseName(field.Names)
if strings.HasPrefix(name, unknownPrefix) {
continue
}
item.Name = stringx.From(name)
typeName, err := a.parseType(field.Type)
if err != nil {
return nil, err
}
item.Type = a.getFieldType(typeName, sourcePackage)
if onlyType {
ret = append(ret, &item)
continue
}
docs := a.parseCommentOrDoc(field.Doc)
comments := a.parseCommentOrDoc(field.Comment)
item.Document = docs
item.Comment = comments
isInline := name == ""
if isInline {
return nil, a.wrapError(field.Pos(), "unexpected inline type:%s", name)
}
ret = append(ret, &item)
}
return ret, nil
}
func (a *astParser) parseTag(basicLit *ast.BasicLit) string {
if basicLit == nil {
return ""
}
value := basicLit.Value
splits := strings.Split(value, " ")
if len(splits) == 1 {
return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[0], "`", ""))
} else {
return fmt.Sprintf("`%s`", strings.ReplaceAll(splits[1], "`", ""))
}
}
// returns
// resp1:type's string expression,like int、string、[]int64、map[string]User、*User
// resp2:error
func (a *astParser) parseType(expr ast.Expr) (string, error) {
if expr == nil {
return "", errorParseError
}
switch v := expr.(type) {
case *ast.StarExpr:
stringExpr, err := a.parseType(v.X)
if err != nil {
return "", err
}
e := fmt.Sprintf("*%s", stringExpr)
return e, nil
case *ast.Ident:
return a.mustGetIndentName(v), nil
case *ast.MapType:
keyStringExpr, err := a.parseType(v.Key)
if err != nil {
return "", err
}
valueStringExpr, err := a.parseType(v.Value)
if err != nil {
return "", err
}
e := fmt.Sprintf("map[%s]%s", keyStringExpr, valueStringExpr)
return e, nil
case *ast.ArrayType:
stringExpr, err := a.parseType(v.Elt)
if err != nil {
return "", err
}
e := fmt.Sprintf("[]%s", stringExpr)
return e, nil
case *ast.InterfaceType:
return "interface{}", nil
case *ast.SelectorExpr:
join := make([]string, 0)
xIdent, ok := v.X.(*ast.Ident)
xIndentName := a.mustGetIndentName(xIdent)
if ok {
join = append(join, xIndentName)
}
sel := v.Sel
join = append(join, a.mustGetIndentName(sel))
return strings.Join(join, "."), nil
case *ast.ChanType:
return "", a.wrapError(v.Pos(), "unexpected type 'chan'")
case *ast.FuncType:
return "", a.wrapError(v.Pos(), "unexpected type 'func'")
case *ast.StructType:
return "", a.wrapError(v.Pos(), "unexpected inline struct type")
default:
return "", a.wrapError(v.Pos(), "unexpected type '%v'", v)
}
}
func (a *astParser) parseName(names []*ast.Ident) string {
if len(names) == 0 {
return ""
}
name := names[0]
return a.mustGetIndentName(name)
}
func (a *astParser) parseCommentOrDoc(cg *ast.CommentGroup) []string {
if cg == nil {
return nil
}
comments := make([]string, 0)
for _, comment := range cg.List {
if comment == nil {
continue
}
text := strings.TrimSpace(comment.Text)
if text == "" {
continue
}
comments = append(comments, text)
}
return comments
}
func (a *astParser) mustGetIndentName(ident *ast.Ident) string {
if ident == nil {
return ""
}
return ident.Name
}
func (a *astParser) wrapError(pos token.Pos, format string, arg ...interface{}) error {
file := a.fileSet.Position(pos)
return fmt.Errorf("line %v: %s", file.Line, fmt.Sprintf(format, arg...))
}
func (f *Func) GetDoc() string {
return strings.Join(f.Document, util.NL)
}
func (f *Func) HaveDoc() bool {
return len(f.Document) > 0
}
func (a *PbAst) GenEnumCode() (string, error) {
var element []string
for _, item := range a.Enum {
code, err := item.GenEnumCode()
if err != nil {
return "", err
}
element = append(element, code)
}
return strings.Join(element, util.NL), nil
}
func (a *PbAst) GenTypesCode() (string, error) {
types := make([]string, 0)
sts := make([]*Struct, 0)
for _, item := range a.Structure {
sts = append(sts, item)
}
sort.Slice(sts, func(i, j int) bool {
return sts[i].Name.Source() < sts[j].Name.Source()
})
for _, s := range sts {
structCode, err := s.genCode(false)
if err != nil {
return "", err
}
if structCode == "" {
continue
}
types = append(types, structCode)
}
types = append(types, a.genAnyCode())
for _, item := range a.Enum {
typeCode, err := item.GenEnumTypeCode()
if err != nil {
return "", err
}
types = append(types, typeCode)
}
buffer, err := util.With("type").Parse(typeTemplate).Execute(map[string]interface{}{
"types": strings.Join(types, util.NL+util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (a *PbAst) genAnyCode() string {
if !a.ContainsAny {
return ""
}
return anyTypeTemplate
}
func (s *Struct) genCode(containsTypeStatement bool) (string, error) {
fields := make([]string, 0)
for _, f := range s.Field {
var comment, doc string
if len(f.Comment) > 0 {
comment = f.Comment[0]
}
doc = strings.Join(f.Document, util.NL)
buffer, err := util.With(sx.Rand()).Parse(fieldTemplate).Execute(map[string]interface{}{
"name": f.Name.Title(),
"type": f.Type.InvokeTypeExpression,
"tag": f.JsonTag,
"hasDoc": len(f.Document) > 0,
"doc": doc,
"hasComment": len(f.Comment) > 0,
"comment": comment,
})
if err != nil {
return "", err
}
fields = append(fields, buffer.String())
}
buffer, err := util.With("struct").Parse(structTemplate).Execute(map[string]interface{}{
"type": containsTypeStatement,
"name": s.Name.Title(),
"fields": strings.Join(fields, util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}

View File

@@ -1,295 +1,12 @@
package parser
import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/emicklei/proto"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
const (
AnyImport = "google/protobuf/any.proto"
)
var (
enumTypeTemplate = `{{.name}} int32`
enumTemplate = `const (
{{.element}}
)`
enumFiledTemplate = `{{.key}} {{.name}} = {{.value}}`
)
type (
MessageField struct {
Type string
Name stringx.String
}
Message struct {
Name stringx.String
Element []*MessageField
*proto.Message
}
Enum struct {
Name stringx.String
Element []*EnumField
*proto.Enum
}
EnumField struct {
Key string
Value int
}
Proto struct {
Package string
Import []*Import
PbSrc string
// deprecated: containsAny will be removed in the feature
ContainsAny bool
Message map[string]lang.PlaceholderType
Enum map[string]*Enum
}
Import struct {
ProtoImportName string
PbImportName string
OriginalDir string
OriginalProtoPath string
OriginalPbPath string
BridgeImport string
exists bool
//xx.proto
protoName string
// xx.pb.go
pbName string
}
)
func checkImport(src string) error {
r, err := os.Open(src)
if err != nil {
return err
}
defer r.Close()
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return err
}
var base = filepath.Base(src)
proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
if err != nil {
return
}
err = fmt.Errorf("%v:%v the external proto cannot import other proto files", base, i.Position.Line)
}))
if err != nil {
return err
}
return nil
}
func ParseImport(src string) ([]*Import, bool, error) {
bridgeImportM := make(map[string]string)
r, err := os.Open(src)
if err != nil {
return nil, false, err
}
defer r.Close()
workDir := filepath.Dir(src)
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return nil, false, err
}
protoImportSet := collection.NewSet()
var containsAny bool
proto.Walk(parseRet, proto.WithImport(func(i *proto.Import) {
if i.Filename == AnyImport {
containsAny = true
return
}
protoImportSet.AddStr(i.Filename)
if i.Comment != nil {
lines := i.Comment.Lines
for _, line := range lines {
line = strings.TrimSpace(line)
if !strings.HasPrefix(line, "@") {
continue
}
line = strings.TrimPrefix(line, "@")
bridgeImportM[i.Filename] = line
}
}
}))
var importList []*Import
for _, item := range protoImportSet.KeysStr() {
pb := strings.TrimSuffix(filepath.Base(item), filepath.Ext(item)) + ".pb.go"
var pbImportName, brideImport string
if v, ok := bridgeImportM[item]; ok {
pbImportName = v
brideImport = "M" + item + "=" + v
} else {
pbImportName = item
}
var impo = Import{
ProtoImportName: item,
PbImportName: pbImportName,
BridgeImport: brideImport,
}
protoSource := filepath.Join(workDir, item)
pbSource := filepath.Join(filepath.Dir(protoSource), pb)
if util.FileExists(protoSource) && util.FileExists(pbSource) {
impo.OriginalProtoPath = protoSource
impo.OriginalPbPath = pbSource
impo.OriginalDir = filepath.Dir(protoSource)
impo.exists = true
impo.protoName = filepath.Base(item)
impo.pbName = pb
} else {
return nil, false, fmt.Errorf("「%v」: import must be found in the relative directory of 「%v」", item, filepath.Base(src))
}
importList = append(importList, &impo)
}
return importList, containsAny, nil
}
func parseProto(src string, messageM map[string]lang.PlaceholderType, enumM map[string]*Enum) (*Proto, error) {
if !filepath.IsAbs(src) {
return nil, fmt.Errorf("expected absolute path,but found: %v", src)
}
r, err := os.Open(src)
if err != nil {
return nil, err
}
defer r.Close()
parser := proto.NewParser(r)
parseRet, err := parser.Parse()
if err != nil {
return nil, err
}
// xx.proto
fileBase := filepath.Base(src)
var resp Proto
proto.Walk(parseRet, proto.WithPackage(func(p *proto.Package) {
if err != nil {
return
}
if len(resp.Package) != 0 {
err = fmt.Errorf("%v:%v duplicate package「%v」", fileBase, p.Position.Line, p.Name)
}
if len(p.Name) == 0 {
err = errors.New("package not found")
}
resp.Package = p.Name
}), proto.WithMessage(func(message *proto.Message) {
if err != nil {
return
}
for _, item := range message.Elements {
switch item.(type) {
case *proto.NormalField, *proto.MapField, *proto.Comment:
continue
default:
err = fmt.Errorf("%v: unsupport inline declaration", fileBase)
return
}
}
name := stringx.From(message.Name)
if _, ok := messageM[name.Lower()]; ok {
err = fmt.Errorf("%v:%v duplicate message 「%v」", fileBase, message.Position.Line, message.Name)
return
}
messageM[name.Lower()] = lang.Placeholder
}), proto.WithEnum(func(enum *proto.Enum) {
if err != nil {
return
}
var node Enum
node.Enum = enum
node.Name = stringx.From(enum.Name)
for _, item := range enum.Elements {
v, ok := item.(*proto.EnumField)
if !ok {
continue
}
node.Element = append(node.Element, &EnumField{
Key: v.Name,
Value: v.Integer,
})
}
if _, ok := enumM[node.Name.Lower()]; ok {
err = fmt.Errorf("%v:%v duplicate enum 「%v」", fileBase, node.Position.Line, node.Name.Source())
return
}
lower := stringx.From(enum.Name).Lower()
enumM[lower] = &node
}))
if err != nil {
return nil, err
}
resp.Message = messageM
resp.Enum = enumM
return &resp, nil
}
func (e *Enum) GenEnumCode() (string, error) {
var element []string
for _, item := range e.Element {
code, err := item.GenEnumFieldCode(e.Name.Source())
if err != nil {
return "", err
}
element = append(element, code)
}
buffer, err := util.With("enum").Parse(enumTemplate).Execute(map[string]interface{}{
"element": strings.Join(element, util.NL),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (e *Enum) GenEnumTypeCode() (string, error) {
buffer, err := util.With("enumAlias").Parse(enumTypeTemplate).Execute(map[string]interface{}{
"name": e.Name.Source(),
})
if err != nil {
return "", err
}
return buffer.String(), nil
}
func (e *EnumField) GenEnumFieldCode(parentName string) (string, error) {
buffer, err := util.With("enumField").Parse(enumFiledTemplate).Execute(map[string]interface{}{
"key": e.Key,
"name": parentName,
"value": e.Value,
})
if err != nil {
return "", err
}
return buffer.String(), nil
type Proto struct {
Src string
Name string
Package Package
PbPackage string
GoPackage string
Import []Import
Message []Message
Service Service
}

View File

@@ -0,0 +1,7 @@
package parser
import "github.com/emicklei/proto"
type RPC struct {
*proto.RPC
}

View File

@@ -0,0 +1,8 @@
package parser
import "github.com/emicklei/proto"
type Service struct {
*proto.Service
RPC []*RPC
}

View File

@@ -0,0 +1,20 @@
syntax = "proto3";
package test;
option go_package = "go";
import "base.proto";
message TestMessage{}
message TestReq{}
message TestReply{}
enum TestEnum {
unknown = 0;
male = 1;
female = 2;
}
service TestService{
rpc TestRpcOne (TestReq)returns(TestReply);
}

View File

@@ -0,0 +1,13 @@
syntax = "proto3";
package test;
option go_package = "go";
import "base.proto";
message Reply{}
service TestService{
rpc TestRpcTwo (base.Req)returns(Reply);
}

View File

@@ -0,0 +1,13 @@
syntax = "proto3";
package test;
option go_package = "go";
import "base.proto";
message Req{}
service TestService{
rpc TestRpcTwo (Req)returns(base.Reply);
}

View File

@@ -0,0 +1,10 @@
syntax = "proto3";
package stream;
option go_package="github.com/tal-tech/go-zero";
message placeholder{}
service greet{
rpc hello(placeholder)returns(placeholder);
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
package stream;
message placeholder{}
service greet{
rpc hello(placeholder)returns(placeholder);
}