Add strict flag (#2248)

Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
This commit is contained in:
anqiansong
2022-08-28 18:55:52 +08:00
committed by GitHub
parent a1466e1707
commit f70805ee60
9 changed files with 126 additions and 57 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/zeromicro/ddl-parser/parser"
"github.com/zeromicro/go-zero/core/collection"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/converter"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/model"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/util"
@@ -61,7 +62,7 @@ func parseNameOriginal(ts []*parser.Table) (nameOriginals [][]string) {
}
// Parse parses ddl into golang structure
func Parse(filename, database string) ([]*Table, error) {
func Parse(filename, database string, strict bool) ([]*Table, error) {
p := parser.NewParser()
tables, err := p.From(filename)
if err != nil {
@@ -124,7 +125,7 @@ func Parse(filename, database string) ([]*Table, error) {
return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
}
primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
primaryKey, fieldM, err := convertColumns(columns, primaryColumn, strict)
if err != nil {
return nil, err
}
@@ -190,7 +191,7 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string
}
}
func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) {
func convertColumns(columns []*parser.Column, primaryColumn string, strict bool) (Primary, map[string]*Field, error) {
var (
primaryKey Primary
fieldM = make(map[string]*Field)
@@ -219,7 +220,7 @@ func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, ma
}
}
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned())
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict)
if err != nil {
return Primary{}, nil, err
}
@@ -264,10 +265,10 @@ func (t *Table) ContainsTime() bool {
}
// ConvertDataType converts mysql data type into golang data type
func ConvertDataType(table *model.Table) (*Table, error) {
func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(table.PrimaryKey.DbColumn.ColumnType, "unsigned")
primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned)
primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict)
if err != nil {
return nil, err
}
@@ -292,7 +293,7 @@ func ConvertDataType(table *model.Table) (*Table, error) {
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
}
fieldM, err := getTableFields(table)
fieldM, err := getTableFields(table, strict)
if err != nil {
return nil, err
}
@@ -342,12 +343,12 @@ func ConvertDataType(table *model.Table) (*Table, error) {
return &reply, nil
}
func getTableFields(table *model.Table) (map[string]*Field, error) {
func getTableFields(table *model.Table, strict bool) (map[string]*Field, error) {
fieldM := make(map[string]*Field)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
isPrimaryUnsigned := strings.Contains(each.ColumnType, "unsigned")
dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned)
dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict)
if err != nil {
return nil, err
}

View File

@@ -7,6 +7,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/model"
"github.com/zeromicro/go-zero/tools/goctl/model/sql/util"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
@@ -17,7 +18,7 @@ func TestParsePlainText(t *testing.T) {
err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777)
assert.Nil(t, err)
_, err = Parse(sqlFile, "go_zero")
_, err = Parse(sqlFile, "go_zero", false)
assert.NotNil(t, err)
}
@@ -26,7 +27,7 @@ func TestParseSelect(t *testing.T) {
err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777)
assert.Nil(t, err)
tables, err := Parse(sqlFile, "go_zero")
tables, err := Parse(sqlFile, "go_zero", false)
assert.Nil(t, err)
assert.Equal(t, 0, len(tables))
}
@@ -39,7 +40,7 @@ func TestParseCreateTable(t *testing.T) {
err := ioutil.WriteFile(sqlFile, []byte(user), 0o777)
assert.Nil(t, err)
tables, err := Parse(sqlFile, "go_zero")
tables, err := Parse(sqlFile, "go_zero", false)
assert.Equal(t, 1, len(tables))
table := tables[0]
assert.Nil(t, err)