diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go index 2a78fc50..89cd49c9 100644 --- a/tools/goctl/goctl.go +++ b/tools/goctl/goctl.go @@ -419,6 +419,10 @@ var ( Name: "idea", Usage: "for idea plugin [optional]", }, + cli.StringFlag{ + Name: "database, db", + Usage: "the name of database [optional]", + }, }, Action: model.MysqlDDL, }, diff --git a/tools/goctl/model/sql/README.MD b/tools/goctl/model/sql/README.MD index 46af9ee7..6f98f630 100644 --- a/tools/goctl/model/sql/README.MD +++ b/tools/goctl/model/sql/README.MD @@ -264,6 +264,7 @@ OPTIONS: --style value the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md] --cache, -c generate code with cache [optional] --idea for idea plugin [optional] + --database, -db the name of database [optional] ``` * datasource diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index 213973e0..c3a228b0 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -24,6 +24,7 @@ const ( flagURL = "url" flagTable = "table" flagStyle = "style" + flagDatabase = "database" ) var errNotMatched = errors.New("sql not matched") @@ -35,12 +36,13 @@ func MysqlDDL(ctx *cli.Context) error { cache := ctx.Bool(flagCache) idea := ctx.Bool(flagIdea) style := ctx.String(flagStyle) + database := ctx.String(flagDatabase) cfg, err := config.NewConfig(style) if err != nil { return err } - return fromDDl(src, dir, cfg, cache, idea) + return fromDDl(src, dir, cfg, cache, idea, database) } // MyDataSource generates model code from datasource @@ -59,7 +61,7 @@ func MyDataSource(ctx *cli.Context) error { return fromDataSource(url, pattern, dir, cfg, cache, idea) } -func fromDDl(src, dir string, cfg *config.Config, cache, idea bool) error { +func fromDDl(src, dir string, cfg *config.Config, cache, idea bool, database string) error { log := console.NewConsole(idea) src = strings.TrimSpace(src) if len(src) == 0 { @@ -81,7 +83,7 @@ func fromDDl(src, dir string, cfg *config.Config, cache, idea bool) error { } for _, file := range files { - err = generator.StartFromDDL(file, cache) + err = generator.StartFromDDL(file, cache, database) if err != nil { return err } diff --git a/tools/goctl/model/sql/command/command_test.go b/tools/goctl/model/sql/command/command_test.go index 6dbb9917..d2495772 100644 --- a/tools/goctl/model/sql/command/command_test.go +++ b/tools/goctl/model/sql/command/command_test.go @@ -24,12 +24,12 @@ func TestFromDDl(t *testing.T) { err := gen.Clean() assert.Nil(t, err) - err = fromDDl("./user.sql", t.TempDir(), cfg, true, false) + err = fromDDl("./user.sql", t.TempDir(), cfg, true, false, "go_zero") assert.Equal(t, errNotMatched, err) // case dir is not exists unknownDir := filepath.Join(t.TempDir(), "test", "user.sql") - err = fromDDl(unknownDir, t.TempDir(), cfg, true, false) + err = fromDDl(unknownDir, t.TempDir(), cfg, true, false, "go_zero") assert.True(t, func() bool { switch err.(type) { case *os.PathError: @@ -40,7 +40,7 @@ func TestFromDDl(t *testing.T) { }()) // case empty src - err = fromDDl("", t.TempDir(), cfg, true, false) + err = fromDDl("", t.TempDir(), cfg, true, false, "go_zero") if err != nil { assert.Equal(t, "expected path or path globbing patterns, but nothing found", err.Error()) } @@ -70,7 +70,7 @@ func TestFromDDl(t *testing.T) { _, err = os.Stat(user2Sql) assert.Nil(t, err) - err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false) + err = fromDDl(filepath.Join(tempDir, "user*.sql"), tempDir, cfg, true, false, "go_zero") assert.Nil(t, err) _, err = os.Stat(filepath.Join(tempDir, "usermodel.go")) diff --git a/tools/goctl/model/sql/gen/gen.go b/tools/goctl/model/sql/gen/gen.go index 3cb4f287..5ed1c20c 100644 --- a/tools/goctl/model/sql/gen/gen.go +++ b/tools/goctl/model/sql/gen/gen.go @@ -90,8 +90,8 @@ func newDefaultOption() Option { } } -func (g *defaultGenerator) StartFromDDL(filename string, withCache bool) error { - modelList, err := g.genFromDDL(filename, withCache) +func (g *defaultGenerator) StartFromDDL(filename string, withCache bool, database string) error { + modelList, err := g.genFromDDL(filename, withCache, database) if err != nil { return err } @@ -174,9 +174,9 @@ func (g *defaultGenerator) createFile(modelList map[string]string) error { } // ret1: key-table name,value-code -func (g *defaultGenerator) genFromDDL(filename string, withCache bool) (map[string]string, error) { +func (g *defaultGenerator) genFromDDL(filename string, withCache bool, database string) (map[string]string, error) { m := make(map[string]string) - tables, err := parser.Parse(filename) + tables, err := parser.Parse(filename, database) if err != nil { return nil, err } diff --git a/tools/goctl/model/sql/gen/gen_test.go b/tools/goctl/model/sql/gen/gen_test.go index b748641e..a173a36d 100644 --- a/tools/goctl/model/sql/gen/gen_test.go +++ b/tools/goctl/model/sql/gen/gen_test.go @@ -34,7 +34,7 @@ func TestCacheModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(sqlFile, true) + err = g.StartFromDDL(sqlFile, true, "go_zero") assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go")) @@ -45,7 +45,7 @@ func TestCacheModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(sqlFile, false) + err = g.StartFromDDL(sqlFile, false, "go_zero") assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go")) @@ -72,7 +72,7 @@ func TestNamingModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(sqlFile, true) + err = g.StartFromDDL(sqlFile, true, "go_zero") assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go")) @@ -83,7 +83,7 @@ func TestNamingModel(t *testing.T) { }) assert.Nil(t, err) - err = g.StartFromDDL(sqlFile, true) + err = g.StartFromDDL(sqlFile, true, "go_zero") assert.Nil(t, err) assert.True(t, func() bool { _, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go")) diff --git a/tools/goctl/model/sql/gen/keys.go b/tools/goctl/model/sql/gen/keys.go index b7c0d441..9759c664 100644 --- a/tools/goctl/model/sql/gen/keys.go +++ b/tools/goctl/model/sql/gen/keys.go @@ -39,9 +39,9 @@ type Join []string func genCacheKeys(table parser.Table) (Key, []Key) { var primaryKey Key var uniqueKey []Key - primaryKey = genCacheKey(table.Name, []*parser.Field{&table.PrimaryKey.Field}) + primaryKey = genCacheKey(table.Db, table.Name, []*parser.Field{&table.PrimaryKey.Field}) for _, each := range table.UniqueIndex { - uniqueKey = append(uniqueKey, genCacheKey(table.Name, each)) + uniqueKey = append(uniqueKey, genCacheKey(table.Db, table.Name, each)) } sort.Slice(uniqueKey, func(i, j int) bool { return uniqueKey[i].VarLeft < uniqueKey[j].VarLeft @@ -50,7 +50,7 @@ func genCacheKeys(table parser.Table) (Key, []Key) { return primaryKey, uniqueKey } -func genCacheKey(table stringx.String, in []*parser.Field) Key { +func genCacheKey(db stringx.String, table stringx.String, in []*parser.Field) Key { var ( varLeftJoin, varRightJon, fieldNameJoin Join varLeft, varRight, varExpression string @@ -59,9 +59,9 @@ func genCacheKey(table stringx.String, in []*parser.Field) Key { keyLeft, keyRight, dataKeyRight, keyExpression, dataKeyExpression string ) - varLeftJoin = append(varLeftJoin, "cache", table.Source()) - varRightJon = append(varRightJon, "cache", table.Source()) - keyLeftJoin = append(keyLeftJoin, table.Source()) + varLeftJoin = append(varLeftJoin, "cache", db.Source(), table.Source()) + varRightJon = append(varRightJon, "cache", db.Source(), table.Source()) + keyLeftJoin = append(keyLeftJoin, db.Source(), table.Source()) for _, each := range in { varLeftJoin = append(varLeftJoin, each.Name.Source()) diff --git a/tools/goctl/model/sql/gen/keys_test.go b/tools/goctl/model/sql/gen/keys_test.go index 510ba82e..b1ec0264 100644 --- a/tools/goctl/model/sql/gen/keys_test.go +++ b/tools/goctl/model/sql/gen/keys_test.go @@ -36,6 +36,7 @@ func TestGenCacheKeys(t *testing.T) { } primariCacheKey, uniqueCacheKey := genCacheKeys(parser.Table{ Name: stringx.From("user"), + Db: stringx.From("go_zero"), PrimaryKey: parser.Primary{ Field: *primaryField, AutoIncrement: true, @@ -70,14 +71,14 @@ func TestGenCacheKeys(t *testing.T) { t.Run("primaryCacheKey", func(t *testing.T) { assert.Equal(t, true, func() bool { return cacheKeyEqual(primariCacheKey, Key{ - VarLeft: "cacheUserIdPrefix", - VarRight: `"cache:user:id:"`, - VarExpression: `cacheUserIdPrefix = "cache:user:id:"`, - KeyLeft: "userIdKey", - KeyRight: `fmt.Sprintf("%s%v", cacheUserIdPrefix, id)`, - DataKeyRight: `fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)`, - KeyExpression: `userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id)`, - DataKeyExpression: `userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)`, + VarLeft: "cacheGoZeroUserIdPrefix", + VarRight: `"cache:goZero:user:id:"`, + VarExpression: `cacheGoZeroUserIdPrefix = "cache:goZero:user:id:"`, + KeyLeft: "goZeroUserIdKey", + KeyRight: `fmt.Sprintf("%s%v", cacheGoZeroUserIdPrefix, id)`, + DataKeyRight: `fmt.Sprintf("%s%v", cacheGoZeroUserIdPrefix, data.Id)`, + KeyExpression: `goZeroUserIdKey := fmt.Sprintf("%s%v", cacheGoZeroUserIdPrefix, id)`, + DataKeyExpression: `goZeroUserIdKey := fmt.Sprintf("%s%v", cacheGoZeroUserIdPrefix, data.Id)`, FieldNameJoin: []string{"id"}, }) }()) @@ -87,25 +88,25 @@ func TestGenCacheKeys(t *testing.T) { assert.Equal(t, true, func() bool { expected := []Key{ { - VarLeft: "cacheUserClassNamePrefix", - VarRight: `"cache:user:class:name:"`, - VarExpression: `cacheUserClassNamePrefix = "cache:user:class:name:"`, - KeyLeft: "userClassNameKey", - KeyRight: `fmt.Sprintf("%s%v:%v", cacheUserClassNamePrefix, class, name)`, - DataKeyRight: `fmt.Sprintf("%s%v:%v", cacheUserClassNamePrefix, data.Class, data.Name)`, - KeyExpression: `userClassNameKey := fmt.Sprintf("%s%v:%v", cacheUserClassNamePrefix, class, name)`, - DataKeyExpression: `userClassNameKey := fmt.Sprintf("%s%v:%v", cacheUserClassNamePrefix, data.Class, data.Name)`, + VarLeft: "cacheGoZeroUserClassNamePrefix", + VarRight: `"cache:goZero:user:class:name:"`, + VarExpression: `cacheGoZeroUserClassNamePrefix = "cache:goZero:user:class:name:"`, + KeyLeft: "goZeroUserClassNameKey", + KeyRight: `fmt.Sprintf("%s%v:%v", cacheGoZeroUserClassNamePrefix, class, name)`, + DataKeyRight: `fmt.Sprintf("%s%v:%v", cacheGoZeroUserClassNamePrefix, data.Class, data.Name)`, + KeyExpression: `goZeroUserClassNameKey := fmt.Sprintf("%s%v:%v", cacheGoZeroUserClassNamePrefix, class, name)`, + DataKeyExpression: `goZeroUserClassNameKey := fmt.Sprintf("%s%v:%v", cacheGoZeroUserClassNamePrefix, data.Class, data.Name)`, FieldNameJoin: []string{"class", "name"}, }, { - VarLeft: "cacheUserMobilePrefix", - VarRight: `"cache:user:mobile:"`, - VarExpression: `cacheUserMobilePrefix = "cache:user:mobile:"`, - KeyLeft: "userMobileKey", - KeyRight: `fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)`, - DataKeyRight: `fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)`, - KeyExpression: `userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)`, - DataKeyExpression: `userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)`, + VarLeft: "cacheGoZeroUserMobilePrefix", + VarRight: `"cache:goZero:user:mobile:"`, + VarExpression: `cacheGoZeroUserMobilePrefix = "cache:goZero:user:mobile:"`, + KeyLeft: "goZeroUserMobileKey", + KeyRight: `fmt.Sprintf("%s%v", cacheGoZeroUserMobilePrefix, mobile)`, + DataKeyRight: `fmt.Sprintf("%s%v", cacheGoZeroUserMobilePrefix, data.Mobile)`, + KeyExpression: `goZeroUserMobileKey := fmt.Sprintf("%s%v", cacheGoZeroUserMobilePrefix, mobile)`, + DataKeyExpression: `goZeroUserMobileKey := fmt.Sprintf("%s%v", cacheGoZeroUserMobilePrefix, data.Mobile)`, FieldNameJoin: []string{"mobile"}, }, } diff --git a/tools/goctl/model/sql/parser/parser.go b/tools/goctl/model/sql/parser/parser.go index 53a16a8e..233e3125 100644 --- a/tools/goctl/model/sql/parser/parser.go +++ b/tools/goctl/model/sql/parser/parser.go @@ -21,6 +21,7 @@ type ( // Table describes a mysql table Table struct { Name stringx.String + Db stringx.String PrimaryKey Primary UniqueIndex map[string][]*Field Fields []*Field @@ -46,7 +47,7 @@ type ( ) // Parse parses ddl into golang structure -func Parse(filename string) ([]*Table, error) { +func Parse(filename string, database string) ([]*Table, error) { p := parser.NewParser() tables, err := p.From(filename) if err != nil { @@ -145,6 +146,7 @@ func Parse(filename string) ([]*Table, error) { list = append(list, &Table{ Name: stringx.From(e.Name), + Db: stringx.From(database), PrimaryKey: primaryKey, UniqueIndex: uniqueIndex, Fields: fields, @@ -243,6 +245,7 @@ func ConvertDataType(table *model.Table) (*Table, error) { var reply Table reply.UniqueIndex = map[string][]*Field{} reply.Name = stringx.From(table.Table) + reply.Db = stringx.From(table.Db) seqInIndex := 0 if table.PrimaryKey.Index != nil { seqInIndex = table.PrimaryKey.Index.SeqInIndex diff --git a/tools/goctl/model/sql/parser/parser_test.go b/tools/goctl/model/sql/parser/parser_test.go index b31dc232..25485fb9 100644 --- a/tools/goctl/model/sql/parser/parser_test.go +++ b/tools/goctl/model/sql/parser/parser_test.go @@ -15,7 +15,7 @@ func TestParsePlainText(t *testing.T) { err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777) assert.Nil(t, err) - _, err = Parse(sqlFile) + _, err = Parse(sqlFile, "go_zero") assert.NotNil(t, err) } @@ -24,7 +24,7 @@ func TestParseSelect(t *testing.T) { err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777) assert.Nil(t, err) - tables, err := Parse(sqlFile) + tables, err := Parse(sqlFile, "go_zero") assert.Nil(t, err) assert.Equal(t, 0, len(tables)) } @@ -34,7 +34,7 @@ func TestParseCreateTable(t *testing.T) { err := ioutil.WriteFile(sqlFile, []byte("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机 号',\n `class` bigint NOT NULL comment '班级',\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n 名',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"), 0o777) assert.Nil(t, err) - tables, err := Parse(sqlFile) + tables, err := Parse(sqlFile, "go_zero") assert.Equal(t, 1, len(tables)) table := tables[0] assert.Nil(t, err)