Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
This commit is contained in:
@@ -29,7 +29,6 @@ func TestRpcGenerate(t *testing.T) {
|
|||||||
projectName := stringx.Rand()
|
projectName := stringx.Rand()
|
||||||
g := NewRPCGenerator(dispatcher, cfg)
|
g := NewRPCGenerator(dispatcher, cfg)
|
||||||
|
|
||||||
// case go path
|
|
||||||
src := filepath.Join(build.Default.GOPATH, "src")
|
src := filepath.Join(build.Default.GOPATH, "src")
|
||||||
_, err = os.Stat(src)
|
_, err = os.Stat(src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -41,45 +40,52 @@ func TestRpcGenerate(t *testing.T) {
|
|||||||
defer func() {
|
defer func() {
|
||||||
_ = os.RemoveAll(srcDir)
|
_ = os.RemoveAll(srcDir)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
common, err := filepath.Abs(".")
|
common, err := filepath.Abs(".")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
// case go path
|
||||||
assert.Nil(t, err)
|
t.Run("GOPATH", func(t *testing.T) {
|
||||||
_, err = execx.Run("go test "+projectName, projectDir)
|
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||||
if err != nil {
|
assert.Nil(t, err)
|
||||||
assert.True(t, func() bool {
|
_, err = execx.Run("go test "+projectName, projectDir)
|
||||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
if err != nil {
|
||||||
}())
|
assert.True(t, func() bool {
|
||||||
}
|
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||||
|
}())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// case go mod
|
// case go mod
|
||||||
workDir := t.TempDir()
|
t.Run("GOMOD", func(t *testing.T) {
|
||||||
name := filepath.Base(workDir)
|
workDir := t.TempDir()
|
||||||
_, err = execx.Run("go mod init "+name, workDir)
|
name := filepath.Base(workDir)
|
||||||
if err != nil {
|
_, err = execx.Run("go mod init "+name, workDir)
|
||||||
logx.Error(err)
|
if err != nil {
|
||||||
return
|
logx.Error(err)
|
||||||
}
|
return
|
||||||
|
}
|
||||||
|
|
||||||
projectDir = filepath.Join(workDir, projectName)
|
projectDir = filepath.Join(workDir, projectName)
|
||||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
_, err = execx.Run("go test "+projectName, projectDir)
|
_, err = execx.Run("go test "+projectName, projectDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
assert.True(t, func() bool {
|
assert.True(t, func() bool {
|
||||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||||
}())
|
}())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
// case not in go mod and go path
|
// case not in go mod and go path
|
||||||
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
t.Run("OTHER", func(t *testing.T) {
|
||||||
assert.Nil(t, err)
|
err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base")
|
||||||
_, err = execx.Run("go test "+projectName, projectDir)
|
assert.Nil(t, err)
|
||||||
if err != nil {
|
_, err = execx.Run("go test "+projectName, projectDir)
|
||||||
assert.True(t, func() bool {
|
if err != nil {
|
||||||
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
assert.True(t, func() bool {
|
||||||
}())
|
return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package")
|
||||||
}
|
}())
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,13 +49,13 @@ func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} {
|
|||||||
`
|
`
|
||||||
|
|
||||||
callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
|
callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}}
|
||||||
{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)`
|
{{end}}{{.method}}(ctx context.Context{{if .hasReq}},in *{{.pbRequest}}{{end}}) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error)`
|
||||||
|
|
||||||
callFunctionTemplate = `
|
callFunctionTemplate = `
|
||||||
{{if .hasComment}}{{.comment}}{{end}}
|
{{if .hasComment}}{{.comment}}{{end}}
|
||||||
func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) {
|
func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}},in *{{.pbRequest}}{{end}}) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error) {
|
||||||
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
|
client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn())
|
||||||
return client.{{.method}}(ctx, in)
|
return client.{{.method}}(ctx,{{if .hasReq}} in{{end}})
|
||||||
}
|
}
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
@@ -78,7 +78,7 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
iFunctions, err := g.getInterfaceFuncs(service)
|
iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -115,6 +115,14 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
|
|||||||
}
|
}
|
||||||
|
|
||||||
comment := parser.GetComment(rpc.Doc())
|
comment := parser.GetComment(rpc.Doc())
|
||||||
|
var streamServer string
|
||||||
|
if rpc.StreamsRequest && rpc.StreamsReturns {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamClient")
|
||||||
|
} else if rpc.StreamsRequest {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamClient")
|
||||||
|
} else {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamClient")
|
||||||
|
}
|
||||||
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
|
buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{
|
||||||
"serviceName": stringx.From(service.Name).ToCamel(),
|
"serviceName": stringx.From(service.Name).ToCamel(),
|
||||||
"rpcServiceName": parser.CamelCase(service.Name),
|
"rpcServiceName": parser.CamelCase(service.Name),
|
||||||
@@ -124,6 +132,9 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
|
|||||||
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
||||||
"hasComment": len(comment) > 0,
|
"hasComment": len(comment) > 0,
|
||||||
"comment": comment,
|
"comment": comment,
|
||||||
|
"hasReq": !rpc.StreamsRequest,
|
||||||
|
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||||
|
"streamBody": streamServer,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -134,7 +145,7 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service)
|
|||||||
return functions, nil
|
return functions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *DefaultGenerator) getInterfaceFuncs(service parser.Service) ([]string, error) {
|
func (g *DefaultGenerator) getInterfaceFuncs(goPackage string, service parser.Service) ([]string, error) {
|
||||||
functions := make([]string, 0)
|
functions := make([]string, 0)
|
||||||
|
|
||||||
for _, rpc := range service.RPC {
|
for _, rpc := range service.RPC {
|
||||||
@@ -144,13 +155,25 @@ func (g *DefaultGenerator) getInterfaceFuncs(service parser.Service) ([]string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
comment := parser.GetComment(rpc.Doc())
|
comment := parser.GetComment(rpc.Doc())
|
||||||
|
var streamServer string
|
||||||
|
if rpc.StreamsRequest && rpc.StreamsReturns {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamClient")
|
||||||
|
} else if rpc.StreamsRequest {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamClient")
|
||||||
|
} else {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamClient")
|
||||||
|
}
|
||||||
|
|
||||||
buffer, err := util.With("interfaceFn").Parse(text).Execute(
|
buffer, err := util.With("interfaceFn").Parse(text).Execute(
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"hasComment": len(comment) > 0,
|
"hasComment": len(comment) > 0,
|
||||||
"comment": comment,
|
"comment": comment,
|
||||||
"method": parser.CamelCase(rpc.Name),
|
"method": parser.CamelCase(rpc.Name),
|
||||||
|
"hasReq": !rpc.StreamsRequest,
|
||||||
"pbRequest": parser.CamelCase(rpc.RequestType),
|
"pbRequest": parser.CamelCase(rpc.RequestType),
|
||||||
|
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||||
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
"pbResponse": parser.CamelCase(rpc.ReturnsType),
|
||||||
|
"streamBody": streamServer,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -40,10 +40,10 @@ func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logic
|
|||||||
{{.functions}}
|
{{.functions}}
|
||||||
`
|
`
|
||||||
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
|
logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}}
|
||||||
func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
|
func (l *{{.logicName}}) {{.method}} ({{if .hasReq}}in {{.request}}{{if .stream}},stream {{.streamBody}}{{end}}{{else}}stream {{.streamBody}}{{end}}) ({{if .hasReply}}{{.response}},{{end}} error) {
|
||||||
// todo: add your logic here and delete this line
|
// todo: add your logic here and delete this line
|
||||||
|
|
||||||
return &{{.responseType}}{}, nil
|
return {{if .hasReply}}&{{.responseType}}{},{{end}} nil
|
||||||
}
|
}
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
@@ -51,6 +51,7 @@ func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
|
|||||||
// GenLogic generates the logic file of the rpc service, which corresponds to the RPC definition items in proto.
|
// GenLogic generates the logic file of the rpc service, which corresponds to the RPC definition items in proto.
|
||||||
func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
|
||||||
dir := ctx.GetLogic()
|
dir := ctx.GetLogic()
|
||||||
|
service := proto.Service.Service.Name
|
||||||
for _, rpc := range proto.Service.RPC {
|
for _, rpc := range proto.Service.RPC {
|
||||||
logicFilename, err := format.FileNamingFormat(cfg.NamingFormat, rpc.Name+"_logic")
|
logicFilename, err := format.FileNamingFormat(cfg.NamingFormat, rpc.Name+"_logic")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -58,7 +59,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con
|
|||||||
}
|
}
|
||||||
|
|
||||||
filename := filepath.Join(dir.Filename, logicFilename+".go")
|
filename := filepath.Join(dir.Filename, logicFilename+".go")
|
||||||
functions, err := g.genLogicFunction(proto.PbPackage, rpc)
|
functions, err := g.genLogicFunction(service, proto.PbPackage, rpc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -82,7 +83,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *DefaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (string, error) {
|
func (g *DefaultGenerator) genLogicFunction(serviceName string, goPackage string, rpc *parser.RPC) (string, error) {
|
||||||
functions := make([]string, 0)
|
functions := make([]string, 0)
|
||||||
text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
|
text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -91,12 +92,24 @@ func (g *DefaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (
|
|||||||
|
|
||||||
logicName := stringx.From(rpc.Name + "_logic").ToCamel()
|
logicName := stringx.From(rpc.Name + "_logic").ToCamel()
|
||||||
comment := parser.GetComment(rpc.Doc())
|
comment := parser.GetComment(rpc.Doc())
|
||||||
|
var streamServer string
|
||||||
|
if rpc.StreamsRequest && rpc.StreamsReturns {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_StreamServer")
|
||||||
|
} else if rpc.StreamsRequest {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_ClientStreamServer")
|
||||||
|
} else {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_ServerStreamServer")
|
||||||
|
}
|
||||||
buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
|
buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{
|
||||||
"logicName": logicName,
|
"logicName": logicName,
|
||||||
"method": parser.CamelCase(rpc.Name),
|
"method": parser.CamelCase(rpc.Name),
|
||||||
|
"hasReq": !rpc.StreamsRequest,
|
||||||
"request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
|
"request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)),
|
||||||
|
"hasReply": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||||
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
|
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
|
||||||
"responseType": fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
|
"responseType": fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
|
||||||
|
"stream": rpc.StreamsRequest || rpc.StreamsReturns,
|
||||||
|
"streamBody": streamServer,
|
||||||
"hasComment": len(comment) > 0,
|
"hasComment": len(comment) > 0,
|
||||||
"comment": comment,
|
"comment": comment,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -38,9 +38,9 @@ func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server {
|
|||||||
`
|
`
|
||||||
functionTemplate = `
|
functionTemplate = `
|
||||||
{{if .hasComment}}{{.comment}}{{end}}
|
{{if .hasComment}}{{.comment}}{{end}}
|
||||||
func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) {
|
func (s *{{.server}}Server) {{.method}} ({{if .notStream}}ctx context.Context,{{if .hasReq}} in {{.request}}{{end}}{{else}}{{if .hasReq}} in {{.request}},{{end}}stream {{.streamBody}}{{end}}) ({{if .notStream}}{{.response}},{{end}}error) {
|
||||||
l := logic.New{{.logicName}}(ctx,s.svcCtx)
|
l := logic.New{{.logicName}}({{if .notStream}}ctx,{{else}}stream.Context(),{{end}}s.svcCtx)
|
||||||
return l.{{.method}}(in)
|
return l.{{.method}}({{if .hasReq}}in{{if .stream}} ,stream{{end}}{{else}}{{if .stream}}stream{{end}}{{end}})
|
||||||
}
|
}
|
||||||
`
|
`
|
||||||
)
|
)
|
||||||
@@ -91,6 +91,15 @@ func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service
|
|||||||
}
|
}
|
||||||
|
|
||||||
comment := parser.GetComment(rpc.Doc())
|
comment := parser.GetComment(rpc.Doc())
|
||||||
|
var streamServer string
|
||||||
|
if rpc.StreamsRequest && rpc.StreamsReturns {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamServer")
|
||||||
|
} else if rpc.StreamsRequest {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamServer")
|
||||||
|
} else {
|
||||||
|
streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamServer")
|
||||||
|
}
|
||||||
|
|
||||||
buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
|
buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{
|
||||||
"server": stringx.From(service.Name).ToCamel(),
|
"server": stringx.From(service.Name).ToCamel(),
|
||||||
"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()),
|
"logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()),
|
||||||
@@ -99,6 +108,10 @@ func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service
|
|||||||
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
|
"response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)),
|
||||||
"hasComment": len(comment) > 0,
|
"hasComment": len(comment) > 0,
|
||||||
"comment": comment,
|
"comment": comment,
|
||||||
|
"hasReq": !rpc.StreamsRequest,
|
||||||
|
"stream": rpc.StreamsRequest || rpc.StreamsReturns,
|
||||||
|
"notStream": !rpc.StreamsRequest && !rpc.StreamsReturns,
|
||||||
|
"streamBody": streamServer,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -59,4 +59,10 @@ service Test_Service {
|
|||||||
rpc MapService (MapReq) returns (CommonReply);
|
rpc MapService (MapReq) returns (CommonReply);
|
||||||
// case repeated
|
// case repeated
|
||||||
rpc RepeatedService (RepeatedReq) returns (CommonReply);
|
rpc RepeatedService (RepeatedReq) returns (CommonReply);
|
||||||
|
// server stream
|
||||||
|
rpc ServerStream (Req) returns (stream Reply);
|
||||||
|
// client stream
|
||||||
|
rpc ClientStream (stream Req) returns (Reply);
|
||||||
|
// stream
|
||||||
|
rpc Stream(stream Req) returns (stream Reply);
|
||||||
}
|
}
|
||||||
16
tools/goctl/rpc/parser/stream.proto
Normal file
16
tools/goctl/rpc/parser/stream.proto
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package test;
|
||||||
|
|
||||||
|
message Req{
|
||||||
|
string input = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Reply{
|
||||||
|
string output = 1;
|
||||||
|
}
|
||||||
|
service TestService{
|
||||||
|
rpc ServerStream (Req) returns (stream Reply);
|
||||||
|
rpc ClientStream (stream Req) returns (Reply);
|
||||||
|
rpc Stream (stream Req) returns (stream Reply);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user