Add strict flag (#2248)

Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
This commit is contained in:
anqiansong
2022-08-28 18:55:52 +08:00
committed by GitHub
parent a1466e1707
commit f70805ee60
9 changed files with 126 additions and 57 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/postgres"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/command/migrationnotes"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/gen"
@@ -47,6 +48,8 @@ var (
VarStringRemote string
// VarStringBranch describes the git branch of the repository.
VarStringBranch string
// VarBoolStrict describes whether the strict mode is enabled.
VarBoolStrict bool
)
var errNotMatched = errors.New("sql not matched")
@@ -77,7 +80,16 @@ func MysqlDDL(_ *cobra.Command, _ []string) error {
return err
}
return fromDDL(src, dir, cfg, cache, idea, database)
arg := ddlArg{
src: src,
dir: dir,
cfg: cfg,
cache: cache,
idea: idea,
database: database,
strict: VarBoolStrict,
}
return fromDDL(arg)
}
// MySqlDataSource generates model code from datasource
@@ -108,7 +120,16 @@ func MySqlDataSource(_ *cobra.Command, _ []string) error {
return err
}
return fromMysqlDataSource(url, dir, patterns, cfg, cache, idea)
arg := dataSourceArg{
url: url,
dir: dir,
tablePat: patterns,
cfg: cfg,
cache: cache,
idea: idea,
strict: VarBoolStrict,
}
return fromMysqlDataSource(arg)
}
type pattern map[string]struct{}
@@ -180,12 +201,20 @@ func PostgreSqlDataSource(_ *cobra.Command, _ []string) error {
return err
}
return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea)
return fromPostgreSqlDataSource(url, pattern, dir, schema, cfg, cache, idea, VarBoolStrict)
}
func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database string) error {
log := console.NewConsole(idea)
src = strings.TrimSpace(src)
type ddlArg struct {
src, dir string
cfg *config.Config
cache, idea bool
database string
strict bool
}
func fromDDL(arg ddlArg) error {
log := console.NewConsole(arg.idea)
src := strings.TrimSpace(arg.src)
if len(src) == 0 {
return errors.New("expected path or path globbing patterns, but nothing found")
}
@@ -199,13 +228,13 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str
return errNotMatched
}
generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log))
generator, err := gen.NewDefaultGenerator(arg.dir, arg.cfg, gen.WithConsoleOption(log))
if err != nil {
return err
}
for _, file := range files {
err = generator.StartFromDDL(file, cache, database)
err = generator.StartFromDDL(file, arg.cache, arg.strict, arg.database)
if err != nil {
return err
}
@@ -214,25 +243,33 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str
return nil
}
func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config, cache, idea bool) error {
log := console.NewConsole(idea)
if len(url) == 0 {
type dataSourceArg struct {
url, dir string
tablePat pattern
cfg *config.Config
cache, idea bool
strict bool
}
func fromMysqlDataSource(arg dataSourceArg) error {
log := console.NewConsole(arg.idea)
if len(arg.url) == 0 {
log.Error("%v", "expected data source of mysql, but nothing found")
return nil
}
if len(tablePat) == 0 {
if len(arg.tablePat) == 0 {
log.Error("%v", "expected table or table globbing patterns, but nothing found")
return nil
}
dsn, err := mysql.ParseDSN(url)
dsn, err := mysql.ParseDSN(arg.url)
if err != nil {
return err
}
logx.Disable()
databaseSource := strings.TrimSuffix(url, "/"+dsn.DBName) + "/information_schema"
databaseSource := strings.TrimSuffix(arg.url, "/"+dsn.DBName) + "/information_schema"
db := sqlx.NewMysql(databaseSource)
im := model.NewInformationSchemaModel(db)
@@ -243,7 +280,7 @@ func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config,
matchTables := make(map[string]*model.Table)
for _, item := range tables {
if !tablePat.Match(item) {
if !arg.tablePat.Match(item) {
continue
}
@@ -264,15 +301,15 @@ func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config,
return errors.New("no tables matched")
}
generator, err := gen.NewDefaultGenerator(dir, cfg, gen.WithConsoleOption(log))
generator, err := gen.NewDefaultGenerator(arg.dir, arg.cfg, gen.WithConsoleOption(log))
if err != nil {
return err
}
return generator.StartFromInformationSchema(matchTables, cache)
return generator.StartFromInformationSchema(matchTables, arg.cache, arg.strict)
}
func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea bool) error {
func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Config, cache, idea, strict bool) error {
log := console.NewConsole(idea)
if len(url) == 0 {
log.Error("%v", "expected data source of postgresql, but nothing found")
@@ -324,5 +361,5 @@ func fromPostgreSqlDataSource(url, pattern, dir, schema string, cfg *config.Conf
return err
}
return generator.StartFromInformationSchema(matchTables, cache)
return generator.StartFromInformationSchema(matchTables, cache, strict)
}

View File

@@ -10,6 +10,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/gen"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
@@ -27,12 +28,25 @@ func TestFromDDl(t *testing.T) {
err := gen.Clean()
assert.Nil(t, err)
err = fromDDL("./user.sql", pathx.MustTempDir(), cfg, true, false, "go_zero")
err = fromDDL(ddlArg{
src: "./user.sql",
dir: pathx.MustTempDir(),
cfg: cfg,
cache: true,
database: "go-zero",
strict: false,
})
assert.Equal(t, errNotMatched, err)
// case dir is not exists
unknownDir := filepath.Join(pathx.MustTempDir(), "test", "user.sql")
err = fromDDL(unknownDir, pathx.MustTempDir(), cfg, true, false, "go_zero")
err = fromDDL(ddlArg{
src: unknownDir,
dir: pathx.MustTempDir(),
cfg: cfg,
cache: true,
database: "go_zero",
})
assert.True(t, func() bool {
switch err.(type) {
case *os.PathError:
@@ -43,7 +57,12 @@ func TestFromDDl(t *testing.T) {
}())
// case empty src
err = fromDDL("", pathx.MustTempDir(), cfg, true, false, "go_zero")
err = fromDDL(ddlArg{
dir: pathx.MustTempDir(),
cfg: cfg,
cache: true,
database: "go_zero",
})
if err != nil {
assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error())
}
@@ -75,7 +94,13 @@ func TestFromDDl(t *testing.T) {
filename := filepath.Join(tempDir, "usermodel.go")
fromDDL := func(db string) {
err = fromDDL(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, db)
err = fromDDL(ddlArg{
src: filepath.Join(tempDir, "user*.sql"),
dir: tempDir,
cfg: cfg,
cache: true,
database: db,
})
assert.Nil(t, err)
_, err = os.Stat(filename)