diff --git a/config.yml b/config.yml index 2b333c7..4909fe9 100644 --- a/config.yml +++ b/config.yml @@ -14,6 +14,8 @@ is_gui : false # 是否ui模式显示 is_table_name : true # 是否直接生成表名,列名 is_null_to_point : false # 数据库默认 'DEFAULT NULL' 时设置结构为指针类型 table_prefix : "" # 表前缀, 如果有则使用, 没有留空 +table_names: "" # 指定表生成,多个表用,隔开 + db_info: host : 127.0.0.1 # type=1的时候,host为yml文件全路径 port : 3306 diff --git a/data/cmd/cmd.go b/data/cmd/cmd.go index 3226fda..5282182 100644 --- a/data/cmd/cmd.go +++ b/data/cmd/cmd.go @@ -2,6 +2,7 @@ package cmd import ( "os" + "strings" "github.com/xxjwxc/public/mylog" @@ -67,6 +68,9 @@ func init() { rootCmd.Flags().Int("port", 3306, "端口号") rootCmd.Flags().StringP("table_prefix", "t", "", "表前缀") + //增加表名称 + rootCmd.Flags().StringP("table_names", "b", "", "表名称") + } // initConfig reads in config file and ENV variables if set. @@ -117,4 +121,13 @@ func MergeMysqlDbInfo() { tablePrefix := config.GetTablePrefix() mycobra.IfReplace(rootCmd, "table_prefix", &tablePrefix) // 如果设置了,更新 config.SetTablePrefix(tablePrefix) + + //更新tableNames + tableNames := config.GetTableNames() + if tableNames != "" { + tableNames = strings.Replace(tableNames, "'", "", -1) + } + mycobra.IfReplace(rootCmd, "table_names", &tableNames) // 如果设置了,更新 + config.SetTableNames(tableNames) + } diff --git a/data/config/MyIni.go b/data/config/MyIni.go index f8e2aa0..bd3a75d 100644 --- a/data/config/MyIni.go +++ b/data/config/MyIni.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "strings" "github.com/xxjwxc/public/tools" ) @@ -27,6 +28,7 @@ type Config struct { SelfTypeDef map[string]string `yaml:"self_type_define"` OutFileName string `yaml:"out_file_name"` WebTagType int `yaml:"web_tag_type"` // 默认小驼峰 + TableNames string `yaml:"table_names"` // 表名(多个表名用","隔开) } // DBInfo mysql database information. mysql 数据库信息 @@ -254,3 +256,40 @@ func SetWebTagType(i int) { func GetWebTagType() int { return _map.WebTagType } + +//获取设置的表名 +func GetTableNames() string { + var sb strings.Builder + if _map.TableNames != "" { + tableNames := _map.TableNames + tableNames = strings.TrimLeft(tableNames, ",") + tableNames = strings.TrimRight(tableNames, ",") + if tableNames == "" { + return "" + } + + sarr := strings.Split(_map.TableNames, ",") + if len(sarr) == 0 { + fmt.Printf("tableNames is vailed, genmodel will by default global") + return "" + } + + for i, val := range sarr { + sb.WriteString(fmt.Sprintf("'%s'", val)) + if i != len(sarr)-1 { + sb.WriteString(",") + } + } + } + return sb.String() +} + +//获取设置的表名 +func GetOriginTableNames() string { + return _map.TableNames +} + +//设置生成的表名 +func SetTableNames(tableNames string) { + _map.TableNames = tableNames +} diff --git a/data/config/common.go b/data/config/common.go index 177d4cb..65a2085 100644 --- a/data/config/common.go +++ b/data/config/common.go @@ -46,6 +46,7 @@ var _map = Config{ TablePrefix: "", SelfTypeDef: make(map[string]string), WebTagType: 0, + TableNames: "", } var configPath string diff --git a/data/view/model/def_ifs.go b/data/view/model/def_ifs.go index 6c34f64..789d3bf 100644 --- a/data/view/model/def_ifs.go +++ b/data/view/model/def_ifs.go @@ -4,5 +4,6 @@ package model type IModel interface { GenModel() DBInfo GetDbName() string - GetPkgName() string // Getting package names through config outdir configuration.通过config outdir 配置获取包名 + GetPkgName() string // Getting package names through config outdir configuration.通过config outdir 配置获取包名 + GetTableNames() string //获取设置表名 } diff --git a/data/view/model/gencnf/gencnf.go b/data/view/model/gencnf/gencnf.go index 52b628a..8b712c5 100644 --- a/data/view/model/gencnf/gencnf.go +++ b/data/view/model/gencnf/gencnf.go @@ -51,6 +51,11 @@ func (m *cnfModel) GenModel() model.DBInfo { return dbInfo } +// GetTableNames get table name.获取指定的表名 +func (m *cnfModel) GetTableNames() string { + return config.GetTableNames() +} + // GetDbName get database name.获取数据库名字 func (m *cnfModel) GetDbName() string { dir := config.GetDbInfo().Host diff --git a/data/view/model/genmysql/genmysql.go b/data/view/model/genmysql/genmysql.go index 2e31d6e..361275e 100644 --- a/data/view/model/genmysql/genmysql.go +++ b/data/view/model/genmysql/genmysql.go @@ -1,6 +1,7 @@ package genmysql import ( + "database/sql" "fmt" "sort" "strings" @@ -34,6 +35,16 @@ func (m *mysqlModel) GetDbName() string { return config.GetDbInfo().Database } +// GetTableNames get table name.获取格式化后指定的表名 +func (m *mysqlModel) GetTableNames() string { + return config.GetTableNames() +} + +// GetTableNames get table name.获取原始指定的表名 +func (m *mysqlModel) GetOriginTableNames() string { + return config.GetOriginTableNames() +} + // GetPkgName package names through config outdir configuration.通过config outdir 配置获取包名 func (m *mysqlModel) GetPkgName() string { dir := config.GetOutDir() @@ -73,6 +84,7 @@ func (m *mysqlModel) getPackageInfo(orm *mysqldb.MySqlDB, info *model.DBInfo) { // } // tabls = newTabls // } + fmt.Println(tabls) for tabName, notes := range tabls { var tab model.TabInfo tab.Name = tabName @@ -99,6 +111,7 @@ func (m *mysqlModel) getPackageInfo(orm *mysqldb.MySqlDB, info *model.DBInfo) { info.TabList = append(info.TabList, tab) } + fmt.Println(info.TabList) // sort tables sort.Slice(info.TabList, func(i, j int) bool { return info.TabList[i].Name < info.TabList[j].Name @@ -207,24 +220,42 @@ func (m *mysqlModel) getTables(orm *mysqldb.MySqlDB) map[string]string { // Get column names.获取列名 var tables []string - rows, err := orm.Raw("show tables").Rows() - if err != nil { - if !config.GetIsGUI() { - fmt.Println(err) + if m.GetOriginTableNames() != "" { + sarr := strings.Split(m.GetOriginTableNames(), ",") + if len(sarr) != 0 { + for _, val := range sarr { + tbDesc[val] = "" + } + } + } else { + rows, err := orm.Raw("show tables").Rows() + if err != nil { + if !config.GetIsGUI() { + fmt.Println(err) + } + return tbDesc } - return tbDesc - } - for rows.Next() { - var table string - rows.Scan(&table) - tables = append(tables, table) - tbDesc[table] = "" + for rows.Next() { + var table string + rows.Scan(&table) + tables = append(tables, table) + tbDesc[table] = "" + } + rows.Close() } - rows.Close() // Get table annotations.获取表注释 - rows1, err := orm.Raw("SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.TABLES WHERE table_schema= '" + m.GetDbName() + "'").Rows() + var err error + var rows1 *sql.Rows + if m.GetTableNames() != "" { + rows1, err = orm.Raw("SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.TABLES WHERE table_schema= '" + m.GetDbName() + "'and TABLE_NAME IN(" + m.GetTableNames() + ")").Rows() + fmt.Println("getTables:" + m.GetTableNames()) + fmt.Println("SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.TABLES WHERE table_schema= '" + m.GetDbName() + "'and TABLE_NAME IN(" + m.GetTableNames() + ")") + } else { + rows1, err = orm.Raw("SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.TABLES WHERE table_schema= '" + m.GetDbName() + "'").Rows() + } + if err != nil { if !config.GetIsGUI() { fmt.Println(err) diff --git a/data/view/model/gensqlite/gensqlite.go b/data/view/model/gensqlite/gensqlite.go index 801d348..736d8c3 100644 --- a/data/view/model/gensqlite/gensqlite.go +++ b/data/view/model/gensqlite/gensqlite.go @@ -65,6 +65,11 @@ func (m *sqliteModel) GetDbName() string { return dbName } +// GetTableNames get table name.获取指定的表名 +func (m *sqliteModel) GetTableNames() string { + return config.GetTableNames() +} + // GetPkgName package names through config outdir configuration.通过config outdir 配置获取包名 func (m *sqliteModel) GetPkgName() string { dir := config.GetOutDir()