Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16bfb1b7be | ||
|
|
ef4d4968d6 | ||
|
|
7b4a5e3ec6 | ||
|
|
e6df21e0d2 | ||
|
|
0a2c2d1eca | ||
|
|
a5fb29a6f0 | ||
|
|
f8da301e57 | ||
|
|
cb9075b737 | ||
|
|
3f389a55c2 | ||
|
|
afbd565d87 | ||
|
|
d629acc2b7 | ||
|
|
f32c6a9b28 |
@@ -35,7 +35,7 @@ spec:
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd0:2379
|
||||
- http://etcd0.discov:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
@@ -107,7 +107,7 @@ spec:
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd1:2379
|
||||
- http://etcd1.discov:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
@@ -179,7 +179,7 @@ spec:
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd2:2379
|
||||
- http://etcd2.discov:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
@@ -251,7 +251,7 @@ spec:
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd3:2379
|
||||
- http://etcd3.discov:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
@@ -323,7 +323,7 @@ spec:
|
||||
- --listen-client-urls
|
||||
- http://0.0.0.0:2379
|
||||
- --advertise-client-urls
|
||||
- http://etcd4:2379
|
||||
- http://etcd4.discov:2379
|
||||
- --initial-cluster
|
||||
- etcd0=http://etcd0:2380,etcd1=http://etcd1:2380,etcd2=http://etcd2:2380,etcd3=http://etcd3:2380,etcd4=http://etcd4:2380
|
||||
- --initial-cluster-state
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||
)
|
||||
|
||||
const postgreDriverName = "postgres"
|
||||
const postgresDriverName = "postgres"
|
||||
|
||||
func NewPostgre(datasource string, opts ...sqlx.SqlOption) sqlx.SqlConn {
|
||||
return sqlx.NewSqlConn(postgreDriverName, datasource, opts...)
|
||||
func NewPostgres(datasource string, opts ...sqlx.SqlOption) sqlx.SqlConn {
|
||||
return sqlx.NewSqlConn(postgresDriverName, datasource, opts...)
|
||||
}
|
||||
|
||||
6
go.mod
6
go.mod
@@ -7,6 +7,7 @@ require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.4.1
|
||||
github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 // indirect
|
||||
github.com/alicebob/miniredis v2.5.0+incompatible
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect
|
||||
github.com/dchest/siphash v1.2.1
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||
github.com/emicklei/proto v1.9.0
|
||||
@@ -25,7 +26,7 @@ require (
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/gorilla/websocket v1.4.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.14.3 // indirect
|
||||
github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334
|
||||
github.com/iancoleman/strcase v0.1.2
|
||||
github.com/justinas/alice v1.2.0
|
||||
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||
@@ -40,6 +41,7 @@ require (
|
||||
github.com/pierrec/lz4 v2.5.1+incompatible // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/prometheus/client_golang v1.5.1
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0
|
||||
github.com/stretchr/testify v1.5.1
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect
|
||||
@@ -59,7 +61,7 @@ require (
|
||||
google.golang.org/protobuf v1.25.0
|
||||
gopkg.in/cheggaaa/pb.v1 v1.0.28
|
||||
gopkg.in/h2non/gock.v1 v1.0.15
|
||||
gopkg.in/yaml.v2 v2.2.8
|
||||
gopkg.in/yaml.v2 v2.3.0
|
||||
honnef.co/go/tools v0.0.1-2020.1.4 // indirect
|
||||
sigs.k8s.io/yaml v1.2.0 // indirect
|
||||
)
|
||||
|
||||
8
go.sum
8
go.sum
@@ -42,6 +42,8 @@ github.com/coreos/go-systemd/v22 v22.0.0 h1:XJIw/+VlJ+87J+doOxznsAWIdmWuViOVhkQa
|
||||
github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
@@ -147,6 +149,8 @@ github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334 h1:VHgatEHNcBFEB7inlalqfNqw65aNkM1lGX2yt3NmbS8=
|
||||
github.com/iancoleman/strcase v0.0.0-20191112232945-16388991a334/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE=
|
||||
github.com/iancoleman/strcase v0.1.2 h1:gnomlvw9tnV3ITTAxzKSgTF+8kFWcU/f+TgttpXGz1U=
|
||||
github.com/iancoleman/strcase v0.1.2/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE=
|
||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
|
||||
github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
|
||||
@@ -248,6 +252,8 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
|
||||
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/shirou/gopsutil v0.0.0-20180427012116-c95755e4bcd7/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U=
|
||||
github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc=
|
||||
@@ -448,6 +454,8 @@ gopkg.in/yaml.v2 v2.2.5 h1:ymVxjfMaHvXD8RqPRmzHHsB3VvucivSkIAvJFDI5O3c=
|
||||
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU=
|
||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
|
||||
|
||||
@@ -40,7 +40,7 @@ func genDoc(api *spec.ApiSpec, dir string, filename string) error {
|
||||
defer fp.Close()
|
||||
|
||||
var builder strings.Builder
|
||||
for index, route := range api.Service.Routes {
|
||||
for index, route := range api.Service.Routes() {
|
||||
routeComment, _ := util.GetAnnotationValue(route.Annotations, "doc", "summary")
|
||||
if len(routeComment) == 0 {
|
||||
routeComment = "N/A"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package format
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/format"
|
||||
@@ -114,7 +115,7 @@ func apiFormat(data string) (string, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
fs, err := format.Source([]byte(strings.TrimSpace(apiStruct.StructBody)))
|
||||
fs, err := format.Source([]byte(strings.TrimSpace(apiStruct.Type)))
|
||||
if err != nil {
|
||||
str := err.Error()
|
||||
lineNumber := strings.Index(str, ":")
|
||||
@@ -144,10 +145,28 @@ func apiFormat(data string) (string, error) {
|
||||
result += strings.TrimSpace(string(fs)) + "\n\n"
|
||||
}
|
||||
if len(strings.TrimSpace(apiStruct.Service)) > 0 {
|
||||
result += strings.TrimSpace(apiStruct.Service) + "\n\n"
|
||||
result += formatService(apiStruct.Service) + "\n\n"
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return strings.TrimSpace(result), nil
|
||||
}
|
||||
|
||||
func formatService(str string) string {
|
||||
var builder strings.Builder
|
||||
scanner := bufio.NewScanner(strings.NewReader(str))
|
||||
var tapCount = 0
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == ")" || line == "}" {
|
||||
tapCount -= 1
|
||||
}
|
||||
util.WriteIndent(&builder, tapCount)
|
||||
builder.WriteString(line + "\n")
|
||||
if strings.HasSuffix(line, "(") || strings.HasSuffix(line, "{") {
|
||||
tapCount += 1
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(builder.String())
|
||||
}
|
||||
|
||||
func countRune(s string, r rune) int {
|
||||
|
||||
47
tools/goctl/api/format/format_test.go
Normal file
47
tools/goctl/api/format/format_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package format
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
notFormattedStr = `
|
||||
type Request struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
service A-api {
|
||||
@server(
|
||||
handler: GreetHandler
|
||||
)
|
||||
get /greet/from/:name(Request) returns (Response)
|
||||
}
|
||||
`
|
||||
|
||||
formattedStr = `type Request struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
service A-api {
|
||||
@server(
|
||||
handler: GreetHandler
|
||||
)
|
||||
get /greet/from/:name(Request) returns (Response)
|
||||
}`
|
||||
)
|
||||
|
||||
func TestInlineTypeNotExist(t *testing.T) {
|
||||
r, err := apiFormat(notFormattedStr)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, r, formattedStr)
|
||||
}
|
||||
@@ -28,7 +28,6 @@ var tmpDir = path.Join(os.TempDir(), "goctl")
|
||||
func GoCommand(c *cli.Context) error {
|
||||
apiFile := c.String("api")
|
||||
dir := c.String("dir")
|
||||
force := c.Bool("force")
|
||||
if len(apiFile) == 0 {
|
||||
return errors.New("missing -api")
|
||||
}
|
||||
@@ -36,10 +35,10 @@ func GoCommand(c *cli.Context) error {
|
||||
return errors.New("missing -dir")
|
||||
}
|
||||
|
||||
return DoGenProject(apiFile, dir, force)
|
||||
return DoGenProject(apiFile, dir)
|
||||
}
|
||||
|
||||
func DoGenProject(apiFile, dir string, force bool) error {
|
||||
func DoGenProject(apiFile, dir string) error {
|
||||
p, err := parser.NewParser(apiFile)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -54,9 +53,9 @@ func DoGenProject(apiFile, dir string, force bool) error {
|
||||
logx.Must(genConfig(dir, api))
|
||||
logx.Must(genMain(dir, api))
|
||||
logx.Must(genServiceContext(dir, api))
|
||||
logx.Must(genTypes(dir, api, force))
|
||||
logx.Must(genTypes(dir, api))
|
||||
logx.Must(genHandlers(dir, api))
|
||||
logx.Must(genRoutes(dir, api, force))
|
||||
logx.Must(genRoutes(dir, api))
|
||||
logx.Must(genLogic(dir, api))
|
||||
|
||||
if err := backupAndSweep(apiFile); err != nil {
|
||||
|
||||
@@ -23,7 +23,7 @@ info(
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
Name string ` + "`" + `path:"name,options=you|me"` + "`" + `
|
||||
Name string ` + "`" + `path:"name,options=you|me"` + "`" + ` // }
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
@@ -31,19 +31,20 @@ type Response struct {
|
||||
}
|
||||
|
||||
@server(
|
||||
group: greet
|
||||
// C0
|
||||
group: greet/s1
|
||||
)
|
||||
// C1
|
||||
service A-api {
|
||||
@server(
|
||||
// C2
|
||||
@server( // C3
|
||||
handler: GreetHandler
|
||||
)
|
||||
get /greet/from/:name(Request) returns (Response)
|
||||
|
||||
@server(
|
||||
handler: NoResponseHandler
|
||||
|
||||
)
|
||||
get /greet/get(Request) returns
|
||||
get /greet/from/:name(Request) returns (Response) // hello
|
||||
|
||||
// C4
|
||||
@handler NoResponseHandler // C5
|
||||
get /greet/get(Request)
|
||||
}
|
||||
`
|
||||
|
||||
@@ -291,13 +292,13 @@ func TestParser(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, len(api.Types), 2)
|
||||
assert.Equal(t, len(api.Service.Routes), 2)
|
||||
assert.Equal(t, len(api.Service.Routes()), 2)
|
||||
|
||||
assert.Equal(t, api.Service.Routes[0].Path, "/greet/from/:name")
|
||||
assert.Equal(t, api.Service.Routes[1].Path, "/greet/get")
|
||||
assert.Equal(t, api.Service.Routes()[0].Path, "/greet/from/:name")
|
||||
assert.Equal(t, api.Service.Routes()[1].Path, "/greet/get")
|
||||
|
||||
assert.Equal(t, api.Service.Routes[1].RequestType.Name, "Request")
|
||||
assert.Equal(t, api.Service.Routes[1].ResponseType.Name, "")
|
||||
assert.Equal(t, api.Service.Routes()[1].RequestType.Name, "Request")
|
||||
assert.Equal(t, api.Service.Routes()[1].ResponseType.Name, "")
|
||||
|
||||
validate(t, filename)
|
||||
}
|
||||
@@ -314,7 +315,7 @@ func TestMultiService(t *testing.T) {
|
||||
api, err := parser.Parse()
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, len(api.Service.Routes), 2)
|
||||
assert.Equal(t, len(api.Service.Routes()), 2)
|
||||
assert.Equal(t, len(api.Service.Groups), 2)
|
||||
|
||||
validate(t, filename)
|
||||
@@ -341,10 +342,7 @@ func TestInvalidApiFile(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
|
||||
parser, err := parser.NewParser(filename)
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = parser.Parse()
|
||||
_, err = parser.NewParser(filename)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
@@ -360,8 +358,8 @@ func TestAnonymousAnnotation(t *testing.T) {
|
||||
api, err := parser.Parse()
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, len(api.Service.Routes), 1)
|
||||
assert.Equal(t, api.Service.Routes[0].Annotations[0].Value, "GreetHandler")
|
||||
assert.Equal(t, len(api.Service.Routes()), 1)
|
||||
assert.Equal(t, api.Service.Routes()[0].Annotations[0].Value, "GreetHandler")
|
||||
|
||||
validate(t, filename)
|
||||
}
|
||||
@@ -501,7 +499,8 @@ func TestHasImportApi(t *testing.T) {
|
||||
|
||||
func validate(t *testing.T, api string) {
|
||||
dir := "_go"
|
||||
err := DoGenProject(api, dir, true)
|
||||
os.RemoveAll(dir)
|
||||
err := DoGenProject(api, dir)
|
||||
defer os.RemoveAll(dir)
|
||||
assert.Nil(t, err)
|
||||
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
|
||||
@@ -31,11 +31,11 @@ func genEtc(dir string, api *spec.ApiSpec) error {
|
||||
defer fp.Close()
|
||||
|
||||
service := api.Service
|
||||
host, ok := util.GetAnnotationValue(service.Annotations, "server", "host")
|
||||
host, ok := util.GetAnnotationValue(service.Groups[0].Annotations, "server", "host")
|
||||
if !ok {
|
||||
host = "0.0.0.0"
|
||||
}
|
||||
port, ok := util.GetAnnotationValue(service.Annotations, "server", "port")
|
||||
port, ok := util.GetAnnotationValue(service.Groups[0].Annotations, "server", "port")
|
||||
if !ok {
|
||||
port = strconv.Itoa(defaultPort)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package gogen
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -61,7 +62,7 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func genRoutes(dir string, api *spec.ApiSpec, force bool) error {
|
||||
func genRoutes(dir string, api *spec.ApiSpec) error {
|
||||
var builder strings.Builder
|
||||
groups, err := getRoutes(api)
|
||||
if err != nil {
|
||||
@@ -121,11 +122,7 @@ func genRoutes(dir string, api *spec.ApiSpec, force bool) error {
|
||||
}
|
||||
|
||||
filename := path.Join(dir, handlerDir, routesFilename)
|
||||
if !force {
|
||||
if err := util.RemoveOrQuit(filename); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
os.Remove(filename)
|
||||
|
||||
fp, created, err := apiutil.MaybeCreateFile(dir, handlerDir, routesFilename)
|
||||
if err != nil {
|
||||
@@ -163,8 +160,7 @@ func genRouteImports(parentPkg string, api *spec.ApiSpec) string {
|
||||
continue
|
||||
}
|
||||
}
|
||||
importSet.AddStr(fmt.Sprintf("%s \"%s\"", folder,
|
||||
util.JoinPackages(parentPkg, handlerDir, folder)))
|
||||
importSet.AddStr(fmt.Sprintf("%s \"%s\"", toPrefix(folder), util.JoinPackages(parentPkg, handlerDir, folder)))
|
||||
}
|
||||
}
|
||||
imports := importSet.KeysStr()
|
||||
@@ -187,11 +183,11 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
|
||||
handler = getHandlerBaseName(handler) + "Handler(serverCtx)"
|
||||
folder, ok := apiutil.GetAnnotationValue(r.Annotations, "server", groupProperty)
|
||||
if ok {
|
||||
handler = folder + "." + strings.ToUpper(handler[:1]) + handler[1:]
|
||||
handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:]
|
||||
} else {
|
||||
folder, ok = apiutil.GetAnnotationValue(g.Annotations, "server", groupProperty)
|
||||
if ok {
|
||||
handler = folder + "." + strings.ToUpper(handler[:1]) + handler[1:]
|
||||
handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:]
|
||||
}
|
||||
}
|
||||
groupedRoutes.routes = append(groupedRoutes.routes, route{
|
||||
@@ -215,3 +211,7 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func toPrefix(folder string) string {
|
||||
return strings.ReplaceAll(folder, "/", "")
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"text/template"
|
||||
@@ -42,18 +43,14 @@ func BuildTypes(types []spec.Type) (string, error) {
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func genTypes(dir string, api *spec.ApiSpec, force bool) error {
|
||||
func genTypes(dir string, api *spec.ApiSpec) error {
|
||||
val, err := BuildTypes(api.Types)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filename := path.Join(dir, typesDir, typesFile)
|
||||
if !force {
|
||||
if err := util.RemoveOrQuit(filename); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
os.Remove(filename)
|
||||
|
||||
fp, created, err := apiutil.MaybeCreateFile(dir, typesDir, typesFile)
|
||||
if err != nil {
|
||||
|
||||
@@ -26,14 +26,8 @@ func getParentPackage(dir string) (string, error) {
|
||||
return filepath.ToSlash(filepath.Join(projectCtx.Path, strings.TrimPrefix(projectCtx.WorkDir, projectCtx.Dir))), nil
|
||||
}
|
||||
|
||||
func writeIndent(writer io.Writer, indent int) {
|
||||
for i := 0; i < indent; i++ {
|
||||
fmt.Fprint(writer, "\t")
|
||||
}
|
||||
}
|
||||
|
||||
func writeProperty(writer io.Writer, name, tp, tag, comment string, indent int) error {
|
||||
writeIndent(writer, indent)
|
||||
util.WriteIndent(writer, indent)
|
||||
var err error
|
||||
if len(comment) > 0 {
|
||||
comment = strings.TrimPrefix(comment, "//")
|
||||
|
||||
@@ -77,7 +77,7 @@ public class {{.packetName}} extends HttpRequestPacket<{{.packetName}}.{{.packet
|
||||
`
|
||||
|
||||
func genPacket(dir, packetName string, api *spec.ApiSpec) error {
|
||||
for _, route := range api.Service.Routes {
|
||||
for _, route := range api.Service.Routes() {
|
||||
if err := createWith(dir, api, route, packetName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -27,9 +27,9 @@ service {{.name}}-api {
|
||||
|
||||
func NewService(c *cli.Context) error {
|
||||
args := c.Args()
|
||||
dirName := "greet"
|
||||
if len(args) > 0 {
|
||||
dirName = args.First()
|
||||
dirName := args.First()
|
||||
if len(dirName) == 0 {
|
||||
dirName = "greet"
|
||||
}
|
||||
|
||||
abs, err := filepath.Abs(dirName)
|
||||
@@ -58,6 +58,6 @@ func NewService(c *cli.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = gogen.DoGenProject(apiFilePath, abs, true)
|
||||
err = gogen.DoGenProject(apiFilePath, abs)
|
||||
return err
|
||||
}
|
||||
|
||||
219
tools/goctl/api/parser/apifileparser.go
Normal file
219
tools/goctl/api/parser/apifileparser.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenInfo = "info"
|
||||
tokenImport = "import"
|
||||
tokenType = "type"
|
||||
tokenService = "service"
|
||||
tokenServiceAnnotation = "@server"
|
||||
)
|
||||
|
||||
type (
|
||||
ApiStruct struct {
|
||||
Info string
|
||||
Type string
|
||||
Service string
|
||||
Imports string
|
||||
serviceBeginLine int
|
||||
}
|
||||
|
||||
apiFileState interface {
|
||||
process(api *ApiStruct, token string) (apiFileState, error)
|
||||
}
|
||||
|
||||
apiRootState struct {
|
||||
*baseState
|
||||
}
|
||||
|
||||
apiInfoState struct {
|
||||
*baseState
|
||||
}
|
||||
|
||||
apiImportState struct {
|
||||
*baseState
|
||||
}
|
||||
|
||||
apiTypeState struct {
|
||||
*baseState
|
||||
}
|
||||
|
||||
apiServiceState struct {
|
||||
*baseState
|
||||
}
|
||||
)
|
||||
|
||||
func ParseApi(src string) (*ApiStruct, error) {
|
||||
var buffer = new(bytes.Buffer)
|
||||
buffer.WriteString(src)
|
||||
api := new(ApiStruct)
|
||||
var lineNumber = api.serviceBeginLine
|
||||
apiFile := baseState{r: bufio.NewReader(buffer), lineNumber: &lineNumber}
|
||||
st := apiRootState{&apiFile}
|
||||
for {
|
||||
st, err := st.process(api, "")
|
||||
if err == io.EOF {
|
||||
return api, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("near line: %d, %s", lineNumber, err.Error())
|
||||
}
|
||||
if st == nil {
|
||||
return api, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *apiRootState) process(api *ApiStruct, token string) (apiFileState, error) {
|
||||
var builder strings.Builder
|
||||
for {
|
||||
ch, err := s.readSkipComment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case isSpace(ch) || isNewline(ch) || ch == leftParenthesis:
|
||||
token := builder.String()
|
||||
token = strings.TrimSpace(token)
|
||||
if len(token) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
builder.Reset()
|
||||
switch token {
|
||||
case tokenInfo:
|
||||
info := apiInfoState{s.baseState}
|
||||
return info.process(api, token+string(ch))
|
||||
case tokenImport:
|
||||
tp := apiImportState{s.baseState}
|
||||
return tp.process(api, token+string(ch))
|
||||
case tokenType:
|
||||
ty := apiTypeState{s.baseState}
|
||||
return ty.process(api, token+string(ch))
|
||||
case tokenService:
|
||||
server := apiServiceState{s.baseState}
|
||||
return server.process(api, token+string(ch))
|
||||
case tokenServiceAnnotation:
|
||||
server := apiServiceState{s.baseState}
|
||||
return server.process(api, token+string(ch))
|
||||
default:
|
||||
if strings.HasPrefix(token, "//") {
|
||||
continue
|
||||
}
|
||||
return nil, errors.New(fmt.Sprintf("invalid token %s at line %d", token, *s.lineNumber))
|
||||
}
|
||||
default:
|
||||
builder.WriteRune(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *apiInfoState) process(api *ApiStruct, token string) (apiFileState, error) {
|
||||
for {
|
||||
line, err := s.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.Info += "\n" + token + line
|
||||
token = ""
|
||||
if strings.TrimSpace(line) == string(rightParenthesis) {
|
||||
return &apiRootState{s.baseState}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *apiImportState) process(api *ApiStruct, token string) (apiFileState, error) {
|
||||
line, err := s.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = token + line
|
||||
if len(strings.Fields(line)) != 2 {
|
||||
return nil, errors.New("import syntax error: " + line)
|
||||
}
|
||||
|
||||
api.Imports += "\n" + line
|
||||
return &apiRootState{s.baseState}, nil
|
||||
}
|
||||
|
||||
func (s *apiTypeState) process(api *ApiStruct, token string) (apiFileState, error) {
|
||||
var blockCount = 0
|
||||
for {
|
||||
line, err := s.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.Type += "\n\n" + token + line
|
||||
token = ""
|
||||
line = strings.TrimSpace(line)
|
||||
line = removeComment(line)
|
||||
if strings.HasSuffix(line, leftBrace) {
|
||||
blockCount++
|
||||
}
|
||||
if strings.HasSuffix(line, string(leftParenthesis)) {
|
||||
blockCount++
|
||||
}
|
||||
if strings.HasSuffix(line, string(rightBrace)) {
|
||||
blockCount--
|
||||
}
|
||||
if strings.HasSuffix(line, string(rightParenthesis)) {
|
||||
blockCount--
|
||||
}
|
||||
|
||||
if blockCount == 0 {
|
||||
return &apiRootState{s.baseState}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *apiServiceState) process(api *ApiStruct, token string) (apiFileState, error) {
|
||||
var blockCount = 0
|
||||
for {
|
||||
line, err := s.readLineSkipComment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = token + line
|
||||
token = ""
|
||||
api.Service += "\n" + line
|
||||
line = strings.TrimSpace(line)
|
||||
line = removeComment(line)
|
||||
if strings.HasSuffix(line, leftBrace) {
|
||||
blockCount++
|
||||
}
|
||||
if strings.HasSuffix(line, string(leftParenthesis)) {
|
||||
blockCount++
|
||||
}
|
||||
if line == string(rightBrace) {
|
||||
blockCount--
|
||||
}
|
||||
if line == string(rightParenthesis) {
|
||||
blockCount--
|
||||
}
|
||||
|
||||
if blockCount == 0 {
|
||||
return &apiRootState{s.baseState}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func removeComment(line string) string {
|
||||
var commentIdx = strings.Index(line, "//")
|
||||
if commentIdx >= 0 {
|
||||
return line[:commentIdx]
|
||||
}
|
||||
return line
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func (s *baseState) parseProperties() (map[string]string, error) {
|
||||
var st = startState
|
||||
|
||||
for {
|
||||
ch, err := s.read()
|
||||
ch, err := s.readSkipComment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -164,6 +164,60 @@ func (s *baseState) read() (rune, error) {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (s *baseState) readSkipComment() (rune, error) {
|
||||
ch, err := s.read()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if isSlash(ch) {
|
||||
value, err := s.mayReadToEndOfLine()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if value > 0 {
|
||||
ch = value
|
||||
}
|
||||
}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (s *baseState) mayReadToEndOfLine() (rune, error) {
|
||||
ch, err := s.read()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if isSlash(ch) {
|
||||
for {
|
||||
value, err := s.read()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if isNewline(value) {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
err = s.unread()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
func (s *baseState) readLineSkipComment() (string, error) {
|
||||
line, err := s.readLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var commentIdx = strings.Index(line, "//")
|
||||
if commentIdx >= 0 {
|
||||
return line[:commentIdx], nil
|
||||
}
|
||||
return line, nil
|
||||
}
|
||||
|
||||
func (s *baseState) readLine() (string, error) {
|
||||
line, _, err := s.r.ReadLine()
|
||||
if err != nil {
|
||||
|
||||
@@ -30,7 +30,7 @@ func newEntity(state *baseState, api *spec.ApiSpec, parser entityParser) entity
|
||||
}
|
||||
|
||||
func (s *entity) process() error {
|
||||
line, err := s.state.readLine()
|
||||
line, err := s.state.readLineSkipComment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,7 +59,7 @@ func (s *entity) process() error {
|
||||
var annos []spec.Annotation
|
||||
memberLoop:
|
||||
for {
|
||||
ch, err := s.state.read()
|
||||
ch, err := s.state.readSkipComment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -70,13 +70,13 @@ memberLoop:
|
||||
case ch == at:
|
||||
annotationLoop:
|
||||
for {
|
||||
next, err := s.state.read()
|
||||
next, err := s.state.readSkipComment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch {
|
||||
case isSpace(next):
|
||||
if builder.Len() > 0 {
|
||||
if builder.Len() > 0 && annoName == "" {
|
||||
annoName = builder.String()
|
||||
builder.Reset()
|
||||
}
|
||||
@@ -84,6 +84,7 @@ memberLoop:
|
||||
if builder.Len() == 0 {
|
||||
return errors.New("invalid annotation format")
|
||||
}
|
||||
|
||||
if len(annoName) > 0 {
|
||||
value := builder.String()
|
||||
if value != string(leftParenthesis) {
|
||||
@@ -127,7 +128,7 @@ memberLoop:
|
||||
}
|
||||
|
||||
var line string
|
||||
line, err = s.state.readLine()
|
||||
line, err = s.state.readLineSkipComment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package parser
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@@ -34,10 +35,11 @@ func NewParser(filename string) (*Parser, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, item := range strings.Split(apiStruct.Imports, "\n") {
|
||||
ip := strings.TrimSpace(item)
|
||||
if len(ip) > 0 {
|
||||
item := strings.TrimPrefix(item, "import")
|
||||
importLine := strings.TrimSpace(item)
|
||||
if len(importLine) > 0 {
|
||||
item := strings.TrimPrefix(importLine, "import")
|
||||
item = strings.TrimSpace(item)
|
||||
item = strings.TrimPrefix(item, `"`)
|
||||
item = strings.TrimSuffix(item, `"`)
|
||||
@@ -46,18 +48,33 @@ func NewParser(filename string) (*Parser, error) {
|
||||
path = filepath.Join(filepath.Dir(apiAbsPath), item)
|
||||
}
|
||||
content, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, errors.New("import api file not exist: " + item)
|
||||
}
|
||||
|
||||
importStruct, err := ParseApi(string(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
apiStruct.StructBody += "\n" + string(content)
|
||||
|
||||
if len(importStruct.Imports) > 0 {
|
||||
return nil, errors.New("import api should not import another api file recursive")
|
||||
}
|
||||
|
||||
apiStruct.Type += "\n" + importStruct.Type
|
||||
apiStruct.Service += "\n" + importStruct.Service
|
||||
}
|
||||
}
|
||||
|
||||
if len(strings.TrimSpace(apiStruct.Service)) == 0 {
|
||||
return nil, errors.New("api has no service defined")
|
||||
}
|
||||
|
||||
var buffer = new(bytes.Buffer)
|
||||
buffer.WriteString(apiStruct.Service)
|
||||
return &Parser{
|
||||
r: bufio.NewReader(buffer),
|
||||
typeDef: apiStruct.StructBody,
|
||||
typeDef: apiStruct.Type,
|
||||
api: apiStruct,
|
||||
}, nil
|
||||
}
|
||||
@@ -69,6 +86,7 @@ func (p *Parser) Parse() (api *spec.ApiSpec, err error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.Types = types
|
||||
var lineNumber = p.api.serviceBeginLine
|
||||
st := newRootState(p.r, &lineNumber)
|
||||
|
||||
@@ -23,7 +23,7 @@ func (s rootState) process(api *spec.ApiSpec) (state, error) {
|
||||
var annos []spec.Annotation
|
||||
var builder strings.Builder
|
||||
for {
|
||||
ch, err := s.read()
|
||||
ch, err := s.readSkipComment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -33,6 +33,7 @@ func (s rootState) process(api *spec.ApiSpec) (state, error) {
|
||||
if builder.Len() == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
token := builder.String()
|
||||
builder.Reset()
|
||||
return s.processToken(token, annos)
|
||||
@@ -44,10 +45,11 @@ func (s rootState) process(api *spec.ApiSpec) (state, error) {
|
||||
var annoName string
|
||||
annoLoop:
|
||||
for {
|
||||
next, err := s.read()
|
||||
next, err := s.readSkipComment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case isSpace(next):
|
||||
if builder.Len() > 0 {
|
||||
@@ -58,6 +60,7 @@ func (s rootState) process(api *spec.ApiSpec) (state, error) {
|
||||
if err := s.unread(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if builder.Len() > 0 {
|
||||
annoName = builder.String()
|
||||
builder.Reset()
|
||||
@@ -66,6 +69,7 @@ func (s rootState) process(api *spec.ApiSpec) (state, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
annos = append(annos, spec.Annotation{
|
||||
Name: annoName,
|
||||
Properties: attrs,
|
||||
@@ -79,9 +83,11 @@ func (s rootState) process(api *spec.ApiSpec) (state, error) {
|
||||
if builder.Len() == 0 {
|
||||
return nil, fmt.Errorf("incorrect %q at the beginning of the line", leftParenthesis)
|
||||
}
|
||||
|
||||
if err := s.unread(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token := builder.String()
|
||||
builder.Reset()
|
||||
return s.processToken(token, annos)
|
||||
|
||||
@@ -40,9 +40,7 @@ func (s *serviceState) process(api *spec.ApiSpec) (state, error) {
|
||||
}
|
||||
|
||||
api.Service = spec.Service{
|
||||
Name: name,
|
||||
Annotations: append(api.Service.Annotations, s.annos...),
|
||||
Routes: append(api.Service.Routes, routes...),
|
||||
Name: name,
|
||||
Groups: append(api.Service.Groups, spec.Group{
|
||||
Annotations: s.annos,
|
||||
Routes: routes,
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
)
|
||||
|
||||
type typeState struct {
|
||||
*baseState
|
||||
annos []spec.Annotation
|
||||
}
|
||||
|
||||
func newTypeState(state *baseState, annos []spec.Annotation) state {
|
||||
return &typeState{
|
||||
baseState: state,
|
||||
annos: annos,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *typeState) process(api *spec.ApiSpec) (state, error) {
|
||||
var name string
|
||||
var members []spec.Member
|
||||
parser := &typeEntityParser{
|
||||
acceptName: func(n string) {
|
||||
name = n
|
||||
},
|
||||
acceptMember: func(member spec.Member) {
|
||||
members = append(members, member)
|
||||
},
|
||||
}
|
||||
ent := newEntity(s.baseState, api, parser)
|
||||
if err := ent.process(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.Types = append(api.Types, spec.Type{
|
||||
Name: name,
|
||||
Annotations: s.annos,
|
||||
Members: members,
|
||||
})
|
||||
|
||||
return newRootState(s.r, s.lineNumber), nil
|
||||
}
|
||||
|
||||
type typeEntityParser struct {
|
||||
acceptName func(name string)
|
||||
acceptMember func(member spec.Member)
|
||||
}
|
||||
|
||||
func (p *typeEntityParser) parseLine(line string, api *spec.ApiSpec, annos []spec.Annotation) error {
|
||||
index := strings.Index(line, "//")
|
||||
comment := ""
|
||||
if index >= 0 {
|
||||
comment = line[index+2:]
|
||||
line = strings.TrimSpace(line[:index])
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(fields) == 1 {
|
||||
p.acceptMember(spec.Member{
|
||||
Annotations: annos,
|
||||
Name: fields[0],
|
||||
Type: fields[0],
|
||||
IsInline: true,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
name := fields[0]
|
||||
tp := fields[1]
|
||||
var tag string
|
||||
if len(fields) > 2 {
|
||||
tag = fields[2]
|
||||
} else {
|
||||
tag = fmt.Sprintf("`json:\"%s\"`", util.Untitle(name))
|
||||
}
|
||||
|
||||
p.acceptMember(spec.Member{
|
||||
Annotations: annos,
|
||||
Name: name,
|
||||
Type: tp,
|
||||
Tag: tag,
|
||||
Comment: comment,
|
||||
IsInline: false,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *typeEntityParser) setEntityName(name string) {
|
||||
p.acceptName(name)
|
||||
}
|
||||
@@ -2,22 +2,12 @@ package parser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
|
||||
)
|
||||
|
||||
var emptyType spec.Type
|
||||
|
||||
type ApiStruct struct {
|
||||
Info string
|
||||
StructBody string
|
||||
Service string
|
||||
Imports string
|
||||
serviceBeginLine int
|
||||
}
|
||||
|
||||
func GetType(api *spec.ApiSpec, t string) spec.Type {
|
||||
for _, tp := range api.Types {
|
||||
if tp.Name == t {
|
||||
@@ -36,6 +26,10 @@ func isSpace(r rune) bool {
|
||||
return r == ' ' || r == '\t'
|
||||
}
|
||||
|
||||
func isSlash(r rune) bool {
|
||||
return r == '/'
|
||||
}
|
||||
|
||||
func isNewline(r rune) bool {
|
||||
return r == '\n' || r == '\r'
|
||||
}
|
||||
@@ -69,82 +63,3 @@ func skipSpaces(r *bufio.Reader) error {
|
||||
func unread(r *bufio.Reader) error {
|
||||
return r.UnreadRune()
|
||||
}
|
||||
|
||||
func ParseApi(api string) (*ApiStruct, error) {
|
||||
var result ApiStruct
|
||||
scanner := bufio.NewScanner(strings.NewReader(api))
|
||||
var parseInfo = false
|
||||
var parseImport = false
|
||||
var parseType = false
|
||||
var parseService = false
|
||||
var segment string
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if line == "info(" {
|
||||
parseInfo = true
|
||||
}
|
||||
if line == ")" && parseInfo {
|
||||
parseInfo = false
|
||||
result.Info = segment + ")"
|
||||
segment = ""
|
||||
continue
|
||||
}
|
||||
|
||||
if isImportBeginLine(line) {
|
||||
parseImport = true
|
||||
}
|
||||
if parseImport && (isTypeBeginLine(line) || isServiceBeginLine(line)) {
|
||||
parseImport = false
|
||||
result.Imports = segment
|
||||
segment = line + "\n"
|
||||
continue
|
||||
}
|
||||
|
||||
if isTypeBeginLine(line) {
|
||||
parseType = true
|
||||
}
|
||||
if isServiceBeginLine(line) {
|
||||
parseService = true
|
||||
if parseType {
|
||||
parseType = false
|
||||
result.StructBody = segment
|
||||
segment = line + "\n"
|
||||
continue
|
||||
}
|
||||
}
|
||||
segment += scanner.Text() + "\n"
|
||||
}
|
||||
|
||||
if !parseService {
|
||||
return nil, errors.New("no service defined")
|
||||
}
|
||||
result.Service = segment
|
||||
result.serviceBeginLine = lineBeginOfService(api)
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func isImportBeginLine(line string) bool {
|
||||
return strings.HasPrefix(line, "import") && (strings.HasSuffix(line, ".api") || strings.HasSuffix(line, `.api"`))
|
||||
}
|
||||
|
||||
func isTypeBeginLine(line string) bool {
|
||||
return strings.HasPrefix(line, "type")
|
||||
}
|
||||
|
||||
func isServiceBeginLine(line string) bool {
|
||||
return strings.HasPrefix(line, "@server") || (strings.HasPrefix(line, "service") && strings.HasSuffix(line, "{"))
|
||||
}
|
||||
|
||||
func lineBeginOfService(api string) int {
|
||||
scanner := bufio.NewScanner(strings.NewReader(api))
|
||||
var number = 0
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if isServiceBeginLine(line) {
|
||||
break
|
||||
}
|
||||
number++
|
||||
}
|
||||
return number
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func (p *Parser) validateDuplicateProperty(tp spec.Type) (bool, string) {
|
||||
|
||||
func (p *Parser) validateDuplicateRouteHandler(api *spec.ApiSpec) (bool, string) {
|
||||
var names []string
|
||||
for _, r := range api.Service.Routes {
|
||||
for _, r := range api.Service.Routes() {
|
||||
handler, ok := util.GetAnnotationValue(r.Annotations, "server", "handler")
|
||||
if !ok {
|
||||
return false, fmt.Sprintf("missing handler annotation for %s", r.Path)
|
||||
|
||||
@@ -27,6 +27,14 @@ type Attribute struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (s Service) Routes() []Route {
|
||||
var result []Route
|
||||
for _, group := range s.Groups {
|
||||
result = append(result, group.Routes...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m Member) IsOptional() bool {
|
||||
var option string
|
||||
|
||||
|
||||
@@ -57,10 +57,8 @@ type (
|
||||
}
|
||||
|
||||
Service struct {
|
||||
Name string
|
||||
Annotations []Annotation
|
||||
Routes []Route
|
||||
Groups []Group
|
||||
Name string
|
||||
Groups []Group
|
||||
}
|
||||
|
||||
Type struct {
|
||||
|
||||
@@ -36,7 +36,7 @@ func genHandler(dir, webApi, caller string, api *spec.ApiSpec, unwrapApi bool) e
|
||||
defer fp.Close()
|
||||
|
||||
var localTypes []spec.Type
|
||||
for _, route := range api.Service.Routes {
|
||||
for _, route := range api.Service.Routes() {
|
||||
rts := apiutil.GetLocalTypes(api, route)
|
||||
localTypes = append(localTypes, rts...)
|
||||
}
|
||||
@@ -121,7 +121,7 @@ func genTypes(localTypes []spec.Type, inlineType func(string) (*spec.Type, error
|
||||
|
||||
func genApi(api *spec.ApiSpec, localTypes []spec.Type, caller string, prefixForType func(string) string) (string, error) {
|
||||
var builder strings.Builder
|
||||
for _, route := range api.Service.Routes {
|
||||
for _, route := range api.Service.Routes() {
|
||||
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
|
||||
if !ok {
|
||||
return "", fmt.Errorf("missing handler annotation for route %q", route.Path)
|
||||
|
||||
@@ -130,7 +130,7 @@ func GetSharedTypes(api *spec.ApiSpec) []spec.Type {
|
||||
}
|
||||
return false
|
||||
}
|
||||
for _, route := range api.Service.Routes {
|
||||
for _, route := range api.Service.Routes() {
|
||||
var rts []spec.Type
|
||||
getTypeRecursive(route.RequestType, types, &rts)
|
||||
getTypeRecursive(route.ResponseType, types, &rts)
|
||||
|
||||
@@ -75,3 +75,9 @@ func ComponentName(api *spec.ApiSpec) string {
|
||||
}
|
||||
return name + "Components"
|
||||
}
|
||||
|
||||
func WriteIndent(writer io.Writer, indent int) {
|
||||
for i := 0; i < indent; i++ {
|
||||
fmt.Fprint(writer, "\t")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/gen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
@@ -26,7 +28,7 @@ func DockerCommand(c *cli.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return gen.GenerateDockerfile(goFile, "-f", "etc/"+cfg)
|
||||
return generateDockerfile(goFile, "-f", "etc/"+cfg)
|
||||
}
|
||||
|
||||
func findConfig(file, dir string) (string, error) {
|
||||
@@ -57,3 +59,56 @@ func findConfig(file, dir string) (string, error) {
|
||||
|
||||
return files[0], nil
|
||||
}
|
||||
|
||||
func generateDockerfile(goFile string, args ...string) error {
|
||||
projPath, err := getFilePath(filepath.Dir(goFile))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pos := strings.IndexByte(projPath, '/')
|
||||
if pos >= 0 {
|
||||
projPath = projPath[pos+1:]
|
||||
}
|
||||
|
||||
out, err := util.CreateIfNotExist("Dockerfile")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
text, err := ctlutil.LoadTemplate(category, dockerTemplateFile, dockerTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
for _, arg := range args {
|
||||
builder.WriteString(`, "` + arg + `"`)
|
||||
}
|
||||
|
||||
t := template.Must(template.New("dockerfile").Parse(text))
|
||||
return t.Execute(out, map[string]string{
|
||||
"goRelPath": projPath,
|
||||
"goFile": goFile,
|
||||
"exeFile": util.FileNameWithoutExt(filepath.Base(goFile)),
|
||||
"argument": builder.String(),
|
||||
})
|
||||
}
|
||||
|
||||
func getFilePath(file string) (string, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
projPath, ok := util.FindGoModPath(filepath.Join(wd, file))
|
||||
if !ok {
|
||||
projPath, err = util.PathFromGoSrc()
|
||||
if err != nil {
|
||||
return "", errors.New("no go.mod found, or not in GOPATH")
|
||||
}
|
||||
}
|
||||
|
||||
return projPath, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
package gen
|
||||
package docker
|
||||
|
||||
const dockerTemplate = `FROM golang:alpine AS builder
|
||||
import (
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
const (
|
||||
category = "docker"
|
||||
dockerTemplateFile = "docker.tpl"
|
||||
dockerTemplate = `FROM golang:alpine AS builder
|
||||
|
||||
LABEL stage=gobuilder
|
||||
|
||||
@@ -27,3 +35,10 @@ COPY --from=builder /app/etc /app/etc
|
||||
|
||||
CMD ["./{{.exeFile}}"{{.argument}}]
|
||||
`
|
||||
)
|
||||
|
||||
func GenTemplates(_ *cli.Context) error {
|
||||
return util.InitTemplates(category, map[string]string{
|
||||
dockerTemplateFile: dockerTemplate,
|
||||
})
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
)
|
||||
|
||||
func GenerateDockerfile(goFile string, args ...string) error {
|
||||
projPath, err := getFilePath(filepath.Dir(goFile))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pos := strings.IndexByte(projPath, '/')
|
||||
if pos >= 0 {
|
||||
projPath = projPath[pos+1:]
|
||||
}
|
||||
|
||||
out, err := util.CreateIfNotExist("Dockerfile")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
var builder strings.Builder
|
||||
for _, arg := range args {
|
||||
builder.WriteString(`, "` + arg + `"`)
|
||||
}
|
||||
|
||||
t := template.Must(template.New("dockerfile").Parse(dockerTemplate))
|
||||
return t.Execute(out, map[string]string{
|
||||
"goRelPath": projPath,
|
||||
"goFile": goFile,
|
||||
"exeFile": util.FileNameWithoutExt(filepath.Base(goFile)),
|
||||
"argument": builder.String(),
|
||||
})
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package gen
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
)
|
||||
|
||||
func getFilePath(file string) (string, error) {
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
projPath, ok := util.FindGoModPath(filepath.Join(wd, file))
|
||||
if !ok {
|
||||
projPath, err = util.PathFromGoSrc()
|
||||
if err != nil {
|
||||
return "", errors.New("no go.mod found, or not in GOPATH")
|
||||
}
|
||||
}
|
||||
|
||||
return projPath, nil
|
||||
}
|
||||
@@ -98,10 +98,6 @@ var (
|
||||
Name: "api",
|
||||
Usage: "the api file",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "force",
|
||||
Usage: "force override the exist files",
|
||||
},
|
||||
},
|
||||
Action: gogen.GoCommand,
|
||||
},
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/logrusorgru/aurora"
|
||||
"github.com/tal-tech/go-zero/core/errorx"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/gogen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/docker"
|
||||
modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
|
||||
rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
@@ -25,6 +26,9 @@ func GenTemplates(ctx *cli.Context) error {
|
||||
func() error {
|
||||
return rpcgen.GenTemplates(ctx)
|
||||
},
|
||||
func() error {
|
||||
return docker.GenTemplates(ctx)
|
||||
},
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,83 +1,11 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
import "google.golang.org/grpc"
|
||||
|
||||
func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption {
|
||||
return grpc.WithStreamInterceptor(chainStreamClientInterceptors(interceptors...))
|
||||
return grpc.WithChainStreamInterceptor(interceptors...)
|
||||
}
|
||||
|
||||
func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption {
|
||||
return grpc.WithUnaryInterceptor(chainUnaryClientInterceptors(interceptors...))
|
||||
}
|
||||
|
||||
func chainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
|
||||
switch len(interceptors) {
|
||||
case 0:
|
||||
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}
|
||||
case 1:
|
||||
return interceptors[0]
|
||||
default:
|
||||
last := len(interceptors) - 1
|
||||
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
|
||||
method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
var chainStreamer grpc.Streamer
|
||||
var current int
|
||||
|
||||
chainStreamer = func(curCtx context.Context, curDesc *grpc.StreamDesc, curCc *grpc.ClientConn,
|
||||
curMethod string, curOpts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
if current == last {
|
||||
return streamer(curCtx, curDesc, curCc, curMethod, curOpts...)
|
||||
}
|
||||
|
||||
current++
|
||||
clientStream, err := interceptors[current](curCtx, curDesc, curCc, curMethod, chainStreamer, curOpts...)
|
||||
current--
|
||||
|
||||
return clientStream, err
|
||||
}
|
||||
|
||||
return interceptors[0](ctx, desc, cc, method, chainStreamer, opts...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
|
||||
switch len(interceptors) {
|
||||
case 0:
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
case 1:
|
||||
return interceptors[0]
|
||||
default:
|
||||
last := len(interceptors) - 1
|
||||
return func(ctx context.Context, method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
var chainInvoker grpc.UnaryInvoker
|
||||
var current int
|
||||
|
||||
chainInvoker = func(curCtx context.Context, curMethod string, curReq, curReply interface{},
|
||||
curCc *grpc.ClientConn, curOpts ...grpc.CallOption) error {
|
||||
if current == last {
|
||||
return invoker(curCtx, curMethod, curReq, curReply, curCc, curOpts...)
|
||||
}
|
||||
|
||||
current++
|
||||
err := interceptors[current](curCtx, curMethod, curReq, curReply, curCc, chainInvoker, curOpts...)
|
||||
current--
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return interceptors[0](ctx, method, req, reply, cc, chainInvoker, opts...)
|
||||
}
|
||||
}
|
||||
return grpc.WithChainUnaryInterceptor(interceptors...)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestWithStreamClientInterceptors(t *testing.T) {
|
||||
@@ -16,108 +14,4 @@ func TestWithStreamClientInterceptors(t *testing.T) {
|
||||
func TestWithUnaryClientInterceptors(t *testing.T) {
|
||||
opts := WithUnaryClientInterceptors()
|
||||
assert.NotNil(t, opts)
|
||||
}
|
||||
|
||||
func TestChainStreamClientInterceptors_zero(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamClientInterceptors()
|
||||
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
|
||||
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
vals = append(vals, 1)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
|
||||
func TestChainStreamClientInterceptors_one(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
|
||||
cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
|
||||
grpc.ClientStream, error) {
|
||||
vals = append(vals, 1)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
})
|
||||
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
|
||||
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
vals = append(vals, 2)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2}, vals)
|
||||
}
|
||||
|
||||
func TestChainStreamClientInterceptors_more(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
|
||||
cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
|
||||
grpc.ClientStream, error) {
|
||||
vals = append(vals, 1)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
vals = append(vals, 2)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
})
|
||||
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
|
||||
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
vals = append(vals, 3)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
|
||||
}
|
||||
|
||||
func TestWithUnaryClientInterceptors_zero(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryClientInterceptors()
|
||||
err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 1)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
|
||||
func TestWithUnaryClientInterceptors_one(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
|
||||
reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 1)
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
})
|
||||
err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 2)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2}, vals)
|
||||
}
|
||||
|
||||
func TestWithUnaryClientInterceptors_more(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
|
||||
reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 1)
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 2)
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
})
|
||||
err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 3)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
|
||||
}
|
||||
}
|
||||
@@ -1,81 +1,11 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
import "google.golang.org/grpc"
|
||||
|
||||
func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
|
||||
return grpc.StreamInterceptor(chainStreamServerInterceptors(interceptors...))
|
||||
return grpc.ChainStreamInterceptor(interceptors...)
|
||||
}
|
||||
|
||||
func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
|
||||
return grpc.UnaryInterceptor(chainUnaryServerInterceptors(interceptors...))
|
||||
}
|
||||
|
||||
func chainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
|
||||
switch len(interceptors) {
|
||||
case 0:
|
||||
return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) error {
|
||||
return handler(srv, stream)
|
||||
}
|
||||
case 1:
|
||||
return interceptors[0]
|
||||
default:
|
||||
last := len(interceptors) - 1
|
||||
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) error {
|
||||
var chainHandler grpc.StreamHandler
|
||||
var current int
|
||||
|
||||
chainHandler = func(curSrv interface{}, curStream grpc.ServerStream) error {
|
||||
if current == last {
|
||||
return handler(curSrv, curStream)
|
||||
}
|
||||
|
||||
current++
|
||||
err := interceptors[current](curSrv, curStream, info, chainHandler)
|
||||
current--
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return interceptors[0](srv, stream, info, chainHandler)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
|
||||
switch len(interceptors) {
|
||||
case 0:
|
||||
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
|
||||
interface{}, error) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
case 1:
|
||||
return interceptors[0]
|
||||
default:
|
||||
last := len(interceptors) - 1
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
|
||||
interface{}, error) {
|
||||
var chainHandler grpc.UnaryHandler
|
||||
var current int
|
||||
|
||||
chainHandler = func(curCtx context.Context, curReq interface{}) (interface{}, error) {
|
||||
if current == last {
|
||||
return handler(curCtx, curReq)
|
||||
}
|
||||
|
||||
current++
|
||||
resp, err := interceptors[current](curCtx, curReq, info, chainHandler)
|
||||
current--
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
return interceptors[0](ctx, req, info, chainHandler)
|
||||
}
|
||||
}
|
||||
return grpc.ChainUnaryInterceptor(interceptors...)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestWithStreamServerInterceptors(t *testing.T) {
|
||||
@@ -16,96 +14,4 @@ func TestWithStreamServerInterceptors(t *testing.T) {
|
||||
func TestWithUnaryServerInterceptors(t *testing.T) {
|
||||
opts := WithUnaryServerInterceptors()
|
||||
assert.NotNil(t, opts)
|
||||
}
|
||||
|
||||
func TestChainStreamServerInterceptors_zero(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamServerInterceptors()
|
||||
err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
vals = append(vals, 1)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
|
||||
func TestChainStreamServerInterceptors_one(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
vals = append(vals, 1)
|
||||
return handler(srv, ss)
|
||||
})
|
||||
err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
vals = append(vals, 2)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2}, vals)
|
||||
}
|
||||
|
||||
func TestChainStreamServerInterceptors_more(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
vals = append(vals, 1)
|
||||
return handler(srv, ss)
|
||||
}, func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
vals = append(vals, 2)
|
||||
return handler(srv, ss)
|
||||
})
|
||||
err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
vals = append(vals, 3)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
|
||||
}
|
||||
|
||||
func TestChainUnaryServerInterceptors_zero(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryServerInterceptors()
|
||||
_, err := interceptors(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
vals = append(vals, 1)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
|
||||
func TestChainUnaryServerInterceptors_one(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
|
||||
info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
vals = append(vals, 1)
|
||||
return handler(ctx, req)
|
||||
})
|
||||
_, err := interceptors(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
vals = append(vals, 2)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2}, vals)
|
||||
}
|
||||
|
||||
func TestChainUnaryServerInterceptors_more(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
|
||||
info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
vals = append(vals, 1)
|
||||
return handler(ctx, req)
|
||||
}, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
vals = append(vals, 2)
|
||||
return handler(ctx, req)
|
||||
})
|
||||
_, err := interceptors(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
vals = append(vals, 3)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user