From db87fd3239df415c4136e4b5888842fab4ca73e9 Mon Sep 17 00:00:00 2001 From: anqiansong Date: Fri, 16 Jul 2021 22:54:07 +0800 Subject: [PATCH] To generate grpc stream, fix issue #616 (#815) Co-authored-by: anqiansong --- tools/goctl/rpc/generator/gen_test.go | 74 ++++++++++++++------------ tools/goctl/rpc/generator/gencall.go | 33 ++++++++++-- tools/goctl/rpc/generator/genlogic.go | 21 ++++++-- tools/goctl/rpc/generator/genserver.go | 19 +++++-- tools/goctl/rpc/generator/test.proto | 6 +++ tools/goctl/rpc/parser/stream.proto | 16 ++++++ 6 files changed, 123 insertions(+), 46 deletions(-) create mode 100644 tools/goctl/rpc/parser/stream.proto diff --git a/tools/goctl/rpc/generator/gen_test.go b/tools/goctl/rpc/generator/gen_test.go index b6f6f0d5..12bb6bdc 100644 --- a/tools/goctl/rpc/generator/gen_test.go +++ b/tools/goctl/rpc/generator/gen_test.go @@ -29,7 +29,6 @@ func TestRpcGenerate(t *testing.T) { projectName := stringx.Rand() g := NewRPCGenerator(dispatcher, cfg) - // case go path src := filepath.Join(build.Default.GOPATH, "src") _, err = os.Stat(src) if err != nil { @@ -41,45 +40,52 @@ func TestRpcGenerate(t *testing.T) { defer func() { _ = os.RemoveAll(srcDir) }() - common, err := filepath.Abs(".") assert.Nil(t, err) - err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base") - assert.Nil(t, err) - _, err = execx.Run("go test "+projectName, projectDir) - if err != nil { - assert.True(t, func() bool { - return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package") - }()) - } + // case go path + t.Run("GOPATH", func(t *testing.T) { + err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base") + assert.Nil(t, err) + _, err = execx.Run("go test "+projectName, projectDir) + if err != nil { + assert.True(t, func() bool { + return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package") + }()) + } + }) // case go mod - workDir := t.TempDir() - name := filepath.Base(workDir) - _, err = execx.Run("go mod init "+name, workDir) - if err != nil { - logx.Error(err) - return - } + t.Run("GOMOD", func(t *testing.T) { + workDir := t.TempDir() + name := filepath.Base(workDir) + _, err = execx.Run("go mod init "+name, workDir) + if err != nil { + logx.Error(err) + return + } - projectDir = filepath.Join(workDir, projectName) - err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base") - assert.Nil(t, err) - _, err = execx.Run("go test "+projectName, projectDir) - if err != nil { - assert.True(t, func() bool { - return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package") - }()) - } + projectDir = filepath.Join(workDir, projectName) + err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base") + assert.Nil(t, err) + _, err = execx.Run("go test "+projectName, projectDir) + if err != nil { + assert.True(t, func() bool { + return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package") + }()) + } + + }) // case not in go mod and go path - err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base") - assert.Nil(t, err) - _, err = execx.Run("go test "+projectName, projectDir) - if err != nil { - assert.True(t, func() bool { - return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package") - }()) - } + t.Run("OTHER", func(t *testing.T) { + err = g.Generate("./test.proto", projectDir, []string{common, src}, "Mbase/common.proto=./base") + assert.Nil(t, err) + _, err = execx.Run("go test "+projectName, projectDir) + if err != nil { + assert.True(t, func() bool { + return strings.Contains(err.Error(), "not in GOROOT") || strings.Contains(err.Error(), "cannot find package") + }()) + } + }) } diff --git a/tools/goctl/rpc/generator/gencall.go b/tools/goctl/rpc/generator/gencall.go index ec42a436..a2ae50ba 100644 --- a/tools/goctl/rpc/generator/gencall.go +++ b/tools/goctl/rpc/generator/gencall.go @@ -49,13 +49,13 @@ func New{{.serviceName}}(cli zrpc.Client) {{.serviceName}} { ` callInterfaceFunctionTemplate = `{{if .hasComment}}{{.comment}} -{{end}}{{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}},error)` +{{end}}{{.method}}(ctx context.Context{{if .hasReq}},in *{{.pbRequest}}{{end}}) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error)` callFunctionTemplate = ` {{if .hasComment}}{{.comment}}{{end}} -func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbRequest}}) (*{{.pbResponse}}, error) { +func (m *default{{.serviceName}}) {{.method}}(ctx context.Context{{if .hasReq}},in *{{.pbRequest}}{{end}}) ({{if .notStream}}*{{.pbResponse}}, {{else}}{{.streamBody}},{{end}} error) { client := {{.package}}.New{{.rpcServiceName}}Client(m.cli.Conn()) - return client.{{.method}}(ctx, in) + return client.{{.method}}(ctx,{{if .hasReq}} in{{end}}) } ` ) @@ -78,7 +78,7 @@ func (g *DefaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf return err } - iFunctions, err := g.getInterfaceFuncs(service) + iFunctions, err := g.getInterfaceFuncs(proto.PbPackage, service) if err != nil { return err } @@ -115,6 +115,14 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service) } comment := parser.GetComment(rpc.Doc()) + var streamServer string + if rpc.StreamsRequest && rpc.StreamsReturns { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamClient") + } else if rpc.StreamsRequest { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamClient") + } else { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamClient") + } buffer, err := util.With("sharedFn").Parse(text).Execute(map[string]interface{}{ "serviceName": stringx.From(service.Name).ToCamel(), "rpcServiceName": parser.CamelCase(service.Name), @@ -124,6 +132,9 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service) "pbResponse": parser.CamelCase(rpc.ReturnsType), "hasComment": len(comment) > 0, "comment": comment, + "hasReq": !rpc.StreamsRequest, + "notStream": !rpc.StreamsRequest && !rpc.StreamsReturns, + "streamBody": streamServer, }) if err != nil { return nil, err @@ -134,7 +145,7 @@ func (g *DefaultGenerator) genFunction(goPackage string, service parser.Service) return functions, nil } -func (g *DefaultGenerator) getInterfaceFuncs(service parser.Service) ([]string, error) { +func (g *DefaultGenerator) getInterfaceFuncs(goPackage string, service parser.Service) ([]string, error) { functions := make([]string, 0) for _, rpc := range service.RPC { @@ -144,13 +155,25 @@ func (g *DefaultGenerator) getInterfaceFuncs(service parser.Service) ([]string, } comment := parser.GetComment(rpc.Doc()) + var streamServer string + if rpc.StreamsRequest && rpc.StreamsReturns { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamClient") + } else if rpc.StreamsRequest { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamClient") + } else { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamClient") + } + buffer, err := util.With("interfaceFn").Parse(text).Execute( map[string]interface{}{ "hasComment": len(comment) > 0, "comment": comment, "method": parser.CamelCase(rpc.Name), + "hasReq": !rpc.StreamsRequest, "pbRequest": parser.CamelCase(rpc.RequestType), + "notStream": !rpc.StreamsRequest && !rpc.StreamsReturns, "pbResponse": parser.CamelCase(rpc.ReturnsType), + "streamBody": streamServer, }) if err != nil { return nil, err diff --git a/tools/goctl/rpc/generator/genlogic.go b/tools/goctl/rpc/generator/genlogic.go index 62eb9111..fe230cde 100644 --- a/tools/goctl/rpc/generator/genlogic.go +++ b/tools/goctl/rpc/generator/genlogic.go @@ -40,10 +40,10 @@ func New{{.logicName}}(ctx context.Context,svcCtx *svc.ServiceContext) *{{.logic {{.functions}} ` logicFunctionTemplate = `{{if .hasComment}}{{.comment}}{{end}} -func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) { +func (l *{{.logicName}}) {{.method}} ({{if .hasReq}}in {{.request}}{{if .stream}},stream {{.streamBody}}{{end}}{{else}}stream {{.streamBody}}{{end}}) ({{if .hasReply}}{{.response}},{{end}} error) { // todo: add your logic here and delete this line - return &{{.responseType}}{}, nil + return {{if .hasReply}}&{{.responseType}}{},{{end}} nil } ` ) @@ -51,6 +51,7 @@ func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) { // GenLogic generates the logic file of the rpc service, which corresponds to the RPC definition items in proto. func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error { dir := ctx.GetLogic() + service := proto.Service.Service.Name for _, rpc := range proto.Service.RPC { logicFilename, err := format.FileNamingFormat(cfg.NamingFormat, rpc.Name+"_logic") if err != nil { @@ -58,7 +59,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con } filename := filepath.Join(dir.Filename, logicFilename+".go") - functions, err := g.genLogicFunction(proto.PbPackage, rpc) + functions, err := g.genLogicFunction(service, proto.PbPackage, rpc) if err != nil { return err } @@ -82,7 +83,7 @@ func (g *DefaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *con return nil } -func (g *DefaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) (string, error) { +func (g *DefaultGenerator) genLogicFunction(serviceName string, goPackage string, rpc *parser.RPC) (string, error) { functions := make([]string, 0) text, err := util.LoadTemplate(category, logicFuncTemplateFileFile, logicFunctionTemplate) if err != nil { @@ -91,12 +92,24 @@ func (g *DefaultGenerator) genLogicFunction(goPackage string, rpc *parser.RPC) ( logicName := stringx.From(rpc.Name + "_logic").ToCamel() comment := parser.GetComment(rpc.Doc()) + var streamServer string + if rpc.StreamsRequest && rpc.StreamsReturns { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_StreamServer") + } else if rpc.StreamsRequest { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_ClientStreamServer") + } else { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(serviceName)+"_ServerStreamServer") + } buffer, err := util.With("fun").Parse(text).Execute(map[string]interface{}{ "logicName": logicName, "method": parser.CamelCase(rpc.Name), + "hasReq": !rpc.StreamsRequest, "request": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.RequestType)), + "hasReply": !rpc.StreamsRequest && !rpc.StreamsReturns, "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)), "responseType": fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)), + "stream": rpc.StreamsRequest || rpc.StreamsReturns, + "streamBody": streamServer, "hasComment": len(comment) > 0, "comment": comment, }) diff --git a/tools/goctl/rpc/generator/genserver.go b/tools/goctl/rpc/generator/genserver.go index 61f2dad9..6221600d 100644 --- a/tools/goctl/rpc/generator/genserver.go +++ b/tools/goctl/rpc/generator/genserver.go @@ -38,9 +38,9 @@ func New{{.server}}Server(svcCtx *svc.ServiceContext) *{{.server}}Server { ` functionTemplate = ` {{if .hasComment}}{{.comment}}{{end}} -func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) ({{.response}}, error) { - l := logic.New{{.logicName}}(ctx,s.svcCtx) - return l.{{.method}}(in) +func (s *{{.server}}Server) {{.method}} ({{if .notStream}}ctx context.Context,{{if .hasReq}} in {{.request}}{{end}}{{else}}{{if .hasReq}} in {{.request}},{{end}}stream {{.streamBody}}{{end}}) ({{if .notStream}}{{.response}},{{end}}error) { + l := logic.New{{.logicName}}({{if .notStream}}ctx,{{else}}stream.Context(),{{end}}s.svcCtx) + return l.{{.method}}({{if .hasReq}}in{{if .stream}} ,stream{{end}}{{else}}{{if .stream}}stream{{end}}{{end}}) } ` ) @@ -91,6 +91,15 @@ func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service } comment := parser.GetComment(rpc.Doc()) + var streamServer string + if rpc.StreamsRequest && rpc.StreamsReturns { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_StreamServer") + } else if rpc.StreamsRequest { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ClientStreamServer") + } else { + streamServer = fmt.Sprintf("%s.%s", goPackage, parser.CamelCase(service.Name)+"_ServerStreamServer") + } + buffer, err := util.With("func").Parse(text).Execute(map[string]interface{}{ "server": stringx.From(service.Name).ToCamel(), "logicName": fmt.Sprintf("%sLogic", stringx.From(rpc.Name).ToCamel()), @@ -99,6 +108,10 @@ func (g *DefaultGenerator) genFunctions(goPackage string, service parser.Service "response": fmt.Sprintf("*%s.%s", goPackage, parser.CamelCase(rpc.ReturnsType)), "hasComment": len(comment) > 0, "comment": comment, + "hasReq": !rpc.StreamsRequest, + "stream": rpc.StreamsRequest || rpc.StreamsReturns, + "notStream": !rpc.StreamsRequest && !rpc.StreamsReturns, + "streamBody": streamServer, }) if err != nil { return nil, err diff --git a/tools/goctl/rpc/generator/test.proto b/tools/goctl/rpc/generator/test.proto index 7e7d383f..c2746107 100644 --- a/tools/goctl/rpc/generator/test.proto +++ b/tools/goctl/rpc/generator/test.proto @@ -59,4 +59,10 @@ service Test_Service { rpc MapService (MapReq) returns (CommonReply); // case repeated rpc RepeatedService (RepeatedReq) returns (CommonReply); + // server stream + rpc ServerStream (Req) returns (stream Reply); + // client stream + rpc ClientStream (stream Req) returns (Reply); + // stream + rpc Stream(stream Req) returns (stream Reply); } \ No newline at end of file diff --git a/tools/goctl/rpc/parser/stream.proto b/tools/goctl/rpc/parser/stream.proto new file mode 100644 index 00000000..0db0f104 --- /dev/null +++ b/tools/goctl/rpc/parser/stream.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package test; + +message Req{ + string input = 1; +} + +message Reply{ + string output = 1; +} +service TestService{ + rpc ServerStream (Req) returns (stream Reply); + rpc ClientStream (stream Req) returns (Reply); + rpc Stream (stream Req) returns (stream Reply); +}