Goctl rpc patch (#117)

* remove mock generation

* add: proto project import

* update document

* remove mock generation

* add: proto project import

* update document

* remove NL

* update document

* optimize code

* add test

* add test
This commit is contained in:
Keson
2020-10-10 16:19:46 +08:00
committed by GitHub
parent c32759d735
commit 0a9c427443
26 changed files with 1394 additions and 230 deletions

View File

@@ -1,6 +1,7 @@
package gen
import (
"github.com/logrusorgru/aurora"
"github.com/tal-tech/go-zero/tools/goctl/rpc/ctx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
@@ -31,10 +32,11 @@ func NewDefaultRpcGenerator(ctx *ctx.RpcContext) *defaultRpcGenerator {
}
func (g *defaultRpcGenerator) Generate() (err error) {
g.Ctx.Info("generating code...")
g.Ctx.Info(aurora.Blue("-> goctl rpc reference documents: ").String() + "「https://github.com/tal-tech/go-zero/blob/master/doc/goctl-rpc.md」")
g.Ctx.Warning("-> generating rpc code ...")
defer func() {
if err == nil {
g.Ctx.Success("Done.")
g.Ctx.MarkDone()
}
}()
err = g.createDir()

View File

@@ -2,17 +2,16 @@ package gen
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
typesFilename = "types.go"
callTemplateText = `{{.head}}
//go:generate mockgen -destination ./{{.name}}_mock.go -package {{.filePackage}} -source $GOFILE
@@ -54,14 +53,17 @@ import "errors"
var errJsonConvert = errors.New("json convert error")
{{.const}}
{{.types}}
`
callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
callFunctionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
var request {{.package}}.{{.pbRequest}}
func (m *default{{.rpcServiceName}}) {{.method}}(ctx context.Context,in *{{.pbRequestName}}) (*{{.pbResponse}}, error) {
var request {{.pbRequest}}
bts, err := jsonx.Marshal(in)
if err != nil {
return nil, errJsonConvert
@@ -108,21 +110,23 @@ func (g *defaultRpcGenerator) genCall() error {
return err
}
constLit, err := file.GenEnumCode()
if err != nil {
return err
}
service := file.Service[0]
callPath := filepath.Join(g.dirM[dirTarget], service.Name.Lower())
if err = util.MkdirIfNotExist(callPath); err != nil {
return err
}
pbPkg := file.Package
remotePackage := fmt.Sprintf(`%v "%v"`, pbPkg, g.mustGetPackage(dirPb))
filename := filepath.Join(callPath, "types.go")
filename := filepath.Join(callPath, typesFilename)
head := util.GetHead(g.Ctx.ProtoSource)
err = util.With("types").GoFmt(true).Parse(callTemplateTypes).SaveTo(map[string]interface{}{
"head": head,
"const": constLit,
"filePackage": service.Name.Lower(),
"pbPkg": pbPkg,
"serviceName": g.Ctx.ServiceName.Title(),
"lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
"types": typeCode,
@@ -131,10 +135,8 @@ func (g *defaultRpcGenerator) genCall() error {
return err
}
_, err = exec.LookPath("mockgen")
mockGenInstalled := err == nil
filename = filepath.Join(callPath, fmt.Sprintf("%s.go", service.Name.Lower()))
functions, err := g.getFuncs(service)
functions, importList, err := g.genFunction(service)
if err != nil {
return err
}
@@ -144,72 +146,56 @@ func (g *defaultRpcGenerator) genCall() error {
return err
}
mockFile := filepath.Join(callPath, fmt.Sprintf("%s_mock.go", service.Name.Lower()))
_ = os.Remove(mockFile)
err = util.With("shared").GoFmt(true).Parse(callTemplateText).SaveTo(map[string]interface{}{
"name": service.Name.Lower(),
"head": head,
"filePackage": service.Name.Lower(),
"pbPkg": pbPkg,
"package": remotePackage,
"package": strings.Join(importList, util.NL),
"serviceName": service.Name.Title(),
"functions": strings.Join(functions, "\n"),
"interface": strings.Join(iFunctions, "\n"),
"functions": strings.Join(functions, util.NL),
"interface": strings.Join(iFunctions, util.NL),
}, filename, true)
if err != nil {
return err
}
// if mockgen is already installed, it will generate code of gomock for shared files
// Deprecated: it will be removed
if mockGenInstalled && g.Ctx.IsInGoEnv {
_, _ = execx.Run(fmt.Sprintf("go generate %s", filename), "")
}
return nil
return err
}
func (g *defaultRpcGenerator) getFuncs(service *parser.RpcService) ([]string, error) {
func (g *defaultRpcGenerator) genFunction(service *parser.RpcService) ([]string, []string, error) {
file := g.ast
pkgName := file.Package
functions := make([]string, 0)
imports := collection.NewSet()
imports.AddStr(fmt.Sprintf(`%v "%v"`, pkgName, g.mustGetPackage(dirPb)))
for _, method := range service.Funcs {
var comment string
if len(method.Document) > 0 {
comment = method.Document[0]
}
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
buffer, err := util.With("sharedFn").Parse(callFunctionTemplate).Execute(map[string]interface{}{
"rpcServiceName": service.Name.Title(),
"method": method.Name.Title(),
"package": pkgName,
"pbRequest": method.InType,
"pbResponse": method.OutType,
"hasComment": len(method.Document) > 0,
"comment": comment,
"pbRequestName": method.ParameterIn.Name,
"pbRequest": method.ParameterIn.Expression,
"pbResponse": method.ParameterOut.Name,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
})
if err != nil {
return nil, err
return nil, nil, err
}
functions = append(functions, buffer.String())
}
return functions, nil
return functions, imports.KeysStr(), nil
}
func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
functions := make([]string, 0)
for _, method := range service.Funcs {
var comment string
if len(method.Document) > 0 {
comment = method.Document[0]
}
buffer, err := util.With("interfaceFn").Parse(callInterfaceFunctionTemplate).Execute(
map[string]interface{}{
"hasComment": len(method.Document) > 0,
"comment": comment,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
"method": method.Name.Title(),
"pbRequest": method.InType,
"pbResponse": method.OutType,
"pbRequest": method.ParameterIn.Name,
"pbResponse": method.ParameterOut.Name,
})
if err != nil {
return nil, err

View File

@@ -6,6 +6,7 @@ import (
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/vars"
)
// target
@@ -43,9 +44,9 @@ func (g *defaultRpcGenerator) mustGetPackage(dir string) string {
relativePath := strings.TrimPrefix(target, projectPath)
os := runtime.GOOS
switch os {
case "windows":
case vars.OsWindows:
relativePath = filepath.ToSlash(relativePath)
case "darwin", "linux":
case vars.OsMac, vars.OsLinux:
default:
g.Ctx.Fatalln("unexpected os: %s", os)
}

View File

@@ -37,10 +37,10 @@ func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logic
{{.functions}}
`
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
func (l *{{.logicName}}) {{.method}} (in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
// todo: add your logic here and delete this line
return &{{.package}}.{{.response}}{}, nil
return &{{.responseType}}{}, nil
}
`
)
@@ -53,18 +53,18 @@ func (g *defaultRpcGenerator) genLogic() error {
for _, method := range item.Funcs {
logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
filename := filepath.Join(logicPath, logicName)
functions, err := genLogicFunction(protoPkg, method)
functions, importList, err := g.genLogicFunction(protoPkg, method)
if err != nil {
return err
}
imports := collection.NewSet()
pbImport := fmt.Sprintf(`%v "%v"`, protoPkg, g.mustGetPackage(dirPb))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports.AddStr(pbImport, svcImport)
imports.AddStr(svcImport)
imports.AddStr(importList...)
err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"functions": functions,
"imports": strings.Join(imports.KeysStr(), "\n"),
"imports": strings.Join(imports.KeysStr(), util.NL),
}, filename, false)
if err != nil {
return err
@@ -74,20 +74,26 @@ func (g *defaultRpcGenerator) genLogic() error {
return nil
}
func genLogicFunction(packageName string, method *parser.Func) (string, error) {
func (g *defaultRpcGenerator) genLogicFunction(packageName string, method *parser.Func) (string, []string, error) {
var functions = make([]string, 0)
var imports = collection.NewSet()
if method.ParameterIn.Package == packageName || method.ParameterOut.Package == packageName {
imports.AddStr(fmt.Sprintf(`%v "%v"`, packageName, g.mustGetPackage(dirPb)))
}
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
buffer, err := util.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(),
"package": packageName,
"request": method.InType,
"response": method.OutType,
"hasComment": len(method.Document) > 0,
"comment": strings.Join(method.Document, "\n"),
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(),
"request": method.ParameterIn.StarExpression,
"response": method.ParameterOut.StarExpression,
"responseType": method.ParameterOut.Expression,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
})
if err != nil {
return "", err
return "", nil, err
}
functions = append(functions, buffer.String())
return strings.Join(functions, "\n"), nil
return strings.Join(functions, util.NL), imports.KeysStr(), nil
}

View File

@@ -65,7 +65,7 @@ func (g *defaultRpcGenerator) genMain() error {
"serviceName": g.Ctx.ServiceName.Lower(),
"srv": srv,
"registers": registers,
"imports": strings.Join(imports, "\n"),
"imports": strings.Join(imports, util.NL),
}, fileName, true)
}
@@ -77,5 +77,5 @@ func (g *defaultRpcGenerator) genServer(pkg string, list []*parser.RpcService) (
list1 = append(list1, fmt.Sprintf("%sSrv := server.New%sServer(ctx)", name, item.Name.Title()))
list2 = append(list2, fmt.Sprintf("%s.Register%sServer(grpcServer, %sSrv)", pkg, item.Name.Title(), name))
}
return strings.Join(list1, "\n"), strings.Join(list2, "\n")
return strings.Join(list1, util.NL), strings.Join(list2, util.NL)
}

View File

@@ -1,68 +1,37 @@
package gen
import (
"errors"
"bytes"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"github.com/dsymonds/gotoc/parser"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
astParser "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
const (
protocCmd = "protoc"
grpcPluginCmd = "--go_out=plugins=grpc"
)
func (g *defaultRpcGenerator) genPb() error {
importPath, filename := filepath.Split(g.Ctx.ProtoFileSrc)
tree, err := parser.ParseFiles([]string{filename}, []string{importPath})
if err != nil {
return err
}
if len(tree.Files) == 0 {
return errors.New("proto ast parse failed")
}
file := tree.Files[0]
if len(file.Package) == 0 {
return errors.New("expected package, but nothing found")
}
targetStruct := make(map[string]lang.PlaceholderType)
for _, item := range file.Messages {
if len(item.Messages) > 0 {
return fmt.Errorf(`line %v: unexpected inner message near: "%v""`, item.Messages[0].Position.Line, item.Messages[0].Name)
}
name := stringx.From(item.Name)
if _, ok := targetStruct[name.Lower()]; ok {
return fmt.Errorf("line %v: duplicate %v", item.Position.Line, name)
}
targetStruct[name.Lower()] = lang.Placeholder
}
pbPath := g.dirM[dirPb]
protoFileName := filepath.Base(g.Ctx.ProtoFileSrc)
err = g.protocGenGo(pbPath)
imports, containsAny, err := parser.ParseImport(g.Ctx.ProtoFileSrc)
if err != nil {
return err
}
pbGo := strings.TrimSuffix(protoFileName, ".proto") + ".pb.go"
pbFile := filepath.Join(pbPath, pbGo)
bts, err := ioutil.ReadFile(pbFile)
err = g.protocGenGo(pbPath, imports)
if err != nil {
return err
}
aspParser := astParser.NewAstParser(bts, targetStruct, g.Ctx.Console)
ast, err := aspParser.Parse()
ast, err := parser.Transfer(g.Ctx.ProtoFileSrc, pbPath, imports, g.Ctx.Console)
if err != nil {
return err
}
ast.ContainsAny = containsAny
if len(ast.Service) == 0 {
return fmt.Errorf("service not found")
@@ -71,10 +40,35 @@ func (g *defaultRpcGenerator) genPb() error {
return nil
}
func (g *defaultRpcGenerator) protocGenGo(target string) error {
src := filepath.Dir(g.Ctx.ProtoFileSrc)
sh := fmt.Sprintf(`protoc -I=%s --go_out=plugins=grpc:%s %s`, src, target, g.Ctx.ProtoFileSrc)
stdout, err := execx.Run(sh, "")
func (g *defaultRpcGenerator) protocGenGo(target string, imports []*parser.Import) error {
dir := filepath.Dir(g.Ctx.ProtoFileSrc)
// cmd join,see the document of proto generating class @https://developers.google.com/protocol-buffers/docs/proto3#generating
// template: protoc -I=${import_path} -I=${other_import_path} -I=${...} --go_out=plugins=grpc,M${pb_package_kv}, M${...} :${target_dir}
// eg: protoc -I=${GOPATH}/src -I=. example.proto --go_out=plugins=grpc,Mbase/base.proto=github.com/go-zero/base.proto:.
// note: the external import out of the project which are found in ${GOPATH}/src so far.
buffer := new(bytes.Buffer)
buffer.WriteString(protocCmd + " ")
targetImportFiltered := collection.NewSet()
for _, item := range imports {
buffer.WriteString(fmt.Sprintf("-I=%s ", item.OriginalDir))
if len(item.BridgeImport) == 0 {
continue
}
targetImportFiltered.AddStr(item.BridgeImport)
}
buffer.WriteString("-I=${GOPATH}/src ")
buffer.WriteString(fmt.Sprintf("-I=%s %s ", dir, g.Ctx.ProtoFileSrc))
buffer.WriteString(grpcPluginCmd)
if targetImportFiltered.Count() > 0 {
buffer.WriteString(fmt.Sprintf(",%v", strings.Join(targetImportFiltered.KeysStr(), ",")))
}
buffer.WriteString(":" + target)
g.Ctx.Debug("-> " + buffer.String())
stdout, err := execx.Run(buffer.String(), "")
if err != nil {
return err
}

View File

@@ -5,6 +5,7 @@ import (
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
@@ -32,7 +33,7 @@ func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
`
functionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (s *{{.server}}Server) {{.method}} (ctx context.Context, in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
l := logic.New{{.logicName}}(ctx,s.svcCtx)
return l.{{.method}}(in)
}
@@ -45,29 +46,26 @@ func (s *{{.server}}Server) {{.method}} (ctx context.Context, in *{{.package}}.{
func (g *defaultRpcGenerator) genHandler() error {
serverPath := g.dirM[dirServer]
file := g.ast
pkg := file.Package
pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))
logicImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports := []string{
pbImport,
logicImport,
svcImport,
}
imports := collection.NewSet()
imports.AddStr(logicImport, svcImport)
head := util.GetHead(g.Ctx.ProtoSource)
for _, service := range file.Service {
filename := fmt.Sprintf("%vserver.go", service.Name.Lower())
serverFile := filepath.Join(serverPath, filename)
funcList, err := g.genFunctions(service)
funcList, importList, err := g.genFunctions(service)
if err != nil {
return err
}
imports.AddStr(importList...)
err = util.With("server").GoFmt(true).Parse(serverTemplate).SaveTo(map[string]interface{}{
"head": head,
"types": fmt.Sprintf(typeFmt, service.Name.Title()),
"server": service.Name.Title(),
"imports": strings.Join(imports, "\n\t"),
"funcs": strings.Join(funcList, "\n"),
"imports": strings.Join(imports.KeysStr(), util.NL),
"funcs": strings.Join(funcList, util.NL),
}, serverFile, true)
if err != nil {
return err
@@ -76,25 +74,31 @@ func (g *defaultRpcGenerator) genHandler() error {
return nil
}
func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, error) {
func (g *defaultRpcGenerator) genFunctions(service *parser.RpcService) ([]string, []string, error) {
file := g.ast
pkg := file.Package
var functionList []string
imports := collection.NewSet()
for _, method := range service.Funcs {
if method.ParameterIn.Package == pkg || method.ParameterOut.Package == pkg {
imports.AddStr(fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb)))
}
imports.AddStr(g.ast.Imports[method.ParameterIn.Package])
imports.AddStr(g.ast.Imports[method.ParameterOut.Package])
buffer, err := util.With("func").Parse(functionTemplate).Execute(map[string]interface{}{
"server": service.Name.Title(),
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(),
"package": pkg,
"request": method.InType,
"response": method.OutType,
"hasComment": len(method.Document),
"comment": strings.Join(method.Document, "\n"),
"request": method.ParameterIn.StarExpression,
"response": method.ParameterOut.StarExpression,
"hasComment": method.HaveDoc(),
"comment": method.GetDoc(),
})
if err != nil {
return nil, err
return nil, nil, err
}
functionList = append(functionList, buffer.String())
}
return functionList, nil
return functionList, imports.KeysStr(), nil
}

View File

@@ -39,6 +39,8 @@ func NewRpcTemplate(out string, idea bool) *rpcTemplate {
}
func (r *rpcTemplate) MustGenerate(showState bool) {
r.Info("查看rpc生成请移步至「https://github.com/tal-tech/go-zero/blob/master/doc/goctl-rpc.md」")
r.Info("generating template...")
protoFilename := filepath.Base(r.out)
serviceName := stringx.From(strings.TrimSuffix(protoFilename, filepath.Ext(protoFilename)))
err := util.With("t").Parse(rpcTemplateText).SaveTo(map[string]string{