feature: file namestyle (#223)

* add api filename style

* new feature: config.yaml

* optimize

* optimize logic generation

* check hanlder valid

* optimize

* reactor naming style

* optimize

* optimize test

* optimize gen middleware

* format

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
Co-authored-by: kim <xutao@xiaoheiban.cn>
This commit is contained in:
kingxt
2020-11-24 15:11:18 +08:00
committed by GitHub
parent 702e8d79ce
commit b9ac51b6c3
40 changed files with 896 additions and 296 deletions

View File

@@ -16,6 +16,7 @@ import (
apiformat "github.com/tal-tech/go-zero/tools/goctl/api/format" apiformat "github.com/tal-tech/go-zero/tools/goctl/api/format"
"github.com/tal-tech/go-zero/tools/goctl/api/parser" "github.com/tal-tech/go-zero/tools/goctl/api/parser"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util" apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@@ -27,6 +28,8 @@ var tmpDir = path.Join(os.TempDir(), "goctl")
func GoCommand(c *cli.Context) error { func GoCommand(c *cli.Context) error {
apiFile := c.String("api") apiFile := c.String("api")
dir := c.String("dir") dir := c.String("dir")
namingStyle := c.String("style")
if len(apiFile) == 0 { if len(apiFile) == 0 {
return errors.New("missing -api") return errors.New("missing -api")
} }
@@ -34,10 +37,10 @@ func GoCommand(c *cli.Context) error {
return errors.New("missing -dir") return errors.New("missing -dir")
} }
return DoGenProject(apiFile, dir) return DoGenProject(apiFile, dir, namingStyle)
} }
func DoGenProject(apiFile, dir string) error { func DoGenProject(apiFile, dir, style string) error {
p, err := parser.NewParser(apiFile) p, err := parser.NewParser(apiFile)
if err != nil { if err != nil {
return err return err
@@ -47,15 +50,21 @@ func DoGenProject(apiFile, dir string) error {
return err return err
} }
cfg, err := config.NewConfig(style)
if err != nil {
return err
}
logx.Must(util.MkdirIfNotExist(dir)) logx.Must(util.MkdirIfNotExist(dir))
logx.Must(genEtc(dir, api)) logx.Must(genEtc(dir, cfg, api))
logx.Must(genConfig(dir, api)) logx.Must(genConfig(dir, cfg, api))
logx.Must(genMain(dir, api)) logx.Must(genMain(dir, cfg, api))
logx.Must(genServiceContext(dir, api)) logx.Must(genServiceContext(dir, cfg, api))
logx.Must(genTypes(dir, api)) logx.Must(genTypes(dir, cfg, api))
logx.Must(genHandlers(dir, api)) logx.Must(genRoutes(dir, cfg, api))
logx.Must(genRoutes(dir, api)) logx.Must(genHandlers(dir, cfg, api))
logx.Must(genLogic(dir, api)) logx.Must(genLogic(dir, cfg, api))
logx.Must(genMiddleware(dir, cfg, api))
if err := backupAndSweep(apiFile); err != nil { if err := backupAndSweep(apiFile); err != nil {
return err return err

View File

@@ -534,6 +534,7 @@ func TestHasImportApi(t *testing.T) {
} }
} }
assert.True(t, hasInline) assert.True(t, hasInline)
validate(t, filename) validate(t, filename)
} }
@@ -558,15 +559,30 @@ func TestNestTypeApi(t *testing.T) {
err := ioutil.WriteFile(filename, []byte(nestTypeApi), os.ModePerm) err := ioutil.WriteFile(filename, []byte(nestTypeApi), os.ModePerm)
assert.Nil(t, err) assert.Nil(t, err)
defer os.Remove(filename) defer os.Remove(filename)
_, err = parser.NewParser(filename) _, err = parser.NewParser(filename)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
func TestCamelStyle(t *testing.T) {
filename := "greet.api"
err := ioutil.WriteFile(filename, []byte(testApiTemplate), os.ModePerm)
assert.Nil(t, err)
defer os.Remove(filename)
_, err = parser.NewParser(filename)
assert.Nil(t, err)
validateWithCamel(t, filename, "GoZero")
}
func validate(t *testing.T, api string) { func validate(t *testing.T, api string) {
validateWithCamel(t, api, "gozero")
}
func validateWithCamel(t *testing.T, api, camel string) {
dir := "_go" dir := "_go"
os.RemoveAll(dir) os.RemoveAll(dir)
err := DoGenProject(api, dir) err := DoGenProject(api, dir, camel)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
assert.Nil(t, err) assert.Nil(t, err)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {

View File

@@ -8,12 +8,14 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
const ( const (
configFile = "config.go" configFile = "config"
configTemplate = `package config configTemplate = `package config
import {{.authImport}} import {{.authImport}}
@@ -31,8 +33,13 @@ type Config struct {
` `
) )
func genConfig(dir string, api *spec.ApiSpec) error { func genConfig(dir string, cfg *config.Config, api *spec.ApiSpec) error {
fp, created, err := util.MaybeCreateFile(dir, configDir, configFile) filename, err := format.FileNamingFormat(cfg.NamingFormat, configFile)
if err != nil {
return err
}
fp, created, err := util.MaybeCreateFile(dir, configDir, filename+".go")
if err != nil { if err != nil {
return err return err
} }

View File

@@ -8,7 +8,9 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
) )
const ( const (
@@ -20,8 +22,13 @@ Port: {{.port}}
` `
) )
func genEtc(dir string, api *spec.ApiSpec) error { func genEtc(dir string, cfg *config.Config, api *spec.ApiSpec) error {
fp, created, err := util.MaybeCreateFile(dir, etcDir, fmt.Sprintf("%s.yaml", api.Service.Name)) filename, err := format.FileNamingFormat(cfg.NamingFormat, api.Service.Name)
if err != nil {
return err
}
fp, created, err := util.MaybeCreateFile(dir, etcDir, fmt.Sprintf("%s.yaml", filename))
if err != nil { if err != nil {
return err return err
} }

View File

@@ -2,14 +2,18 @@ package gogen
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"path" "path"
"strings" "strings"
"text/template" "text/template"
"unicode"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util" apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
@@ -50,13 +54,8 @@ type Handler struct {
HasRequest bool HasRequest bool
} }
func genHandler(dir string, group spec.Group, route spec.Route) error { func genHandler(dir string, cfg *config.Config, group spec.Group, route spec.Route) error {
handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler") handler := getHandlerName(route)
if !ok {
return fmt.Errorf("missing handler annotation for %q", route.Path)
}
handler = getHandlerName(handler)
if getHandlerFolderPath(group, route) != handlerDir { if getHandlerFolderPath(group, route) != handlerDir {
handler = strings.Title(handler) handler = strings.Title(handler)
} }
@@ -65,27 +64,24 @@ func genHandler(dir string, group spec.Group, route spec.Route) error {
return err return err
} }
return doGenToFile(dir, handler, group, route, Handler{ return doGenToFile(dir, handler, cfg, group, route, Handler{
ImportPackages: genHandlerImports(group, route, parentPkg), ImportPackages: genHandlerImports(group, route, parentPkg),
HandlerName: handler, HandlerName: handler,
RequestType: util.Title(route.RequestType.Name), RequestType: util.Title(route.RequestType.Name),
LogicType: strings.TrimSuffix(strings.Title(handler), "Handler") + "Logic", LogicType: strings.Title(getLogicName(route)),
Call: strings.Title(strings.TrimSuffix(handler, "Handler")), Call: strings.Title(strings.TrimSuffix(handler, "Handler")),
HasResp: len(route.ResponseType.Name) > 0, HasResp: len(route.ResponseType.Name) > 0,
HasRequest: len(route.RequestType.Name) > 0, HasRequest: len(route.RequestType.Name) > 0,
}) })
} }
func doGenToFile(dir, handler string, group spec.Group, route spec.Route, handleObj Handler) error { func doGenToFile(dir, handler string, cfg *config.Config, group spec.Group, route spec.Route, handleObj Handler) error {
if getHandlerFolderPath(group, route) != handlerDir { filename, err := format.FileNamingFormat(cfg.NamingFormat, handler)
handler = strings.Title(handler) if err != nil {
} return err
filename := strings.ToLower(handler)
if strings.HasSuffix(filename, "handler") {
filename = filename + ".go"
} else {
filename = filename + "handler.go"
} }
filename = filename + ".go"
fp, created, err := apiutil.MaybeCreateFile(dir, getHandlerFolderPath(group, route), filename) fp, created, err := apiutil.MaybeCreateFile(dir, getHandlerFolderPath(group, route), filename)
if err != nil { if err != nil {
return err return err
@@ -111,10 +107,10 @@ func doGenToFile(dir, handler string, group spec.Group, route spec.Route, handle
return err return err
} }
func genHandlers(dir string, api *spec.ApiSpec) error { func genHandlers(dir string, cfg *config.Config, api *spec.ApiSpec) error {
for _, group := range api.Service.Groups { for _, group := range api.Service.Groups {
for _, route := range group.Routes { for _, route := range group.Routes {
if err := genHandler(dir, group, route); err != nil { if err := genHandler(dir, cfg, group, route); err != nil {
return err return err
} }
} }
@@ -136,14 +132,23 @@ func genHandlerImports(group spec.Group, route spec.Route, parentPkg string) str
return strings.Join(imports, "\n\t") return strings.Join(imports, "\n\t")
} }
func getHandlerBaseName(handler string) string { func getHandlerBaseName(route spec.Route) (string, error) {
handlerName := util.Untitle(handler) handler, ok := apiutil.GetAnnotationValue(route.Annotations, "server", "handler")
if strings.HasSuffix(handlerName, "handler") { if !ok {
handlerName = strings.ReplaceAll(handlerName, "handler", "") return "", fmt.Errorf("missing handler annotation for %q", route.Path)
} else if strings.HasSuffix(handlerName, "Handler") {
handlerName = strings.ReplaceAll(handlerName, "Handler", "")
} }
return handlerName
for _, char := range handler {
if !unicode.IsDigit(char) && !unicode.IsLetter(char) {
return "", errors.New(fmt.Sprintf("route [%s] handler [%s] invalid, handler name should only contains letter or digit",
route.Path, handler))
}
}
handler = strings.TrimSpace(handler)
handler = strings.TrimSuffix(handler, "handler")
handler = strings.TrimSuffix(handler, "Handler")
return handler, nil
} }
func getHandlerFolderPath(group spec.Group, route spec.Route) string { func getHandlerFolderPath(group spec.Group, route spec.Route) string {
@@ -159,6 +164,20 @@ func getHandlerFolderPath(group spec.Group, route spec.Route) string {
return path.Join(handlerDir, folder) return path.Join(handlerDir, folder)
} }
func getHandlerName(handler string) string { func getHandlerName(route spec.Route) string {
return getHandlerBaseName(handler) + "Handler" handler, err := getHandlerBaseName(route)
if err != nil {
panic(err)
}
return handler + "Handler"
}
func getLogicName(route spec.Route) string {
handler, err := getHandlerBaseName(route)
if err != nil {
panic(err)
}
return handler + "Logic"
} }

View File

@@ -9,7 +9,9 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
@@ -40,10 +42,10 @@ func (l *{{.logic}}) {{.function}}({{.request}}) {{.responseType}} {
} }
` `
func genLogic(dir string, api *spec.ApiSpec) error { func genLogic(dir string, cfg *config.Config, api *spec.ApiSpec) error {
for _, g := range api.Service.Groups { for _, g := range api.Service.Groups {
for _, r := range g.Routes { for _, r := range g.Routes {
err := genLogicByRoute(dir, g, r) err := genLogicByRoute(dir, cfg, g, r)
if err != nil { if err != nil {
return err return err
} }
@@ -52,16 +54,14 @@ func genLogic(dir string, api *spec.ApiSpec) error {
return nil return nil
} }
func genLogicByRoute(dir string, group spec.Group, route spec.Route) error { func genLogicByRoute(dir string, cfg *config.Config, group spec.Group, route spec.Route) error {
handler, ok := util.GetAnnotationValue(route.Annotations, "server", "handler") logic := getLogicName(route)
if !ok { goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
return fmt.Errorf("missing handler annotation for %q", route.Path) if err != nil {
return err
} }
handler = strings.TrimSuffix(handler, "handler") goFile = goFile + ".go"
handler = strings.TrimSuffix(handler, "Handler")
filename := strings.ToLower(handler)
goFile := filename + "logic.go"
fp, created, err := util.MaybeCreateFile(dir, getLogicFolderPath(group, route), goFile) fp, created, err := util.MaybeCreateFile(dir, getLogicFolderPath(group, route), goFile)
if err != nil { if err != nil {
return err return err
@@ -102,8 +102,8 @@ func genLogicByRoute(dir string, group spec.Group, route spec.Route) error {
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
err = t.Execute(fp, map[string]string{ err = t.Execute(fp, map[string]string{
"imports": imports, "imports": imports,
"logic": strings.Title(handler) + "Logic", "logic": strings.Title(logic),
"function": strings.Title(strings.TrimSuffix(handler, "Handler")), "function": strings.Title(strings.TrimSuffix(logic, "Logic")),
"responseType": responseString, "responseType": responseString,
"returnString": returnString, "returnString": returnString,
"request": requestString, "request": requestString,

View File

@@ -8,7 +8,9 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
@@ -40,12 +42,17 @@ func main() {
} }
` `
func genMain(dir string, api *spec.ApiSpec) error { func genMain(dir string, cfg *config.Config, api *spec.ApiSpec) error {
name := strings.ToLower(api.Service.Name) name := strings.ToLower(api.Service.Name)
if strings.HasSuffix(name, "-api") { if strings.HasSuffix(name, "-api") {
name = strings.ReplaceAll(name, "-api", "") name = strings.ReplaceAll(name, "-api", "")
} }
goFile := name + ".go" filename, err := format.FileNamingFormat(cfg.NamingFormat, name)
if err != nil {
return err
}
goFile := filename + ".go"
fp, created, err := util.MaybeCreateFile(dir, "", goFile) fp, created, err := util.MaybeCreateFile(dir, "", goFile)
if err != nil { if err != nil {
return err return err

View File

@@ -5,7 +5,10 @@ import (
"strings" "strings"
"text/template" "text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
) )
var middlewareImplementCode = ` var middlewareImplementCode = `
@@ -30,9 +33,16 @@ func (m *{{.name}})Handle(next http.HandlerFunc) http.HandlerFunc {
} }
` `
func genMiddleware(dir string, middlewares []string) error { func genMiddleware(dir string, cfg *config.Config, api *spec.ApiSpec) error {
var middlewares = getMiddleware(api)
for _, item := range middlewares { for _, item := range middlewares {
filename := strings.TrimSuffix(strings.ToLower(item), "middleware") + "middleware" + ".go" middlewareFilename := strings.TrimSuffix(strings.ToLower(item), "middleware") + "_middleware"
formatName, err := format.FileNamingFormat(cfg.NamingFormat, middlewareFilename)
if err != nil {
return err
}
filename := formatName + ".go"
fp, created, err := util.MaybeCreateFile(dir, middlewareDir, filename) fp, created, err := util.MaybeCreateFile(dir, middlewareDir, filename)
if err != nil { if err != nil {
return err return err

View File

@@ -12,12 +12,14 @@ import (
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util" apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
const ( const (
routesFilename = "routes.go" routesFilename = "routes"
routesTemplate = `// Code generated by goctl. DO NOT EDIT. routesTemplate = `// Code generated by goctl. DO NOT EDIT.
package handler package handler
@@ -62,7 +64,7 @@ type (
} }
) )
func genRoutes(dir string, api *spec.ApiSpec) error { func genRoutes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
var builder strings.Builder var builder strings.Builder
groups, err := getRoutes(api) groups, err := getRoutes(api)
if err != nil { if err != nil {
@@ -121,10 +123,16 @@ func genRoutes(dir string, api *spec.ApiSpec) error {
return err return err
} }
filename := path.Join(dir, handlerDir, routesFilename) routeFilename, err := format.FileNamingFormat(cfg.NamingFormat, routesFilename)
if err != nil {
return err
}
routeFilename = routeFilename + ".go"
filename := path.Join(dir, handlerDir, routeFilename)
os.Remove(filename) os.Remove(filename)
fp, created, err := apiutil.MaybeCreateFile(dir, handlerDir, routesFilename) fp, created, err := apiutil.MaybeCreateFile(dir, handlerDir, routeFilename)
if err != nil { if err != nil {
return err return err
} }
@@ -176,11 +184,8 @@ func getRoutes(api *spec.ApiSpec) ([]group, error) {
for _, g := range api.Service.Groups { for _, g := range api.Service.Groups {
var groupedRoutes group var groupedRoutes group
for _, r := range g.Routes { for _, r := range g.Routes {
handler, ok := apiutil.GetAnnotationValue(r.Annotations, "server", "handler") handler := getHandlerName(r)
if !ok { handler = handler + "(serverCtx)"
return nil, fmt.Errorf("missing handler annotation for route %q", r.Path)
}
handler = getHandlerBaseName(handler) + "Handler(serverCtx)"
folder, ok := apiutil.GetAnnotationValue(r.Annotations, "server", groupProperty) folder, ok := apiutil.GetAnnotationValue(r.Annotations, "server", groupProperty)
if ok { if ok {
handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:] handler = toPrefix(folder) + "." + strings.ToUpper(handler[:1]) + handler[1:]

View File

@@ -8,12 +8,14 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
"github.com/tal-tech/go-zero/tools/goctl/api/util" "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
ctlutil "github.com/tal-tech/go-zero/tools/goctl/util" ctlutil "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/vars" "github.com/tal-tech/go-zero/tools/goctl/vars"
) )
const ( const (
contextFilename = "servicecontext.go" contextFilename = "service_context"
contextTemplate = `package svc contextTemplate = `package svc
import ( import (
@@ -35,8 +37,13 @@ func NewServiceContext(c {{.config}}) *ServiceContext {
` `
) )
func genServiceContext(dir string, api *spec.ApiSpec) error { func genServiceContext(dir string, cfg *config.Config, api *spec.ApiSpec) error {
fp, created, err := util.MaybeCreateFile(dir, contextDir, contextFilename) filename, err := format.FileNamingFormat(cfg.NamingFormat, contextFilename)
if err != nil {
return err
}
fp, created, err := util.MaybeCreateFile(dir, contextDir, filename+".go")
if err != nil { if err != nil {
return err return err
} }
@@ -64,10 +71,6 @@ func genServiceContext(dir string, api *spec.ApiSpec) error {
var middlewareStr string var middlewareStr string
var middlewareAssignment string var middlewareAssignment string
var middlewares = getMiddleware(api) var middlewares = getMiddleware(api)
err = genMiddleware(dir, middlewares)
if err != nil {
return err
}
for _, item := range middlewares { for _, item := range middlewares {
middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item) middlewareStr += fmt.Sprintf("%s rest.Middleware\n", item)

View File

@@ -12,11 +12,13 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/api/spec" "github.com/tal-tech/go-zero/tools/goctl/api/spec"
apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util" apiutil "github.com/tal-tech/go-zero/tools/goctl/api/util"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
) )
const ( const (
typesFile = "types.go" typesFile = "types"
typesTemplate = `// Code generated by goctl. DO NOT EDIT. typesTemplate = `// Code generated by goctl. DO NOT EDIT.
package types{{if .containsTime}} package types{{if .containsTime}}
import ( import (
@@ -43,19 +45,25 @@ func BuildTypes(types []spec.Type) (string, error) {
return builder.String(), nil return builder.String(), nil
} }
func genTypes(dir string, api *spec.ApiSpec) error { func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
val, err := BuildTypes(api.Types) val, err := BuildTypes(api.Types)
if err != nil { if err != nil {
return err return err
} }
filename := path.Join(dir, typesDir, typesFile) typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, typesFile)
os.Remove(filename)
fp, created, err := apiutil.MaybeCreateFile(dir, typesDir, typesFile)
if err != nil { if err != nil {
return err return err
} }
typeFilename = typeFilename + ".go"
filename := path.Join(dir, typesDir, typeFilename)
os.Remove(filename)
fp, created, err := apiutil.MaybeCreateFile(dir, typesDir, typeFilename)
if err != nil {
return err
}
if !created { if !created {
return nil return nil
} }

View File

@@ -7,6 +7,7 @@ import (
"text/template" "text/template"
"github.com/tal-tech/go-zero/tools/goctl/api/gogen" "github.com/tal-tech/go-zero/tools/goctl/api/gogen"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/urfave/cli" "github.com/urfave/cli"
) )
@@ -60,6 +61,6 @@ func NewService(c *cli.Context) error {
return err return err
} }
err = gogen.DoGenProject(apiFilePath, abs) err = gogen.DoGenProject(apiFilePath, abs, conf.DefaultFormat)
return err return err
} }

View File

@@ -0,0 +1,122 @@
package config
import (
"errors"
"io/ioutil"
"os"
"path/filepath"
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util"
"gopkg.in/yaml.v2"
)
const (
configFile = "config.yaml"
configFolder = "config"
DefaultFormat = "gozero"
)
const defaultYaml = `# namingFormat is used to define the naming format of the generated file name.
# just like time formatting, you can specify the formatting style through the
# two format characters go, and zero. for example: snake format you can
# define as go_zero, camel case format you can it is defined as goZero,
# and even split characters can be specified, such as go#zero. in theory,
# any combination can be used, but the prerequisite must meet the naming conventions
# of each operating system file name. if you want to independently control the file
# naming style of the api, rpc, and model layers, you can set it through apiNamingFormat,
# rpcNamingFormat, modelNamingFormat, and independent control is not enabled by default.
# for more information, please see #{apiNamingFormat},#{rpcNamingFormat},#{modelNamingFormat}
# Note: namingFormat is based on snake or camel string
namingFormat: gozero
`
type Config struct {
// NamingFormat is used to define the naming format of the generated file name.
// just like time formatting, you can specify the formatting style through the
// two format characters go, and zero. for example: snake format you can
// define as go_zero, camel case format you can it is defined as goZero,
// and even split characters can be specified, such as go#zero. in theory,
// any combination can be used, but the prerequisite must meet the naming conventions
// of each operating system file name.
// Note: NamingFormat is based on snake or camel string
NamingFormat string `yaml:"namingFormat"`
}
func NewConfig(format string) (*Config, error) {
if len(format) == 0 {
format = DefaultFormat
}
cfg := &Config{NamingFormat: format}
err := validate(cfg)
return cfg, err
}
func InitOrGetConfig() (*Config, error) {
var (
defaultConfig Config
)
err := yaml.Unmarshal([]byte(defaultYaml), &defaultConfig)
if err != nil {
return nil, err
}
goctlHome, err := util.GetGoctlHome()
if err != nil {
return nil, err
}
configDir := filepath.Join(goctlHome, configFolder)
configFilename := filepath.Join(configDir, configFile)
if util.FileExists(configFilename) {
data, err := ioutil.ReadFile(configFilename)
if err != nil {
return nil, err
}
err = yaml.Unmarshal(data, &defaultConfig)
if err != nil {
return nil, err
}
err = validate(&defaultConfig)
if err != nil {
return nil, err
}
return &defaultConfig, nil
}
err = util.MkdirIfNotExist(configDir)
if err != nil {
return nil, err
}
f, err := os.Create(configFilename)
if err != nil {
return nil, err
}
defer func() {
_ = f.Close()
}()
_, err = f.WriteString(defaultYaml)
if err != nil {
return nil, err
}
err = validate(&defaultConfig)
if err != nil {
return nil, err
}
return &defaultConfig, nil
}
func validate(cfg *Config) error {
if len(strings.TrimSpace(cfg.NamingFormat)) == 0 {
return errors.New("missing namingFormat")
}
return nil
}

View File

@@ -0,0 +1,50 @@
# 配置项管理
| 名称 | 是否可选 | 说明 |
|-------------------|----------|-----------------------------------------------|
| namingFormat | YES | 文件名称格式化符 |
# naming-format
`namingFormat`可以用于对生成代码的文件名称进行格式化和日期格式化符yyyy-MM-dd类似在代码生成时可以根据这些配置项的格式化符进行格式化。
## 格式化符(gozero)
格式化符有`go`,`zero`组成,如常见的三种格式化风格你可以这样编写:
* lower: `gozero`
* camel: `goZero`
* snake: `go_zero`
常见格式化符生成示例
源字符welcome_to_go_zero
| 格式化符 | 格式化结果 | 说明 |
|------------|-----------------------|---------------------------|
| gozero | welcometogozero | 小写 |
| goZero | welcomeToGoZero | 驼峰 |
| go_zero | welcome_to_go_zero | snake |
| Go#zero | Welcome#to#go#zero | #号分割Title类型 |
| GOZERO | WELCOMETOGOZERO | 大写 |
| \_go#zero_ | \_welcome#to#go#zero_ | 下划线做前后缀,并且#分割 |
错误格式化符示例
* go
* gOZero
* zero
* goZEro
* goZERo
* goZeRo
* tal
# 使用方法
目前可通过在生成api、rpc、model时通过`--style`参数指定format格式
```shell script
goctl api go test.api -dir . -style gozero
```
```shell script
goctl rpc proto -src test.proto -dir . -style go_zero
```
```shell script
goctl model mysql datasource -url="" -table="*" -dir ./snake -style GoZero
```
# 默认值
当不指定-style时默认值为`gozero`

View File

@@ -25,7 +25,7 @@ import (
) )
var ( var (
BuildVersion = "20201108" BuildVersion = "20201119-beta"
commands = []cli.Command{ commands = []cli.Command{
{ {
Name: "api", Name: "api",
@@ -98,6 +98,11 @@ var (
Name: "api", Name: "api",
Usage: "the api file", Usage: "the api file",
}, },
cli.StringFlag{
Name: "style",
Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
},
}, },
Action: gogen.GoCommand, Action: gogen.GoCommand,
}, },
@@ -203,7 +208,7 @@ var (
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Usage: "the file naming style, lower|camel|snake,default is lower", Usage: "the file naming style, lower|camel|snake,default is lower, [deprecated,use config.yaml instead]",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",
@@ -240,8 +245,9 @@ var (
Usage: `the target path of the code`, Usage: `the target path of the code`,
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Usage: "the file naming style, lower|camel|snake,default is lower", Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",
@@ -273,8 +279,9 @@ var (
Usage: "the target dir", Usage: "the target dir",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Usage: "the file naming style, lower|camel|snake,default is lower", Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "cache, c", Name: "cache, c",
@@ -308,8 +315,9 @@ var (
Usage: "the target dir", Usage: "the target dir",
}, },
cli.StringFlag{ cli.StringFlag{
Name: "style", Name: "style",
Usage: "the file naming style, lower|camel|snake, default is lower", Required: false,
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
}, },
cli.BoolFlag{ cli.BoolFlag{
Name: "idea", Name: "idea",

View File

@@ -2,7 +2,6 @@ package command
import ( import (
"errors" "errors"
"fmt"
"io/ioutil" "io/ioutil"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -10,6 +9,7 @@ import (
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stores/sqlx" "github.com/tal-tech/go-zero/core/stores/sqlx"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model" "github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/util" "github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
@@ -24,9 +24,9 @@ const (
flagDir = "dir" flagDir = "dir"
flagCache = "cache" flagCache = "cache"
flagIdea = "idea" flagIdea = "idea"
flagStyle = "style"
flagUrl = "url" flagUrl = "url"
flagTable = "table" flagTable = "table"
flagStyle = "style"
) )
func MysqlDDL(ctx *cli.Context) error { func MysqlDDL(ctx *cli.Context) error {
@@ -34,8 +34,13 @@ func MysqlDDL(ctx *cli.Context) error {
dir := ctx.String(flagDir) dir := ctx.String(flagDir)
cache := ctx.Bool(flagCache) cache := ctx.Bool(flagCache)
idea := ctx.Bool(flagIdea) idea := ctx.Bool(flagIdea)
namingStyle := strings.TrimSpace(ctx.String(flagStyle)) style := ctx.String(flagStyle)
return fromDDl(src, dir, namingStyle, cache, idea) cfg, err := config.NewConfig(style)
if err != nil {
return err
}
return fromDDl(src, dir, cfg, cache, idea)
} }
func MyDataSource(ctx *cli.Context) error { func MyDataSource(ctx *cli.Context) error {
@@ -43,26 +48,23 @@ func MyDataSource(ctx *cli.Context) error {
dir := strings.TrimSpace(ctx.String(flagDir)) dir := strings.TrimSpace(ctx.String(flagDir))
cache := ctx.Bool(flagCache) cache := ctx.Bool(flagCache)
idea := ctx.Bool(flagIdea) idea := ctx.Bool(flagIdea)
namingStyle := strings.TrimSpace(ctx.String(flagStyle)) style := ctx.String(flagStyle)
pattern := strings.TrimSpace(ctx.String(flagTable)) pattern := strings.TrimSpace(ctx.String(flagTable))
return fromDataSource(url, pattern, dir, namingStyle, cache, idea) cfg, err := config.NewConfig(style)
if err != nil {
return err
}
return fromDataSource(url, pattern, dir, cfg, cache, idea)
} }
func fromDDl(src, dir, namingStyle string, cache, idea bool) error { func fromDDl(src, dir string, cfg *config.Config, cache, idea bool) error {
log := console.NewConsole(idea) log := console.NewConsole(idea)
src = strings.TrimSpace(src) src = strings.TrimSpace(src)
if len(src) == 0 { if len(src) == 0 {
return errors.New("expected path or path globbing patterns, but nothing found") return errors.New("expected path or path globbing patterns, but nothing found")
} }
switch namingStyle {
case gen.NamingLower, gen.NamingCamel, gen.NamingSnake:
case "":
namingStyle = gen.NamingLower
default:
return fmt.Errorf("unexpected naming style: %s", namingStyle)
}
files, err := util.MatchFiles(src) files, err := util.MatchFiles(src)
if err != nil { if err != nil {
return err return err
@@ -81,7 +83,7 @@ func fromDDl(src, dir, namingStyle string, cache, idea bool) error {
source = append(source, string(data)) source = append(source, string(data))
} }
generator, err := gen.NewDefaultGenerator(dir, namingStyle, gen.WithConsoleOption(log)) generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log))
if err != nil { if err != nil {
return err return err
} }
@@ -90,7 +92,7 @@ func fromDDl(src, dir, namingStyle string, cache, idea bool) error {
return err return err
} }
func fromDataSource(url, pattern, dir, namingStyle string, cache, idea bool) error { func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error {
log := console.NewConsole(idea) log := console.NewConsole(idea)
if len(url) == 0 { if len(url) == 0 {
log.Error("%v", "expected data source of mysql, but nothing found") log.Error("%v", "expected data source of mysql, but nothing found")
@@ -102,25 +104,17 @@ func fromDataSource(url, pattern, dir, namingStyle string, cache, idea bool) err
return nil return nil
} }
switch namingStyle { dsn, err := mysql.ParseDSN(url)
case gen.NamingLower, gen.NamingCamel, gen.NamingSnake:
case "":
namingStyle = gen.NamingLower
default:
return fmt.Errorf("unexpected naming style: %s", namingStyle)
}
cfg, err := mysql.ParseDSN(url)
if err != nil { if err != nil {
return err return err
} }
logx.Disable() logx.Disable()
databaseSource := strings.TrimSuffix(url, "/"+cfg.DBName) + "/information_schema" databaseSource := strings.TrimSuffix(url, "/"+dsn.DBName) + "/information_schema"
db := sqlx.NewMysql(databaseSource) db := sqlx.NewMysql(databaseSource)
im := model.NewInformationSchemaModel(db) im := model.NewInformationSchemaModel(db)
tables, err := im.GetAllTables(cfg.DBName) tables, err := im.GetAllTables(dsn.DBName)
if err != nil { if err != nil {
return err return err
} }
@@ -135,7 +129,7 @@ func fromDataSource(url, pattern, dir, namingStyle string, cache, idea bool) err
if !match { if !match {
continue continue
} }
columns, err := im.FindByTableName(cfg.DBName, item) columns, err := im.FindByTableName(dsn.DBName, item)
if err != nil { if err != nil {
return err return err
} }
@@ -146,11 +140,11 @@ func fromDataSource(url, pattern, dir, namingStyle string, cache, idea bool) err
return errors.New("no tables matched") return errors.New("no tables matched")
} }
generator, err := gen.NewDefaultGenerator(dir, namingStyle, gen.WithConsoleOption(log)) generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log))
if err != nil { if err != nil {
return err return err
} }
err = generator.StartFromInformationSchema(cfg.DBName, matchTables, cache) err = generator.StartFromInformationSchema(dsn.DBName, matchTables, cache)
return err return err
} }

View File

@@ -7,19 +7,22 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen" "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
) )
var sql = "-- 用户表 --\nCREATE TABLE `user` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n UNIQUE KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;\n\n" var sql = "-- 用户表 --\nCREATE TABLE `user` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n UNIQUE KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;\n\n"
var cfg = &config.Config{
NamingFormat: "gozero",
}
func TestFromDDl(t *testing.T) { func TestFromDDl(t *testing.T) {
err := fromDDl("./user.sql", t.TempDir(), gen.NamingCamel, true, false) err := fromDDl("./user.sql", t.TempDir(), cfg, true, false)
assert.Equal(t, errNotMatched, err) assert.Equal(t, errNotMatched, err)
// case dir is not exists // case dir is not exists
unknownDir := filepath.Join(t.TempDir(), "test", "user.sql") unknownDir := filepath.Join(t.TempDir(), "test", "user.sql")
err = fromDDl(unknownDir, t.TempDir(), gen.NamingCamel, true, false) err = fromDDl(unknownDir, t.TempDir(), cfg, true, false)
assert.True(t, func() bool { assert.True(t, func() bool {
switch err.(type) { switch err.(type) {
case *os.PathError: case *os.PathError:
@@ -30,18 +33,11 @@ func TestFromDDl(t *testing.T) {
}()) }())
// case empty src // case empty src
err = fromDDl("", t.TempDir(), gen.NamingCamel, true, false) err = fromDDl("", t.TempDir(), cfg, true, false)
if err != nil { if err != nil {
assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error()) assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error())
} }
// case unknown naming style
tmp := filepath.Join(t.TempDir(), "user.sql")
err = fromDDl(tmp, t.TempDir(), "lower1", true, false)
if err != nil {
assert.Equal(t, "unexpected naming style: lower1", err.Error())
}
tempDir := filepath.Join(t.TempDir(), "test") tempDir := filepath.Join(t.TempDir(), "test")
err = util.MkdirIfNotExist(tempDir) err = util.MkdirIfNotExist(tempDir)
if err != nil { if err != nil {
@@ -67,7 +63,7 @@ func TestFromDDl(t *testing.T) {
_, err = os.Stat(user2Sql) _, err = os.Stat(user2Sql)
assert.Nil(t, err) assert.Nil(t, err)
err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, gen.NamingLower, true, false) err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false)
assert.Nil(t, err) assert.Nil(t, err)
_, err = os.Stat(filepath.Join(tempDir, "usermodel.go")) _, err = os.Stat(filepath.Join(tempDir, "usermodel.go"))

View File

@@ -7,11 +7,13 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model" "github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser" "github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template" "github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -28,13 +30,13 @@ type (
//source string //source string
dir string dir string
console.Console console.Console
pkg string pkg string
namingStyle string cfg *config.Config
} }
Option func(generator *defaultGenerator) Option func(generator *defaultGenerator)
) )
func NewDefaultGenerator(dir, namingStyle string, opt ...Option) (*defaultGenerator, error) { func NewDefaultGenerator(dir string, cfg *config.Config, opt ...Option) (*defaultGenerator, error) {
if dir == "" { if dir == "" {
dir = pwd dir = pwd
} }
@@ -50,7 +52,7 @@ func NewDefaultGenerator(dir, namingStyle string, opt ...Option) (*defaultGenera
return nil, err return nil, err
} }
generator := &defaultGenerator{dir: dir, namingStyle: namingStyle, pkg: pkg} generator := &defaultGenerator{dir: dir, cfg: cfg, pkg: pkg}
var optionList []Option var optionList []Option
optionList = append(optionList, newDefaultOption()) optionList = append(optionList, newDefaultOption())
optionList = append(optionList, opt...) optionList = append(optionList, opt...)
@@ -114,13 +116,12 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
for tableName, code := range modelList { for tableName, code := range modelList {
tn := stringx.From(tableName) tn := stringx.From(tableName)
name := fmt.Sprintf("%smodel.go", strings.ToLower(tn.ToCamel())) modelFilename, err := format.FileNamingFormat(g.cfg.NamingFormat, fmt.Sprintf("%s_model", tn.Source()))
switch g.namingStyle { if err != nil {
case NamingCamel: return err
name = fmt.Sprintf("%sModel.go", tn.ToCamel())
case NamingSnake:
name = fmt.Sprintf("%s_model.go", tn.ToSnake())
} }
name := modelFilename + ".go"
filename := filepath.Join(dirAbs, name) filename := filepath.Join(dirAbs, name)
if util.FileExists(filename) { if util.FileExists(filename) {
g.Warning("%s already exists, ignored.", name) g.Warning("%s already exists, ignored.", name)
@@ -132,10 +133,12 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error {
} }
} }
// generate error file // generate error file
filename := filepath.Join(dirAbs, "vars.go") varFilename, err := format.FileNamingFormat(g.cfg.NamingFormat, "vars")
if g.namingStyle == NamingCamel { if err != nil {
filename = filepath.Join(dirAbs, "Vars.go") return err
} }
filename := filepath.Join(dirAbs, varFilename+".go")
text, err := util.LoadTemplate(category, errTemplateFile, template.Error) text, err := util.LoadTemplate(category, errTemplateFile, template.Error)
if err != nil { if err != nil {
return err return err

View File

@@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/tools/goctl/config"
) )
var ( var (
@@ -22,7 +23,9 @@ func TestCacheModel(t *testing.T) {
defer func() { defer func() {
_ = os.RemoveAll(dir) _ = os.RemoveAll(dir)
}() }()
g, err := NewDefaultGenerator(cacheDir, NamingCamel) g, err := NewDefaultGenerator(cacheDir, &config.Config{
NamingFormat: "GoZero",
})
assert.Nil(t, err) assert.Nil(t, err)
err = g.StartFromDDL(source, true) err = g.StartFromDDL(source, true)
@@ -31,7 +34,9 @@ func TestCacheModel(t *testing.T) {
_, err := os.Stat(filepath.Join(cacheDir, "TestUserInfoModel.go")) _, err := os.Stat(filepath.Join(cacheDir, "TestUserInfoModel.go"))
return err == nil return err == nil
}()) }())
g, err = NewDefaultGenerator(noCacheDir, NamingLower) g, err = NewDefaultGenerator(noCacheDir, &config.Config{
NamingFormat: "gozero",
})
assert.Nil(t, err) assert.Nil(t, err)
err = g.StartFromDDL(source, false) err = g.StartFromDDL(source, false)
@@ -51,7 +56,9 @@ func TestNamingModel(t *testing.T) {
defer func() { defer func() {
_ = os.RemoveAll(dir) _ = os.RemoveAll(dir)
}() }()
g, err := NewDefaultGenerator(camelDir, NamingCamel) g, err := NewDefaultGenerator(camelDir, &config.Config{
NamingFormat: "GoZero",
})
assert.Nil(t, err) assert.Nil(t, err)
err = g.StartFromDDL(source, true) err = g.StartFromDDL(source, true)
@@ -60,7 +67,9 @@ func TestNamingModel(t *testing.T) {
_, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go")) _, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go"))
return err == nil return err == nil
}()) }())
g, err = NewDefaultGenerator(snakeDir, NamingSnake) g, err = NewDefaultGenerator(snakeDir, &config.Config{
NamingFormat: "go_zero",
})
assert.Nil(t, err) assert.Nil(t, err)
err = g.StartFromDDL(source, true) err = g.StartFromDDL(source, true)

View File

@@ -24,32 +24,26 @@ func Rpc(c *cli.Context) error {
return errors.New("missing -dir") return errors.New("missing -dir")
} }
namingStyle, valid := generator.IsNamingValid(style) g, err := generator.NewDefaultRpcGenerator(style)
if !valid { if err != nil {
return fmt.Errorf("unexpected naming style %s", style) return err
} }
g := generator.NewDefaultRpcGenerator(namingStyle)
return g.Generate(src, out, protoImportPath) return g.Generate(src, out, protoImportPath)
} }
// RpcNew is to generate rpc greet service, this greet service can speed // RpcNew is to generate rpc greet service, this greet service can speed
// up your understanding of the zrpc service structure // up your understanding of the zrpc service structure
func RpcNew(c *cli.Context) error { func RpcNew(c *cli.Context) error {
name := c.Args().First() rpcname := c.Args().First()
ext := filepath.Ext(name) ext := filepath.Ext(rpcname)
if len(ext) > 0 { if len(ext) > 0 {
return fmt.Errorf("unexpected ext: %s", ext) return fmt.Errorf("unexpected ext: %s", ext)
} }
style := c.String("style") style := c.String("style")
namingStyle, valid := generator.IsNamingValid(style)
if !valid {
return fmt.Errorf("expected naming style [lower|camel|snake], but found %s", style)
}
protoName := name + ".proto" protoName := rpcname + ".proto"
filename := filepath.Join(".", name, protoName) filename := filepath.Join(".", rpcname, protoName)
src, err := filepath.Abs(filename) src, err := filepath.Abs(filename)
if err != nil { if err != nil {
return err return err
@@ -60,7 +54,11 @@ func RpcNew(c *cli.Context) error {
return err return err
} }
g := generator.NewDefaultRpcGenerator(namingStyle) g, err := generator.NewDefaultRpcGenerator(style)
if err != nil {
return err
}
return g.Generate(src, filepath.Dir(src), nil) return g.Generate(src, filepath.Dir(src), nil)
} }

View File

@@ -1,18 +0,0 @@
package generator
import (
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
func formatFilename(filename string, style NamingStyle) string {
switch style {
case namingCamel:
return stringx.From(filename).ToCamel()
case namingSnake:
return stringx.From(filename).ToSnake()
default:
return strings.ToLower(stringx.From(filename).ToCamel())
}
}

View File

@@ -1,17 +0,0 @@
package generator
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFormatFilename(t *testing.T) {
assert.Equal(t, "abc", formatFilename("a_b_c", namingLower))
assert.Equal(t, "ABC", formatFilename("a_b_c", namingCamel))
assert.Equal(t, "a_b_c", formatFilename("a_b_c", namingSnake))
assert.Equal(t, "a", formatFilename("a", namingSnake))
assert.Equal(t, "A", formatFilename("a", namingCamel))
// no flag to convert to snake
assert.Equal(t, "abc", formatFilename("abc", namingSnake))
}

View File

@@ -3,6 +3,7 @@ package generator
import ( import (
"path/filepath" "path/filepath"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console" "github.com/tal-tech/go-zero/tools/goctl/util/console"
@@ -10,18 +11,22 @@ import (
) )
type RpcGenerator struct { type RpcGenerator struct {
g Generator g Generator
style NamingStyle cfg *conf.Config
} }
func NewDefaultRpcGenerator(style NamingStyle) *RpcGenerator { func NewDefaultRpcGenerator(style string) (*RpcGenerator, error) {
return NewRpcGenerator(NewDefaultGenerator(), style) cfg, err := conf.NewConfig(style)
if err != nil {
return nil, err
}
return NewRpcGenerator(NewDefaultGenerator(), cfg), nil
} }
func NewRpcGenerator(g Generator, style NamingStyle) *RpcGenerator { func NewRpcGenerator(g Generator, cfg *conf.Config) *RpcGenerator {
return &RpcGenerator{ return &RpcGenerator{
g: g, g: g,
style: style, cfg: cfg,
} }
} }
@@ -57,42 +62,42 @@ func (g *RpcGenerator) Generate(src, target string, protoImportPath []string) er
return err return err
} }
err = g.g.GenEtc(dirCtx, proto, g.style) err = g.g.GenEtc(dirCtx, proto, g.cfg)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenPb(dirCtx, protoImportPath, proto, g.style) err = g.g.GenPb(dirCtx, protoImportPath, proto, g.cfg)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenConfig(dirCtx, proto, g.style) err = g.g.GenConfig(dirCtx, proto, g.cfg)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenSvc(dirCtx, proto, g.style) err = g.g.GenSvc(dirCtx, proto, g.cfg)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenLogic(dirCtx, proto, g.style) err = g.g.GenLogic(dirCtx, proto, g.cfg)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenServer(dirCtx, proto, g.style) err = g.g.GenServer(dirCtx, proto, g.cfg)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenMain(dirCtx, proto, g.style) err = g.g.GenMain(dirCtx, proto, g.cfg)
if err != nil { if err != nil {
return err return err
} }
err = g.g.GenCall(dirCtx, proto, g.style) err = g.g.GenCall(dirCtx, proto, g.cfg)
console.NewColorConsole().MarkDone() console.NewColorConsole().MarkDone()

View File

@@ -9,9 +9,14 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stringx" "github.com/tal-tech/go-zero/core/stringx"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx" "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
) )
var cfg = &conf.Config{
NamingFormat: "gozero",
}
func TestRpcGenerate(t *testing.T) { func TestRpcGenerate(t *testing.T) {
_ = Clean() _ = Clean()
dispatcher := NewDefaultGenerator() dispatcher := NewDefaultGenerator()
@@ -21,7 +26,7 @@ func TestRpcGenerate(t *testing.T) {
return return
} }
projectName := stringx.Rand() projectName := stringx.Rand()
g := NewRpcGenerator(dispatcher, namingLower) g := NewRpcGenerator(dispatcher, cfg)
// case go path // case go path
src := filepath.Join(build.Default.GOPATH, "src") src := filepath.Join(build.Default.GOPATH, "src")

View File

@@ -6,8 +6,10 @@ import (
"strings" "strings"
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -59,12 +61,17 @@ func (m *default{{.serviceName}}) {{.method}}(ctx context.Context,in *{{.pbReque
` `
) )
func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error { func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
dir := ctx.GetCall() dir := ctx.GetCall()
service := proto.Service service := proto.Service
head := util.GetHead(proto.Name) head := util.GetHead(proto.Name)
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", formatFilename(service.Name, namingStyle))) callFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name)
if err != nil {
return err
}
filename := filepath.Join(dir.Filename, fmt.Sprintf("%s.go", callFilename))
functions, err := g.genFunction(proto.PbPackage, service) functions, err := g.genFunction(proto.PbPackage, service)
if err != nil { if err != nil {
return err return err
@@ -86,7 +93,7 @@ func (g *defaultGenerator) GenCall(ctx DirContext, proto parser.Proto, namingSty
} }
err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{ err = util.With("shared").GoFmt(true).Parse(text).SaveTo(map[string]interface{}{
"name": formatFilename(service.Name, namingStyle), "name": callFilename,
"alias": strings.Join(alias.KeysStr(), util.NL), "alias": strings.Join(alias.KeysStr(), util.NL),
"head": head, "head": head,
"filePackage": dir.Base, "filePackage": dir.Base,

View File

@@ -5,8 +5,10 @@ import (
"os" "os"
"path/filepath" "path/filepath"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
) )
const configTemplate = `package config const configTemplate = `package config
@@ -18,9 +20,14 @@ type Config struct {
} }
` `
func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error { func (g *defaultGenerator) GenConfig(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
dir := ctx.GetConfig() dir := ctx.GetConfig()
fileName := filepath.Join(dir.Filename, formatFilename("config", namingStyle)+".go") configFilename, err := format.FileNamingFormat(cfg.NamingFormat, "config")
if err != nil {
return err
}
fileName := filepath.Join(dir.Filename, configFilename+".go")
if util.FileExists(fileName) { if util.FileExists(fileName) {
return nil return nil
} }

View File

@@ -1,15 +1,18 @@
package generator package generator
import "github.com/tal-tech/go-zero/tools/goctl/rpc/parser" import (
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
)
type Generator interface { type Generator interface {
Prepare() error Prepare() error
GenMain(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error GenMain(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenCall(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error GenCall(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenEtc(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error GenEtc(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenConfig(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error GenConfig(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenLogic(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenServer(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenSvc(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error GenSvc(ctx DirContext, proto parser.Proto, cfg *conf.Config) error
GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, namingStyle NamingStyle) error GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, cfg *conf.Config) error
} }

View File

@@ -5,8 +5,10 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -18,10 +20,14 @@ Etcd:
Key: {{.serviceName}}.rpc Key: {{.serviceName}}.rpc
` `
func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error { func (g *defaultGenerator) GenEtc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
dir := ctx.GetEtc() dir := ctx.GetEtc()
serviceNameLower := formatFilename(ctx.GetMain().Base, namingStyle) etcFilename, err := format.FileNamingFormat(cfg.NamingFormat, ctx.GetMain().Base)
fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", serviceNameLower)) if err != nil {
return err
}
fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.yaml", etcFilename))
text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate) text, err := util.LoadTemplate(category, etcTemplateFileFile, etcTemplate)
if err != nil { if err != nil {

View File

@@ -6,8 +6,10 @@ import (
"strings" "strings"
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -46,10 +48,15 @@ func (l *{{.logicName}}) {{.method}} (in {{.request}}) ({{.response}}, error) {
` `
) )
func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error { func (g *defaultGenerator) GenLogic(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
dir := ctx.GetLogic() dir := ctx.GetLogic()
for _, rpc := range proto.Service.RPC { for _, rpc := range proto.Service.RPC {
filename := filepath.Join(dir.Filename, formatFilename(rpc.Name+"_logic", namingStyle)+".go") logicFilename, err := format.FileNamingFormat(cfg.NamingFormat, rpc.Name+"_logic")
if err != nil {
return err
}
filename := filepath.Join(dir.Filename, logicFilename+".go")
functions, err := g.genLogicFunction(proto.PbPackage, rpc) functions, err := g.genLogicFunction(proto.PbPackage, rpc)
if err != nil { if err != nil {
return err return err

View File

@@ -5,8 +5,10 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -45,10 +47,14 @@ func main() {
} }
` `
func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error { func (g *defaultGenerator) GenMain(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
dir := ctx.GetMain() dir := ctx.GetMain()
serviceNameLower := formatFilename(ctx.GetMain().Base, namingStyle) mainFilename, err := format.FileNamingFormat(cfg.NamingFormat, ctx.GetMain().Base)
fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", serviceNameLower)) if err != nil {
return err
}
fileName := filepath.Join(dir.Filename, fmt.Sprintf("%v.go", mainFilename))
imports := make([]string, 0) imports := make([]string, 0)
pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package) pbImport := fmt.Sprintf(`"%v"`, ctx.GetPb().Package)
svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package) svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)

View File

@@ -5,11 +5,12 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/execx" "github.com/tal-tech/go-zero/tools/goctl/rpc/execx"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
) )
func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, namingStyle NamingStyle) error { func (g *defaultGenerator) GenPb(ctx DirContext, protoImportPath []string, proto parser.Proto, _ *conf.Config) error {
dir := ctx.GetPb() dir := ctx.GetPb()
cw := new(bytes.Buffer) cw := new(bytes.Buffer)
base := filepath.Dir(proto.Src) base := filepath.Dir(proto.Src)

View File

@@ -6,8 +6,10 @@ import (
"strings" "strings"
"github.com/tal-tech/go-zero/core/collection" "github.com/tal-tech/go-zero/core/collection"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx" "github.com/tal-tech/go-zero/tools/goctl/util/stringx"
) )
@@ -43,7 +45,7 @@ func (s *{{.server}}Server) {{.method}} (ctx context.Context, in {{.request}}) (
` `
) )
func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto, namingStyle NamingStyle) error { func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto, cfg *conf.Config) error {
dir := ctx.GetServer() dir := ctx.GetServer()
logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package) logicImport := fmt.Sprintf(`"%v"`, ctx.GetLogic().Package)
svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package) svcImport := fmt.Sprintf(`"%v"`, ctx.GetSvc().Package)
@@ -54,7 +56,12 @@ func (g *defaultGenerator) GenServer(ctx DirContext, proto parser.Proto, namingS
head := util.GetHead(proto.Name) head := util.GetHead(proto.Name)
service := proto.Service service := proto.Service
serverFile := filepath.Join(dir.Filename, formatFilename(service.Name+"_server", namingStyle)+".go") serverFilename, err := format.FileNamingFormat(cfg.NamingFormat, service.Name+"_server")
if err != nil {
return err
}
serverFile := filepath.Join(dir.Filename, serverFilename+".go")
funcList, err := g.genFunctions(proto.PbPackage, service) funcList, err := g.genFunctions(proto.PbPackage, service)
if err != nil { if err != nil {
return err return err

View File

@@ -4,8 +4,10 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
conf "github.com/tal-tech/go-zero/tools/goctl/config"
"github.com/tal-tech/go-zero/tools/goctl/rpc/parser" "github.com/tal-tech/go-zero/tools/goctl/rpc/parser"
"github.com/tal-tech/go-zero/tools/goctl/util" "github.com/tal-tech/go-zero/tools/goctl/util"
"github.com/tal-tech/go-zero/tools/goctl/util/format"
) )
const svcTemplate = `package svc const svcTemplate = `package svc
@@ -23,9 +25,14 @@ func NewServiceContext(c config.Config) *ServiceContext {
} }
` `
func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto, namingStyle NamingStyle) error { func (g *defaultGenerator) GenSvc(ctx DirContext, _ parser.Proto, cfg *conf.Config) error {
dir := ctx.GetSvc() dir := ctx.GetSvc()
fileName := filepath.Join(dir.Filename, formatFilename("service_context", namingStyle)+".go") svcFilename, err := format.FileNamingFormat(cfg.NamingFormat, "service_context")
if err != nil {
return err
}
fileName := filepath.Join(dir.Filename, svcFilename+".go")
text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate) text, err := util.LoadTemplate(category, svcTemplateFile, svcTemplate)
if err != nil { if err != nil {
return err return err

View File

@@ -1,24 +0,0 @@
package generator
type NamingStyle = string
const (
namingLower NamingStyle = "lower"
namingCamel NamingStyle = "camel"
namingSnake NamingStyle = "snake"
)
// IsNamingValid validates whether the namingStyle is valid or not,return
// namingStyle and true if it is valid, or else return empty string
// and false, and it is a valid value even namingStyle is empty string
func IsNamingValid(namingStyle string) (NamingStyle, bool) {
if len(namingStyle) == 0 {
namingStyle = namingLower
}
switch namingStyle {
case namingLower, namingCamel, namingSnake:
return namingStyle, true
default:
return "", false
}
}

View File

@@ -1,25 +0,0 @@
package generator
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsNamingValid(t *testing.T) {
style, valid := IsNamingValid("")
assert.True(t, valid)
assert.Equal(t, namingLower, style)
_, valid = IsNamingValid("lower1")
assert.False(t, valid)
_, valid = IsNamingValid("lower")
assert.True(t, valid)
_, valid = IsNamingValid("snake")
assert.True(t, valid)
_, valid = IsNamingValid("camel")
assert.True(t, valid)
}

View File

@@ -8,13 +8,21 @@ import (
const goctlDir = ".goctl" const goctlDir = ".goctl"
func GetTemplateDir(category string) (string, error) { func GetGoctlHome() (string, error) {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", err return "", err
} }
return filepath.Join(home, goctlDir), nil
}
return filepath.Join(home, goctlDir, category), nil func GetTemplateDir(category string) (string, error) {
goctlHome, err := GetGoctlHome()
if err != nil {
return "", err
}
return filepath.Join(goctlHome, category), nil
} }
func InitTemplates(category string, templates map[string]string) error { func InitTemplates(category string, templates map[string]string) error {

View File

@@ -0,0 +1,155 @@
package format
import (
"bytes"
"errors"
"fmt"
"io"
"strings"
)
const (
flagGo = "GO"
flagZero = "ZERO"
unknown style = iota
title
lower
upper
)
var ErrNamingFormat = errors.New("unsupported format")
type (
styleFormat struct {
before string
through string
after string
goStyle style
zeroStyle style
}
style int
)
// FileNamingFormat is used to format the file name. You can define the format style
// through the go and zero formatting characters. For example, you can define the snake
// format as go_zero, and the camel case format as goZero. You can even specify the split
// character, such as go#Zero, theoretically any combination can be used, but the prerequisite
// must meet the naming conventions of each operating system file name.
// Note: Formatting is based on snake or camel string
func FileNamingFormat(format, content string) (string, error) {
upperFormat := strings.ToUpper(format)
indexGo := strings.Index(upperFormat, flagGo)
indexZero := strings.Index(upperFormat, flagZero)
if indexGo < 0 || indexZero < 0 || indexGo > indexZero {
return "", ErrNamingFormat
}
var (
before, through, after string
flagGo, flagZero string
goStyle, zeroStyle style
err error
)
before = format[:indexGo]
flagGo = format[indexGo : indexGo+2]
through = format[indexGo+2 : indexZero]
flagZero = format[indexZero : indexZero+4]
after = format[indexZero+4:]
goStyle, err = getStyle(flagGo)
if err != nil {
return "", err
}
zeroStyle, err = getStyle(flagZero)
if err != nil {
return "", err
}
var formatStyle styleFormat
formatStyle.goStyle = goStyle
formatStyle.zeroStyle = zeroStyle
formatStyle.before = before
formatStyle.through = through
formatStyle.after = after
return doFormat(formatStyle, content)
}
func doFormat(f styleFormat, content string) (string, error) {
splits, err := split(content)
if err != nil {
return "", err
}
var join []string
for index, split := range splits {
if index == 0 {
join = append(join, transferTo(split, f.goStyle))
continue
}
join = append(join, transferTo(split, f.zeroStyle))
}
joined := strings.Join(join, f.through)
return f.before + joined + f.after, nil
}
func transferTo(in string, style style) string {
switch style {
case upper:
return strings.ToUpper(in)
case lower:
return strings.ToLower(in)
case title:
return strings.Title(in)
default:
return in
}
}
func split(content string) ([]string, error) {
var (
list []string
reader = strings.NewReader(content)
buffer = bytes.NewBuffer(nil)
)
for {
r, _, err := reader.ReadRune()
if err != nil {
if err == io.EOF {
if buffer.Len() > 0 {
list = append(list, buffer.String())
}
return list, nil
}
return nil, err
}
if r == '_' {
if buffer.Len() > 0 {
list = append(list, buffer.String())
}
buffer.Reset()
continue
}
if r >= 'A' && r <= 'Z' {
if buffer.Len() > 0 {
list = append(list, buffer.String())
}
buffer.Reset()
}
buffer.WriteRune(r)
}
}
func getStyle(flag string) (style, error) {
compare := strings.ToLower(flag)
switch flag {
case strings.ToLower(compare):
return lower, nil
case strings.ToUpper(compare):
return upper, nil
case strings.Title(compare):
return title, nil
default:
return unknown, fmt.Errorf("unexpected format: %s", flag)
}
}

View File

@@ -0,0 +1,112 @@
package format
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSplit(t *testing.T) {
list, err := split("A")
assert.Nil(t, err)
assert.Equal(t, []string{"A"}, list)
list, err = split("goZero")
assert.Nil(t, err)
assert.Equal(t, []string{"go", "Zero"}, list)
list, err = split("Gozero")
assert.Nil(t, err)
assert.Equal(t, []string{"Gozero"}, list)
list, err = split("go_zero")
assert.Nil(t, err)
assert.Equal(t, []string{"go", "zero"}, list)
list, err = split("talGo_zero")
assert.Nil(t, err)
assert.Equal(t, []string{"tal", "Go", "zero"}, list)
list, err = split("GOZERO")
assert.Nil(t, err)
assert.Equal(t, []string{"G", "O", "Z", "E", "R", "O"}, list)
list, err = split("gozero")
assert.Nil(t, err)
assert.Equal(t, []string{"gozero"}, list)
list, err = split("")
assert.Nil(t, err)
assert.Equal(t, 0, len(list))
list, err = split("a_b_CD_EF")
assert.Nil(t, err)
assert.Equal(t, []string{"a", "b", "C", "D", "E", "F"}, list)
list, err = split("_")
assert.Nil(t, err)
assert.Equal(t, 0, len(list))
list, err = split("__")
assert.Nil(t, err)
assert.Equal(t, 0, len(list))
list, err = split("_A")
assert.Nil(t, err)
assert.Equal(t, []string{"A"}, list)
list, err = split("_A_")
assert.Nil(t, err)
assert.Equal(t, []string{"A"}, list)
list, err = split("A_")
assert.Nil(t, err)
assert.Equal(t, []string{"A"}, list)
list, err = split("welcome_to_go_zero")
assert.Nil(t, err)
assert.Equal(t, []string{"welcome", "to", "go", "zero"}, list)
}
func TestFileNamingFormat(t *testing.T) {
testFileNamingFormat(t, "gozero", "welcome_to_go_zero", "welcometogozero")
testFileNamingFormat(t, "_go#zero_", "welcome_to_go_zero", "_welcome#to#go#zero_")
testFileNamingFormat(t, "Go#zero", "welcome_to_go_zero", "Welcome#to#go#zero")
testFileNamingFormat(t, "Go#Zero", "welcome_to_go_zero", "Welcome#To#Go#Zero")
testFileNamingFormat(t, "Go_Zero", "welcome_to_go_zero", "Welcome_To_Go_Zero")
testFileNamingFormat(t, "go_Zero", "welcome_to_go_zero", "welcome_To_Go_Zero")
testFileNamingFormat(t, "goZero", "welcome_to_go_zero", "welcomeToGoZero")
testFileNamingFormat(t, "GoZero", "welcome_to_go_zero", "WelcomeToGoZero")
testFileNamingFormat(t, "GOZero", "welcome_to_go_zero", "WELCOMEToGoZero")
testFileNamingFormat(t, "GoZERO", "welcome_to_go_zero", "WelcomeTOGOZERO")
testFileNamingFormat(t, "GOZERO", "welcome_to_go_zero", "WELCOMETOGOZERO")
testFileNamingFormat(t, "GO*ZERO", "welcome_to_go_zero", "WELCOME*TO*GO*ZERO")
testFileNamingFormat(t, "[GO#ZERO]", "welcome_to_go_zero", "[WELCOME#TO#GO#ZERO]")
testFileNamingFormat(t, "{go###zero}", "welcome_to_go_zero", "{welcome###to###go###zero}")
testFileNamingFormat(t, "{go###zerogo_zero}", "welcome_to_go_zero", "{welcome###to###go###zerogo_zero}")
testFileNamingFormat(t, "GogoZerozero", "welcome_to_go_zero", "WelcomegoTogoGogoZerozero")
testFileNamingFormat(t, "前缀GoZero后缀", "welcome_to_go_zero", "前缀WelcomeToGoZero后缀")
testFileNamingFormat(t, "GoZero", "welcometogozero", "Welcometogozero")
testFileNamingFormat(t, "GoZero", "WelcomeToGoZero", "WelcomeToGoZero")
testFileNamingFormat(t, "gozero", "WelcomeToGoZero", "welcometogozero")
testFileNamingFormat(t, "go_zero", "WelcomeToGoZero", "welcome_to_go_zero")
testFileNamingFormat(t, "Go_Zero", "WelcomeToGoZero", "Welcome_To_Go_Zero")
testFileNamingFormat(t, "Go_Zero", "", "")
testFileNamingFormatErr(t, "go", "")
testFileNamingFormatErr(t, "gOZero", "")
testFileNamingFormatErr(t, "zero", "")
testFileNamingFormatErr(t, "goZEro", "welcome_to_go_zero")
testFileNamingFormatErr(t, "goZERo", "welcome_to_go_zero")
testFileNamingFormatErr(t, "zerogo", "welcome_to_go_zero")
}
func testFileNamingFormat(t *testing.T, format, in, expected string) {
format, err := FileNamingFormat(format, in)
assert.Nil(t, err)
assert.Equal(t, expected, format)
}
func testFileNamingFormatErr(t *testing.T, format, in string) {
_, err := FileNamingFormat(format, in)
assert.Error(t, err)
}

View File

@@ -0,0 +1,41 @@
package name
import (
"strings"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
type NamingStyle = string
const (
NamingLower NamingStyle = "lower"
NamingCamel NamingStyle = "camel"
NamingSnake NamingStyle = "snake"
)
// IsNamingValid validates whether the namingStyle is valid or not,return
// namingStyle and true if it is valid, or else return empty string
// and false, and it is a valid value even namingStyle is empty string
func IsNamingValid(namingStyle string) (NamingStyle, bool) {
if len(namingStyle) == 0 {
namingStyle = NamingLower
}
switch namingStyle {
case NamingLower, NamingCamel, NamingSnake:
return namingStyle, true
default:
return "", false
}
}
func FormatFilename(filename string, style NamingStyle) string {
switch style {
case NamingCamel:
return stringx.From(filename).ToCamel()
case NamingSnake:
return stringx.From(filename).ToSnake()
default:
return strings.ToLower(stringx.From(filename).ToCamel())
}
}

View File

@@ -0,0 +1,35 @@
package name
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsNamingValid(t *testing.T) {
style, valid := IsNamingValid("")
assert.True(t, valid)
assert.Equal(t, NamingLower, style)
_, valid = IsNamingValid("lower1")
assert.False(t, valid)
_, valid = IsNamingValid("lower")
assert.True(t, valid)
_, valid = IsNamingValid("snake")
assert.True(t, valid)
_, valid = IsNamingValid("camel")
assert.True(t, valid)
}
func TestFormatFilename(t *testing.T) {
assert.Equal(t, "abc", FormatFilename("a_b_c", NamingLower))
assert.Equal(t, "ABC", FormatFilename("a_b_c", NamingCamel))
assert.Equal(t, "a_b_c", FormatFilename("a_b_c", NamingSnake))
assert.Equal(t, "a", FormatFilename("a", NamingSnake))
assert.Equal(t, "A", FormatFilename("a", NamingCamel))
// no flag to convert to snake
assert.Equal(t, "abc", FormatFilename("abc", NamingSnake))
}