refine goctl rpc generator

This commit is contained in:
kevin
2020-08-28 21:22:35 +08:00
parent db16115037
commit 72132ce399
20 changed files with 193 additions and 154 deletions

View File

@@ -0,0 +1,89 @@
package gogen
import (
"github.com/tal-tech/go-zero/tools/goctl/rpc/ctx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
const (
dirTarget = "dirTarget"
dirConfig = "config"
dirEtc = "etc"
dirSvc = "svc"
dirShared = "shared"
dirHandler = "handler"
dirLogic = "logic"
dirPb = "pb"
dirInternal = "internal"
fileConfig = "config.go"
fileServiceContext = "servicecontext.go"
)
type (
defaultRpcGenerator struct {
dirM map[string]string
Ctx *ctx.RpcContext
ast *parser.PbAst
}
)
func NewDefaultRpcGenerator(ctx *ctx.RpcContext) *defaultRpcGenerator {
return &defaultRpcGenerator{
Ctx: ctx,
}
}
func (g *defaultRpcGenerator) Generate() (err error) {
g.Ctx.Info("generating code...")
defer func() {
if err == nil {
g.Ctx.Success("Done.")
}
}()
err = g.createDir()
if err != nil {
return
}
err = g.genEtc()
if err != nil {
return
}
err = g.genPb()
if err != nil {
return
}
err = g.genConfig()
if err != nil {
return
}
err = g.genSvc()
if err != nil {
return
}
err = g.genLogic()
if err != nil {
return
}
err = g.genRemoteHandler()
if err != nil {
return
}
err = g.genMain()
if err != nil {
return
}
err = g.genShared()
if err != nil {
return
}
return nil
}

View File

@@ -0,0 +1,27 @@
package gogen
import (
"io/ioutil"
"os"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const configTemplate = `package config
import "github.com/tal-tech/go-zero/rpcx"
type Config struct {
rpcx.RpcServerConf
}
`
func (g *defaultRpcGenerator) genConfig() error {
configPath := g.dirM[dirConfig]
fileName := filepath.Join(configPath, fileConfig)
if util.FileExists(fileName) {
return nil
}
return ioutil.WriteFile(fileName, []byte(configTemplate), os.ModePerm)
}

View File

@@ -0,0 +1,45 @@
package gogen
import (
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
// target
// ├── etc
// ├── internal
// │   ├── config
// │   ├── handler
// │   ├── logic
// │   ├── pb
// │   └── svc
func (g *defaultRpcGenerator) createDir() error {
ctx := g.Ctx
m := make(map[string]string)
m[dirTarget] = ctx.TargetDir
m[dirEtc] = filepath.Join(ctx.TargetDir, dirEtc)
m[dirInternal] = filepath.Join(ctx.TargetDir, dirInternal)
m[dirConfig] = filepath.Join(ctx.TargetDir, dirInternal, dirConfig)
m[dirHandler] = filepath.Join(ctx.TargetDir, dirInternal, dirHandler)
m[dirLogic] = filepath.Join(ctx.TargetDir, dirInternal, dirLogic)
m[dirPb] = filepath.Join(ctx.TargetDir, dirPb)
m[dirSvc] = filepath.Join(ctx.TargetDir, dirInternal, dirSvc)
m[dirShared] = g.Ctx.SharedDir
for _, d := range m {
err := util.MkdirIfNotExist(d)
if err != nil {
return err
}
}
g.dirM = m
return nil
}
func (g *defaultRpcGenerator) mustGetPackage(dir string) string {
target := g.dirM[dir]
projectPath := g.Ctx.ProjectPath
relativePath := strings.TrimPrefix(target, projectPath)
return g.Ctx.Module + relativePath
}

View File

@@ -0,0 +1,30 @@
package gogen
import (
"fmt"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const etcTemplate = `Name: {{.serviceName}}.rpc
Log:
Mode: console
ListenOn: 127.0.0.1:8080
Etcd:
Hosts:
- 127.0.0.1:6379
Key: {{.serviceName}}.rpc
`
func (g *defaultRpcGenerator) genEtc() error {
etdDir := g.dirM[dirEtc]
fileName := filepath.Join(etdDir, fmt.Sprintf("%v.yaml", g.Ctx.ServiceName.Lower()))
if util.FileExists(fileName) {
return nil
}
return util.With("etc").Parse(etcTemplate).SaveTo(map[string]interface{}{
"serviceName": g.Ctx.ServiceName.Lower(),
}, fileName, false)
}

View File

@@ -0,0 +1,107 @@
package gogen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
remoteTemplate = `{{.head}}
package handler
import {{.imports}}
type {{.types}}
{{.newFuncs}}
`
functionTemplate = `{{.head}}
package handler
import (
"context"
{{.imports}}
)
type {{.server}}Server struct{}
{{if .hasComment}}{{.comment}}{{end}}
func (s *{{.server}}Server) {{.method}} (ctx context.Context, in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
l := logic.New{{.logicName}}(ctx,s.svcCtx)
return l.{{.method}}(in)
}
`
typeFmt = `%sServer struct {
svcCtx *svc.ServiceContext
}`
newFuncFmt = `func New%sServer(svcCtx *svc.ServiceContext) *%sServer {
return &%sServer{
svcCtx: svcCtx,
}
}`
)
func (g *defaultRpcGenerator) genRemoteHandler() error {
handlerPath := g.dirM[dirHandler]
serverGo := fmt.Sprintf("%vhandler.go", g.Ctx.ServiceName.Lower())
fileName := filepath.Join(handlerPath, serverGo)
file := g.ast
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
types := make([]string, 0)
newFuncs := make([]string, 0)
head := util.GetHead(g.Ctx.ProtoSource)
for _, service := range file.Service {
types = append(types, fmt.Sprintf(typeFmt, service.Name.Title()))
newFuncs = append(newFuncs, fmt.Sprintf(newFuncFmt, service.Name.Title(), service.Name.Title(), service.Name.Title()))
}
err := util.With("server").GoFmt(true).Parse(remoteTemplate).SaveTo(map[string]interface{}{
"head": head,
"types": strings.Join(types, "\n"),
"newFuncs": strings.Join(newFuncs, "\n"),
"imports": svcImport,
}, fileName, true)
if err != nil {
return err
}
return g.genFunctions()
}
func (g *defaultRpcGenerator) genFunctions() error {
handlerPath := g.dirM[dirHandler]
file := g.ast
pkg := file.Package
head := util.GetHead(g.Ctx.ProtoSource)
handlerImports := make([]string, 0)
pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))
handlerImports = append(handlerImports, pbImport, fmt.Sprintf(`"%v"`, g.mustGetPackage(dirLogic)))
for _, service := range file.Service {
for _, method := range service.Funcs {
handlerName := fmt.Sprintf("%shandler.go", method.Name.Lower())
filename := filepath.Join(handlerPath, handlerName)
// override
err := util.With("func").GoFmt(true).Parse(functionTemplate).SaveTo(map[string]interface{}{
"head": head,
"server": service.Name.Title(),
"imports": strings.Join(handlerImports, "\r\n"),
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(),
"package": pkg,
"request": method.InType,
"response": method.OutType,
"hasComment": len(method.Document),
"comment": strings.Join(method.Document, "\r\n"),
}, filename, true)
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -0,0 +1,93 @@
package gogen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
logicTemplate = `package logic
import (
"context"
{{.imports}}
"github.com/tal-tech/go-zero/core/logx"
)
type {{.logicName}} struct {
ctx context.Context
logx.Logger
}
func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logicName}} {
return &{{.logicName}}{
ctx: ctx,
Logger: logx.WithContext(ctx),
}
}
{{.functions}}
`
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
func (l *{{.logicName}}) {{.method}} (in *{{.package}}.{{.request}}) (*{{.package}}.{{.response}}, error) {
var resp {{.package}}.{{.response}}
// todo: add your logic here and delete this line
return &resp,nil
}
`
)
func (g *defaultRpcGenerator) genLogic() error {
logicPath := g.dirM[dirLogic]
protoPkg := g.ast.Package
service := g.ast.Service
for _, item := range service {
for _, method := range item.Funcs {
logicName := fmt.Sprintf("%slogic.go", method.Name.Lower())
filename := filepath.Join(logicPath, logicName)
functions, err := genLogicFunction(protoPkg, method)
if err != nil {
return err
}
imports := collection.NewSet()
pbImport := fmt.Sprintf(`%v "%v"`, protoPkg, g.mustGetPackage(dirPb))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
imports.AddStr(pbImport, svcImport)
err = util.With("logic").GoFmt(true).Parse(logicTemplate).SaveTo(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"functions": functions,
"imports": strings.Join(imports.KeysStr(), "\r\n"),
}, filename, false)
if err != nil {
return err
}
}
}
return nil
}
func genLogicFunction(packageName string, method *parser.Func) (string, error) {
var functions = make([]string, 0)
buffer, err := util.With("fun").Parse(logicFunctionTemplate).Execute(map[string]interface{}{
"logicName": fmt.Sprintf("%sLogic", method.Name.Title()),
"method": method.Name.Title(),
"package": packageName,
"request": method.InType,
"response": method.OutType,
"hasComment": len(method.Document) > 0,
"comment": strings.Join(method.Document, "\r\n"),
})
if err != nil {
return "", err
}
functions = append(functions, buffer.String())
return strings.Join(functions, "\n"), nil
}

View File

@@ -0,0 +1,83 @@
package gogen
import (
"fmt"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const mainTemplate = `{{.head}}
package main
import (
"flag"
"fmt"
"log"
{{.imports}}
"github.com/tal-tech/go-zero/core/conf"
"github.com/tal-tech/go-zero/rpcx"
"google.golang.org/grpc"
)
var configFile = flag.String("f", "etc/{{.serviceName}}.yaml", "the config file")
func main() {
flag.Parse()
var c config.Config
conf.MustLoad(*configFile, &c)
ctx := svc.NewServiceContext(c)
{{.srv}}
s, err := rpcx.NewServer(c.RpcServerConf, func(grpcServer *grpc.Server) {
{{.registers}}
})
if err != nil {
log.Fatal(err)
}
fmt.Printf("Starting rpc server at %s...\n", c.ListenOn)
s.Start()
}
`
func (g *defaultRpcGenerator) genMain() error {
mainPath := g.dirM[dirTarget]
file := g.ast
pkg := file.Package
fileName := filepath.Join(mainPath, fmt.Sprintf("%v.go", g.Ctx.ServiceName.Lower()))
imports := make([]string, 0)
pbImport := fmt.Sprintf(`%v "%v"`, pkg, g.mustGetPackage(dirPb))
svcImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirSvc))
remoteImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirHandler))
configImport := fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig))
imports = append(imports, configImport, pbImport, remoteImport, svcImport)
srv, registers := g.genServer(pkg, file.Service)
head := util.GetHead(g.Ctx.ProtoSource)
return util.With("main").GoFmt(true).Parse(mainTemplate).SaveTo(map[string]interface{}{
"head": head,
"package": pkg,
"serviceName": g.Ctx.ServiceName.Lower(),
"srv": srv,
"registers": registers,
"imports": strings.Join(imports, "\r\n"),
}, fileName, true)
}
func (g *defaultRpcGenerator) genServer(pkg string, list []*parser.RpcService) (string, string) {
list1 := make([]string, 0)
list2 := make([]string, 0)
for _, item := range list {
name := item.Name.UnTitle()
list1 = append(list1, fmt.Sprintf("%sSrv := handler.New%sServer(ctx)", name, item.Name.Title()))
list2 = append(list2, fmt.Sprintf("%s.Register%sServer(grpcServer, %sSrv)", pkg, item.Name.Title(), name))
}
return strings.Join(list1, "\n"), strings.Join(list2, "\n")
}

View File

@@ -0,0 +1,87 @@
package gogen
import (
"errors"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"github.com/dsymonds/gotoc/parser"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
astParser "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
func (g *defaultRpcGenerator) genPb() error {
importPath, filename := filepath.Split(g.Ctx.ProtoFileSrc)
tree, err := parser.ParseFiles([]string{filename}, []string{importPath})
if err != nil {
return err
}
if len(tree.Files) == 0 {
return errors.New("proto ast parse failed")
}
file := tree.Files[0]
if len(file.Package) == 0 {
return errors.New("expected package, but nothing found")
}
targetStruct := make(map[string]lang.PlaceholderType)
for _, item := range file.Messages {
if len(item.Messages) > 0 {
return fmt.Errorf(`line %v: unexpected inner message near: "%v""`, item.Messages[0].Position.Line, item.Messages[0].Name)
}
name := stringx.From(item.Name)
if _, ok := targetStruct[name.Lower()]; ok {
return fmt.Errorf("line %v: duplicate %v", item.Position.Line, name)
}
targetStruct[name.Lower()] = lang.Placeholder
}
pbPath := g.dirM[dirPb]
protoFileName := filepath.Base(g.Ctx.ProtoFileSrc)
err = g.protocGenGo(pbPath)
if err != nil {
return err
}
pbGo := strings.TrimSuffix(protoFileName, ".proto") + ".pb.go"
pbFile := filepath.Join(pbPath, pbGo)
bts, err := ioutil.ReadFile(pbFile)
if err != nil {
return err
}
aspParser := astParser.NewAstParser(bts, targetStruct, g.Ctx.Console)
ast, err := aspParser.Parse()
if err != nil {
return err
}
if len(ast.Service) == 0 {
return fmt.Errorf("service not found")
}
g.ast = ast
return nil
}
func (g *defaultRpcGenerator) protocGenGo(target string) error {
src := filepath.Dir(g.Ctx.ProtoFileSrc)
sh := fmt.Sprintf(`export PATH=%s:$PATH
protoc -I=%s --go_out=plugins=grpc:%s %s`, filepath.Join(g.Ctx.GoPath, "bin"), src, target, g.Ctx.ProtoFileSrc)
stdout, err := execx.Run(sh)
if err != nil {
return err
}
if len(stdout) > 0 {
g.Ctx.Info(stdout)
}
return nil
}

View File

@@ -0,0 +1,218 @@
package gogen
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const (
sharedTemplateText = `{{.head}}
//go:generate mockgen -destination ./{{.name}}model_mock.go -package {{.filePackage}} -source $GOFILE
package {{.filePackage}}
import (
"context"
{{.package}}
"github.com/tal-tech/go-zero/core/jsonx"
"github.com/tal-tech/go-zero/rpcx"
)
type (
{{.serviceName}}Model interface {
{{.interface}}
}
default{{.serviceName}}Model struct {
cli rpcx.Client
}
)
func New{{.serviceName}}Model(cli rpcx.Client) {{.serviceName}}Model {
return &default{{.serviceName}}Model{
cli: cli,
}
}
{{.functions}}
`
sharedTemplateTypes = `{{.head}}
package {{.filePackage}}
import "errors"
var errJsonConvert = errors.New("json convert error")
{{.types}}
`
sharedInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) {{if .hasResponse}}(*{{.pbResponse}},{{end}} error{{if .hasResponse}}){{end}}`
sharedFunctionTemplate = `
{{if .hasComment}}{{.comment}}{{end}}
func (m *default{{.rpcServiceName}}Model) {{.method}}(ctx context.Context,in *{{.pbRequest}}) {{if .hasResponse}}(*{{.pbResponse}},{{end}} error{{if .hasResponse}}){{end}} {
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
var request {{.package}}.{{.pbRequest}}
bts, err := jsonx.Marshal(in)
if err != nil {
return {{if .hasResponse}}nil, {{end}}errJsonConvert
}
err = jsonx.Unmarshal(bts, &request)
if err != nil {
return {{if .hasResponse}}nil, {{end}}errJsonConvert
}
{{if .hasResponse}}resp, err := {{else}}_, err = {{end}}client.{{.method}}(ctx, &request)
{{if .hasResponse}}if err != nil{
return nil, err
}
var ret {{.pbResponse}}
bts, err = jsonx.Marshal(resp)
if err != nil{
return nil, errJsonConvert
}
err = jsonx.Unmarshal(bts, &ret)
if err != nil{
return nil, errJsonConvert
}
return &ret, nil{{else}}if err != nil {
return err
}
return nil{{end}}
}
`
)
func (g *defaultRpcGenerator) genShared() error {
sharePackage := filepath.Base(g.Ctx.SharedDir)
file := g.ast
typeCode, err := file.GenTypesCode()
if err != nil {
return err
}
pbPkg := file.Package
remotePackage := fmt.Sprintf(`%v "%v"`, pbPkg, g.mustGetPackage(dirPb))
filename := filepath.Join(g.Ctx.SharedDir, "types.go")
head := util.GetHead(g.Ctx.ProtoSource)
err = util.With("types").GoFmt(true).Parse(sharedTemplateTypes).SaveTo(map[string]interface{}{
"head": head,
"filePackage": sharePackage,
"pbPkg": pbPkg,
"serviceName": g.Ctx.ServiceName.Title(),
"lowerStartServiceName": g.Ctx.ServiceName.UnTitle(),
"types": typeCode,
}, filename, true)
for _, service := range file.Service {
filename := filepath.Join(g.Ctx.SharedDir, fmt.Sprintf("%smodel.go", service.Name.Lower()))
functions, err := g.getFuncs(service)
if err != nil {
return err
}
iFunctions, err := g.getInterfaceFuncs(service)
if err != nil {
return err
}
mockFile := filepath.Join(g.Ctx.SharedDir, fmt.Sprintf("%smodel_mock.go", service.Name.Lower()))
os.Remove(mockFile)
err = util.With("shared").GoFmt(true).Parse(sharedTemplateText).SaveTo(map[string]interface{}{
"name": service.Name.Lower(),
"head": head,
"filePackage": sharePackage,
"pbPkg": pbPkg,
"package": remotePackage,
"serviceName": service.Name.Title(),
"functions": strings.Join(functions, "\n"),
"interface": strings.Join(iFunctions, "\n"),
}, filename, true)
if err != nil {
return err
}
}
// if mockgen is already installed, it will generate code of gomock for shared files
_, err = exec.LookPath("mockgen")
if err != nil {
g.Ctx.Warning("warning:mockgen is not found")
} else {
execx.Run(fmt.Sprintf("cd %s \ngo generate", g.Ctx.SharedDir))
}
return nil
}
func (g *defaultRpcGenerator) getFuncs(service *parser.RpcService) ([]string, error) {
file := g.ast
pkgName := file.Package
functions := make([]string, 0)
for _, method := range service.Funcs {
data, found := file.Strcuts[strings.ToLower(method.OutType)]
if found {
found = len(data.Field) > 0
}
var comment string
if len(method.Document) > 0 {
comment = method.Document[0]
}
buffer, err := util.With("sharedFn").Parse(sharedFunctionTemplate).Execute(map[string]interface{}{
"rpcServiceName": service.Name.Title(),
"method": method.Name.Title(),
"package": pkgName,
"pbRequest": method.InType,
"pbResponse": method.OutType,
"hasResponse": found,
"hasComment": len(method.Document) > 0,
"comment": comment,
})
if err != nil {
return nil, err
}
functions = append(functions, buffer.String())
}
return functions, nil
}
func (g *defaultRpcGenerator) getInterfaceFuncs(service *parser.RpcService) ([]string, error) {
file := g.ast
functions := make([]string, 0)
for _, method := range service.Funcs {
data, found := file.Strcuts[strings.ToLower(method.OutType)]
if found {
found = len(data.Field) > 0
}
var comment string
if len(method.Document) > 0 {
comment = method.Document[0]
}
buffer, err := util.With("interfaceFn").Parse(sharedInterfaceFunctionTemplate).Execute(map[string]interface{}{
"hasComment": len(method.Document) > 0,
"comment": comment,
"method": method.Name.Title(),
"pbRequest": method.InType,
"pbResponse": method.OutType,
"hasResponse": found,
})
if err != nil {
return nil, err
}
functions = append(functions, buffer.String())
}
return functions, nil
}

View File

@@ -0,0 +1,31 @@
package gogen
import (
"fmt"
"path/filepath"
"github.com/tal-tech/go-zero/tools/goctl/util"
)
const svcTemplate = `package svc
import {{.imports}}
type ServiceContext struct {
c config.Config
}
func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{
c:c,
}
}
`
func (g *defaultRpcGenerator) genSvc() error {
svcPath := g.dirM[dirSvc]
fileName := filepath.Join(svcPath, fileServiceContext)
return util.With("svc").GoFmt(true).Parse(svcTemplate).SaveTo(map[string]interface{}{
"imports": fmt.Sprintf(`"%v"`, g.mustGetPackage(dirConfig)),
}, fileName, false)
}

View File

@@ -0,0 +1,43 @@
package gogen
import (
"github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
)
const rpcTemplateText = `syntax = "proto3";
package remote;
message Request {
string username = 1;
string password = 2;
}
message Response {
string name = 1;
string gender = 2;
}
service User {
rpc Login(Request) returns(Response);
}
`
type rpcTemplate struct {
out string
console.Console
}
func NewRpcTemplate(out string, idea bool) *rpcTemplate {
return &rpcTemplate{
out: out,
Console: console.NewConsole(idea),
}
}
func (r *rpcTemplate) MustGenerate() {
err := util.With("t").Parse(rpcTemplateText).SaveTo(nil, r.out, false)
r.Must(err)
r.Success("Done.")
}