diff --git a/tools/goctl/goctl.go b/tools/goctl/goctl.go index 445dfc60..2bf6b32f 100644 --- a/tools/goctl/goctl.go +++ b/tools/goctl/goctl.go @@ -725,7 +725,7 @@ var commands = []cli.Command{ Name: "url", Usage: `the data source of database,like "root:password@tcp(127.0.0.1:3306)/database"`, }, - cli.StringFlag{ + cli.StringSliceFlag{ Name: "table, t", Usage: `the table or table globbing patterns in the database`, }, diff --git a/tools/goctl/model/sql/command/command.go b/tools/goctl/model/sql/command/command.go index 326be929..eac3f014 100644 --- a/tools/goctl/model/sql/command/command.go +++ b/tools/goctl/model/sql/command/command.go @@ -87,13 +87,51 @@ func MySqlDataSource(ctx *cli.Context) error { pathx.RegisterGoctlHome(home) } - pattern := strings.TrimSpace(ctx.String(flagTable)) + tableValue := ctx.StringSlice(flagTable) + patterns := parseTableList(tableValue) cfg, err := config.NewConfig(style) if err != nil { return err } - return fromMysqlDataSource(url, pattern, dir, cfg, cache, idea) + return fromMysqlDataSource(url, dir, patterns, cfg, cache, idea) +} + +type pattern map[string]struct{} + +func (p pattern) Match(s string) bool { + for v := range p { + match, err := filepath.Match(v, s) + if err != nil { + console.Error("%+v", err) + continue + } + if match { + return true + } + } + return false +} + +func (p pattern) list() []string { + var ret []string + for v := range p { + ret = append(ret, v) + } + return ret +} + +func parseTableList(tableValue []string) pattern { + tablePattern := make(pattern) + for _, v := range tableValue { + fields := strings.FieldsFunc(v, func(r rune) bool { + return r == ',' + }) + for _, f := range fields { + tablePattern[f] = struct{}{} + } + } + return tablePattern } // PostgreSqlDataSource generates model code from datasource @@ -162,14 +200,14 @@ func fromDDL(src, dir string, cfg *config.Config, cache, idea bool, database str return nil } -func fromMysqlDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bool) error { +func fromMysqlDataSource(url, dir string, tablePat pattern, cfg *config.Config, cache, idea bool) error { log := console.NewConsole(idea) if len(url) == 0 { log.Error("%v", "expected data source of mysql, but nothing found") return nil } - if len(pattern) == 0 { + if len(tablePat) == 0 { log.Error("%v", "expected table or table globbing patterns, but nothing found") return nil } @@ -191,12 +229,7 @@ func fromMysqlDataSource(url, pattern, dir string, cfg *config.Config, cache, id matchTables := make(map[string]*model.Table) for _, item := range tables { - match, err := filepath.Match(pattern, item) - if err != nil { - return err - } - - if !match { + if !tablePat.Match(item) { continue } diff --git a/tools/goctl/model/sql/command/command_test.go b/tools/goctl/model/sql/command/command_test.go index ee363213..bd35713f 100644 --- a/tools/goctl/model/sql/command/command_test.go +++ b/tools/goctl/model/sql/command/command_test.go @@ -5,6 +5,8 @@ import ( "io/ioutil" "os" "path/filepath" + "sort" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -86,3 +88,30 @@ func TestFromDDl(t *testing.T) { _ = os.Remove(filename) fromDDL("1gozero") } + +func Test_parseTableList(t *testing.T) { + testData := []string{"foo", "b*", "bar", "back_up", "foo,bar,b*"} + patterns := parseTableList(testData) + actual := patterns.list() + expected := []string{"foo", "b*", "bar", "back_up"} + sort.Slice(actual, func(i, j int) bool { + return actual[i] > actual[j] + }) + sort.Slice(expected, func(i, j int) bool { + return expected[i] > expected[j] + }) + assert.Equal(t, strings.Join(expected, ","), strings.Join(actual, ",")) + + matchTestData := map[string]bool{ + "foo": true, + "bar": true, + "back_up": true, + "bit": true, + "ab": false, + "b": true, + } + for v, expected := range matchTestData { + actual := patterns.Match(v) + assert.Equal(t, expected, actual) + } +}