Fix issues: #725, #740 (#813)

* Fix issues: #725, #740

* Update filed sort

Co-authored-by: anqiansong <anqiansong@xiaoheiban.cn>
This commit is contained in:
anqiansong
2021-07-16 22:55:39 +08:00
committed by GitHub
parent db87fd3239
commit 9b2a279948
11 changed files with 258 additions and 311 deletions

View File

@@ -1,11 +0,0 @@
package parser
import (
"errors"
)
var (
errUnsupportDDL = errors.New("unexpected type")
errTableBodyNotFound = errors.New("create table spec not found")
errPrimaryKey = errors.New("unexpected join primary key")
)

View File

@@ -2,6 +2,7 @@ package parser
import (
"fmt"
"path/filepath"
"sort"
"strings"
@@ -11,7 +12,7 @@ import (
"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
"github.com/tal-tech/go-zero/tools/goctl/util/console"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
"github.com/xwb1989/sqlparser"
"github.com/zeromicro/ddl-parser/parser"
)
const timeImport = "time.Time"
@@ -22,7 +23,6 @@ type (
Name stringx.String
PrimaryKey Primary
UniqueIndex map[string][]*Field
NormalIndex map[string][]*Field
Fields []*Field
}
@@ -35,7 +35,6 @@ type (
// Field describes a table field
Field struct {
Name stringx.String
DataBaseType string
DataType string
Comment string
SeqInIndex int
@@ -47,73 +46,115 @@ type (
)
// Parse parses ddl into golang structure
func Parse(ddl string) (*Table, error) {
stmt, err := sqlparser.ParseStrictDDL(ddl)
func Parse(filename string) ([]*Table, error) {
p := parser.NewParser()
tables, err := p.From(filename)
if err != nil {
return nil, err
}
ddlStmt, ok := stmt.(*sqlparser.DDL)
if !ok {
return nil, errUnsupportDDL
indexNameGen := func(column ...string) string {
return strings.Join(column, "_")
}
action := ddlStmt.Action
if action != sqlparser.CreateStr {
return nil, fmt.Errorf("expected [CREATE] action,but found: %s", action)
}
prefix := filepath.Base(filename)
var list []*Table
for _, e := range tables {
columns := e.Columns
tableName := ddlStmt.NewName.Name.String()
tableSpec := ddlStmt.TableSpec
if tableSpec == nil {
return nil, errTableBodyNotFound
}
var (
primaryColumnSet = collection.NewSet()
columns := tableSpec.Columns
indexes := tableSpec.Indexes
primaryColumn, uniqueKeyMap, normalKeyMap, err := convertIndexes(indexes)
if err != nil {
return nil, err
}
primaryColumn string
uniqueKeyMap = make(map[string][]string)
normalKeyMap = make(map[string][]string)
)
primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
if err != nil {
return nil, err
}
for _, column := range columns {
if column.Constraint != nil {
if column.Constraint.Primary {
primaryColumnSet.AddStr(column.Name)
}
var fields []*Field
for _, e := range fieldM {
fields = append(fields, e)
}
if column.Constraint.Unique {
indexName := indexNameGen(column.Name, "unique")
uniqueKeyMap[indexName] = []string{column.Name}
}
var (
uniqueIndex = make(map[string][]*Field)
normalIndex = make(map[string][]*Field)
)
for indexName, each := range uniqueKeyMap {
for _, columnName := range each {
uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
if column.Constraint.Key {
indexName := indexNameGen(column.Name, "idx")
uniqueKeyMap[indexName] = []string{column.Name}
}
}
}
}
for indexName, each := range normalKeyMap {
for _, columnName := range each {
normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
for _, e := range e.Constraints {
if len(e.ColumnPrimaryKey) > 1 {
return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
}
if len(e.ColumnPrimaryKey) == 1 {
primaryColumn = e.ColumnPrimaryKey[0]
primaryColumnSet.AddStr(e.ColumnPrimaryKey[0])
}
if len(e.ColumnUniqueKey) > 0 {
list := append([]string(nil), e.ColumnUniqueKey...)
list = append(list, "unique")
indexName := indexNameGen(list...)
uniqueKeyMap[indexName] = e.ColumnUniqueKey
}
}
if primaryColumnSet.Count() > 1 {
return nil, fmt.Errorf("%s: unexpected join primary key", prefix)
}
primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
if err != nil {
return nil, err
}
var fields []*Field
// sort
for _, c := range columns {
field, ok := fieldM[c.Name]
if ok {
fields = append(fields, field)
}
}
var (
uniqueIndex = make(map[string][]*Field)
normalIndex = make(map[string][]*Field)
)
for indexName, each := range uniqueKeyMap {
for _, columnName := range each {
uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
}
}
for indexName, each := range normalKeyMap {
for _, columnName := range each {
normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
}
}
checkDuplicateUniqueIndex(uniqueIndex, e.Name)
list = append(list, &Table{
Name: stringx.From(e.Name),
PrimaryKey: primaryKey,
UniqueIndex: uniqueIndex,
Fields: fields,
})
}
checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
return &Table{
Name: stringx.From(tableName),
PrimaryKey: primaryKey,
UniqueIndex: uniqueIndex,
NormalIndex: normalIndex,
Fields: fields,
}, nil
return list, nil
}
func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string) {
log := console.NewColorConsole()
uniqueSet := collection.NewSet()
for k, i := range uniqueIndex {
@@ -131,26 +172,9 @@ func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string
uniqueSet.AddStr(joinRet)
}
normalIndexSet := collection.NewSet()
for k, i := range normalIndex {
var list []string
for _, e := range i {
list = append(list, e.Name.Source())
}
joinRet := strings.Join(list, ",")
if normalIndexSet.Contains(joinRet) {
log.Warning("table %s: duplicate index %s", tableName, joinRet)
delete(normalIndex, k)
continue
}
normalIndexSet.Add(joinRet)
}
}
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
func convertColumns(columns []*parser.Column, primaryColumn string) (Primary, map[string]*Field, error) {
var (
primaryKey Primary
fieldM = make(map[string]*Field)
@@ -161,35 +185,35 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string)
continue
}
var comment string
if column.Type.Comment != nil {
comment = string(column.Type.Comment.Val)
}
var (
comment string
isDefaultNull bool
)
isDefaultNull := true
if column.Type.NotNull {
isDefaultNull = false
} else {
if column.Type.Default != nil {
if column.Constraint != nil {
comment = column.Constraint.Comment
isDefaultNull = !column.Constraint.HasDefaultValue
if column.Name == primaryColumn && column.Constraint.AutoIncrement {
isDefaultNull = false
}
}
dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
dataType, err := converter.ConvertDataType(column.DataType.Type(), isDefaultNull)
if err != nil {
return Primary{}, nil, err
}
var field Field
field.Name = stringx.From(column.Name.String())
field.DataBaseType = column.Type.Type
field.Name = stringx.From(column.Name)
field.DataType = dataType
field.Comment = util.TrimNewLine(comment)
if field.Name.Source() == primaryColumn {
primaryKey = Primary{
Field: field,
AutoIncrement: bool(column.Type.Autoincrement),
Field: field,
}
if column.Constraint != nil {
primaryKey.AutoIncrement = column.Constraint.AutoIncrement
}
}
@@ -198,60 +222,6 @@ func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string)
return primaryKey, fieldM, nil
}
func convertIndexes(indexes []*sqlparser.IndexDefinition) (string, map[string][]string, map[string][]string, error) {
var primaryColumn string
uniqueKeyMap := make(map[string][]string)
normalKeyMap := make(map[string][]string)
isCreateTimeOrUpdateTime := func(name string) bool {
camelColumnName := stringx.From(name).ToCamel()
// by default, createTime|updateTime findOne is not used.
return camelColumnName == "CreateTime" || camelColumnName == "UpdateTime"
}
for _, index := range indexes {
info := index.Info
if info == nil {
continue
}
indexName := index.Info.Name.String()
if info.Primary {
if len(index.Columns) > 1 {
return "", nil, nil, errPrimaryKey
}
columnName := index.Columns[0].Column.String()
if isCreateTimeOrUpdateTime(columnName) {
continue
}
primaryColumn = columnName
continue
} else if info.Unique {
for _, each := range index.Columns {
columnName := each.Column.String()
if isCreateTimeOrUpdateTime(columnName) {
break
}
uniqueKeyMap[indexName] = append(uniqueKeyMap[indexName], columnName)
}
} else if info.Spatial {
// do nothing
} else {
for _, each := range index.Columns {
columnName := each.Column.String()
if isCreateTimeOrUpdateTime(columnName) {
break
}
normalKeyMap[indexName] = append(normalKeyMap[indexName], each.Column.String())
}
}
}
return primaryColumn, uniqueKeyMap, normalKeyMap, nil
}
// ContainsTime returns true if contains golang type time.Time
func (t *Table) ContainsTime() bool {
for _, item := range t.Fields {
@@ -265,14 +235,13 @@ func (t *Table) ContainsTime() bool {
// ConvertDataType converts mysql data type into golang data type
func ConvertDataType(table *model.Table) (*Table, error) {
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
primaryDataType, err := converter.ConvertDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
primaryDataType, err := converter.ConvertStringDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
if err != nil {
return nil, err
}
var reply Table
reply.UniqueIndex = map[string][]*Field{}
reply.NormalIndex = map[string][]*Field{}
reply.Name = stringx.From(table.Table)
seqInIndex := 0
if table.PrimaryKey.Index != nil {
@@ -282,7 +251,6 @@ func ConvertDataType(table *model.Table) (*Table, error) {
reply.PrimaryKey = Primary{
Field: Field{
Name: stringx.From(table.PrimaryKey.Name),
DataBaseType: table.PrimaryKey.DataType,
DataType: primaryDataType,
Comment: table.PrimaryKey.Comment,
SeqInIndex: seqInIndex,
@@ -338,29 +306,6 @@ func ConvertDataType(table *model.Table) (*Table, error) {
reply.UniqueIndex[indexName] = list
}
normalIndexSet := collection.NewSet()
for indexName, each := range table.NormalIndex {
var list []*Field
var normalJoin []string
for _, c := range each {
list = append(list, fieldM[c.Name])
normalJoin = append(normalJoin, c.Name)
}
normalKey := strings.Join(normalJoin, ",")
if normalIndexSet.Contains(normalKey) {
log.Warning("table %s: duplicate index, %s", table.Table, normalKey)
continue
}
normalIndexSet.AddStr(normalKey)
sort.Slice(list, func(i, j int) bool {
return list[i].SeqInIndex < list[j].SeqInIndex
})
reply.NormalIndex[indexName] = list
}
return &reply, nil
}
@@ -368,7 +313,7 @@ func getTableFields(table *model.Table) (map[string]*Field, error) {
fieldM := make(map[string]*Field)
for _, each := range table.Columns {
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
dt, err := converter.ConvertStringDataType(each.DataType, isDefaultNull)
if err != nil {
return nil, err
}
@@ -379,7 +324,6 @@ func getTableFields(table *model.Table) (map[string]*Field, error) {
field := &Field{
Name: stringx.From(each.Name),
DataBaseType: each.DataType,
DataType: dt,
Comment: each.Comment,
SeqInIndex: columnSeqInIndex,

View File

@@ -1,88 +1,47 @@
package parser
import (
"sort"
"io/ioutil"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
"github.com/tal-tech/go-zero/tools/goctl/model/sql/util"
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
)
func TestParsePlainText(t *testing.T) {
_, err := Parse("plain text")
sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
err := ioutil.WriteFile(sqlFile, []byte("plain text"), 0777)
assert.Nil(t, err)
_, err = Parse(sqlFile)
assert.NotNil(t, err)
}
func TestParseSelect(t *testing.T) {
_, err := Parse("select * from user")
assert.Equal(t, errUnsupportDDL, err)
sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
err := ioutil.WriteFile(sqlFile, []byte("select * from user"), 0777)
assert.Nil(t, err)
tables, err := Parse(sqlFile)
assert.Nil(t, err)
assert.Equal(t, 0, len(tables))
}
func TestParseCreateTable(t *testing.T) {
table, err := Parse("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机 号',\n `class` bigint NOT NULL comment '班级',\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n 名',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;")
sqlFile := filepath.Join(t.TempDir(), "tmp.sql")
err := ioutil.WriteFile(sqlFile, []byte("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL comment '手\\t机 号',\n `class` bigint NOT NULL comment '班级',\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL comment '姓\n 名',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP comment '创建\\r时间',\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"), 0777)
assert.Nil(t, err)
tables, err := Parse(sqlFile)
assert.Equal(t, 1, len(tables))
table := tables[0]
assert.Nil(t, err)
assert.Equal(t, "test_user", table.Name.Source())
assert.Equal(t, "id", table.PrimaryKey.Name.Source())
assert.Equal(t, true, table.ContainsTime())
assert.Equal(t, true, func() bool {
mobileUniqueIndex, ok := table.UniqueIndex["mobile_unique"]
if !ok {
return false
}
classNameUniqueIndex, ok := table.UniqueIndex["class_name_unique"]
if !ok {
return false
}
equal := func(f1, f2 []*Field) bool {
sort.Slice(f1, func(i, j int) bool {
return f1[i].Name.Source() < f1[j].Name.Source()
})
sort.Slice(f2, func(i, j int) bool {
return f2[i].Name.Source() < f2[j].Name.Source()
})
if len(f2) != len(f2) {
return false
}
for index, f := range f1 {
if f1[index].Name.Source() != f.Name.Source() {
return false
}
}
return true
}
if !equal(mobileUniqueIndex, []*Field{
{
Name: stringx.From("mobile"),
DataBaseType: "varchar",
DataType: "string",
SeqInIndex: 1,
},
}) {
return false
}
return equal(classNameUniqueIndex, []*Field{
{
Name: stringx.From("class"),
DataBaseType: "bigint",
DataType: "int64",
SeqInIndex: 1,
},
{
Name: stringx.From("name"),
DataBaseType: "varchar",
DataType: "string",
SeqInIndex: 2,
},
})
}())
assert.Equal(t, 2, len(table.UniqueIndex))
assert.True(t, func() bool {
for _, e := range table.Fields {
if e.Comment != util.TrimNewLine(e.Comment) {