Add strict flag (#2248)
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/zeromicro/ddl-parser/parser"
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/model/sql/converter"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/model/sql/model"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/model/sql/util"
|
||||
@@ -61,7 +62,7 @@ func parseNameOriginal(ts []*parser.Table) (nameOriginals [][]string) {
|
||||
}
|
||||
|
||||
// Parse parses ddl into golang structure
|
||||
func Parse(filename, database string) ([]*Table, error) {
|
||||
func Parse(filename, database string, strict bool) ([]*Table, error) {
|
||||
p := parser.NewParser()
|
||||
tables, err := p.From(filename)
|
||||
if err != nil {
|
||||
@@ -124,7 +125,7 @@ func Parse(filename, database string) ([]*Table, error) {
|
||||
return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
|
||||
}
|
||||
|
||||
primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
|
||||
primaryKey, fieldM, err := convertColumns(columns, primaryColumn, strict)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -190,7 +191,7 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string
|
||||
}
|
||||
}
|
||||
|
||||
func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) {
|
||||
func convertColumns(columns []*parser.Column, primaryColumn string, strict bool) (Primary, map[string]*Field, error) {
|
||||
var (
|
||||
primaryKey Primary
|
||||
fieldM = make(map[string]*Field)
|
||||
@@ -219,7 +220,7 @@ func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, ma
|
||||
}
|
||||
}
|
||||
|
||||
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned())
|
||||
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull, column.DataType.Unsigned(), strict)
|
||||
if err != nil {
|
||||
return Primary{}, nil, err
|
||||
}
|
||||
@@ -264,10 +265,10 @@ func (t *Table) ContainsTime() bool {
|
||||
}
|
||||
|
||||
// ConvertDataType converts mysql data type into golang data type
|
||||
func ConvertDataType(table *model.Table) (*Table, error) {
|
||||
func ConvertDataType(table *model.Table, strict bool) (*Table, error) {
|
||||
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
|
||||
isPrimaryUnsigned := strings.Contains(table.PrimaryKey.DbColumn.ColumnType, "unsigned")
|
||||
primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned)
|
||||
primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull, isPrimaryUnsigned, strict)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -292,7 +293,7 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
||||
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
|
||||
}
|
||||
|
||||
fieldM, err := getTableFields(table)
|
||||
fieldM, err := getTableFields(table, strict)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -342,12 +343,12 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
||||
return &reply, nil
|
||||
}
|
||||
|
||||
func getTableFields(table *model.Table) (map[string]*Field, error) {
|
||||
func getTableFields(table *model.Table, strict bool) (map[string]*Field, error) {
|
||||
fieldM := make(map[string]*Field)
|
||||
for _, each := range table.Columns {
|
||||
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
||||
isPrimaryUnsigned := strings.Contains(each.ColumnType, "unsigned")
|
||||
dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned)
|
||||
dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull, isPrimaryUnsigned, strict)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zeromicro/go-zero/tools/goctl/model/sql/model"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/model/sql/util"
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
@@ -17,7 +18,7 @@ func TestParsePlainText(t *testing.T) {
|
||||
err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0o777)
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = Parse(sqlFile, "go_zero")
|
||||
_, err = Parse(sqlFile, "go_zero", false)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
@@ -26,7 +27,7 @@ func TestParseSelect(t *testing.T) {
|
||||
err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0o777)
|
||||
assert.Nil(t, err)
|
||||
|
||||
tables, err := Parse(sqlFile, "go_zero")
|
||||
tables, err := Parse(sqlFile, "go_zero", false)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(tables))
|
||||
}
|
||||
@@ -39,7 +40,7 @@ func TestParseCreateTable(t *testing.T) {
|
||||
err := ioutil.WriteFile(sqlFile, []byte(user), 0o777)
|
||||
assert.Nil(t, err)
|
||||
|
||||
tables, err := Parse(sqlFile, "go_zero")
|
||||
tables, err := Parse(sqlFile, "go_zero", false)
|
||||
assert.Equal(t, 1, len(tables))
|
||||
table := tables[0]
|
||||
assert.Nil(t, err)
|
||||
|
||||
Reference in New Issue
Block a user