@@ -6,6 +6,8 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
@@ -19,7 +21,10 @@ var (
|
||||
)
|
||||
|
||||
func TestFromDDl(t *testing.T) {
|
||||
err := fromDDl("./user.sql", t.TempDir(), cfg, true, false)
|
||||
err := gen.Clean()
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = fromDDl("./user.sql", t.TempDir(), cfg, true, false)
|
||||
assert.Equal(t, errNotMatched, err)
|
||||
|
||||
// case dir is not exists
|
||||
|
||||
@@ -25,27 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
||||
var list []string
|
||||
camelTableName := table.Name.ToCamel()
|
||||
for _, key := range table.UniqueCacheKey {
|
||||
var inJoin, paramJoin, argJoin Join
|
||||
for _, f := range key.Fields {
|
||||
param := stringx.From(f.Name.ToCamel()).Untitle()
|
||||
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
||||
paramJoin = append(paramJoin, param)
|
||||
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
|
||||
}
|
||||
var in string
|
||||
if len(inJoin) > 0 {
|
||||
in = inJoin.With(", ").Source()
|
||||
}
|
||||
|
||||
var paramJoinString string
|
||||
if len(paramJoin) > 0 {
|
||||
paramJoinString = paramJoin.With(",").Source()
|
||||
}
|
||||
|
||||
var originalFieldString string
|
||||
if len(argJoin) > 0 {
|
||||
originalFieldString = argJoin.With(" and ").Source()
|
||||
}
|
||||
in, paramJoinString, originalFieldString := convertJoin(key)
|
||||
|
||||
output, err := t.Execute(map[string]interface{}{
|
||||
"upperStartCamelObject": camelTableName,
|
||||
@@ -125,3 +105,25 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
||||
findOneInterfaceMethod: strings.Join(listMethod, util.NL),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertJoin(key Key) (in, paramJoinString, originalFieldString string) {
|
||||
var inJoin, paramJoin, argJoin Join
|
||||
for _, f := range key.Fields {
|
||||
param := stringx.From(f.Name.ToCamel()).Untitle()
|
||||
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
||||
paramJoin = append(paramJoin, param)
|
||||
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
|
||||
}
|
||||
if len(inJoin) > 0 {
|
||||
in = inJoin.With(", ").Source()
|
||||
}
|
||||
|
||||
if len(paramJoin) > 0 {
|
||||
paramJoinString = paramJoin.With(",").Source()
|
||||
}
|
||||
|
||||
if len(argJoin) > 0 {
|
||||
originalFieldString = argJoin.With(" and ").Source()
|
||||
}
|
||||
return in, paramJoinString, originalFieldString
|
||||
}
|
||||
|
||||
@@ -102,6 +102,17 @@ func Parse(ddl string) (*Table, error) {
|
||||
}
|
||||
}
|
||||
|
||||
checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
|
||||
return &Table{
|
||||
Name: stringx.From(tableName),
|
||||
PrimaryKey: primaryKey,
|
||||
UniqueIndex: uniqueIndex,
|
||||
NormalIndex: normalIndex,
|
||||
Fields: fields,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
|
||||
log := console.NewColorConsole()
|
||||
uniqueSet := collection.NewSet()
|
||||
for k, i := range uniqueIndex {
|
||||
@@ -136,14 +147,6 @@ func Parse(ddl string) (*Table, error) {
|
||||
|
||||
normalIndexSet.Add(joinRet)
|
||||
}
|
||||
|
||||
return &Table{
|
||||
Name: stringx.From(tableName),
|
||||
PrimaryKey: primaryKey,
|
||||
UniqueIndex: uniqueIndex,
|
||||
NormalIndex: normalIndex,
|
||||
Fields: fields,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
|
||||
@@ -289,27 +292,9 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
||||
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
|
||||
}
|
||||
|
||||
fieldM := make(map[string]*Field)
|
||||
for _, each := range table.Columns {
|
||||
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
||||
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
columnSeqInIndex := 0
|
||||
if each.Index != nil {
|
||||
columnSeqInIndex = each.Index.SeqInIndex
|
||||
}
|
||||
|
||||
field := &Field{
|
||||
Name: stringx.From(each.Name),
|
||||
DataBaseType: each.DataType,
|
||||
DataType: dt,
|
||||
Comment: each.Comment,
|
||||
SeqInIndex: columnSeqInIndex,
|
||||
OrdinalPosition: each.OrdinalPosition,
|
||||
}
|
||||
fieldM[each.Name] = field
|
||||
fieldM, err := getTableFields(table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, each := range fieldM {
|
||||
@@ -379,3 +364,29 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
||||
|
||||
return &reply, nil
|
||||
}
|
||||
|
||||
func getTableFields(table *model.Table) (map[string]*Field, error) {
|
||||
fieldM := make(map[string]*Field)
|
||||
for _, each := range table.Columns {
|
||||
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
||||
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
columnSeqInIndex := 0
|
||||
if each.Index != nil {
|
||||
columnSeqInIndex = each.Index.SeqInIndex
|
||||
}
|
||||
|
||||
field := &Field{
|
||||
Name: stringx.From(each.Name),
|
||||
DataBaseType: each.DataType,
|
||||
DataType: dt,
|
||||
Comment: each.Comment,
|
||||
SeqInIndex: columnSeqInIndex,
|
||||
OrdinalPosition: each.OrdinalPosition,
|
||||
}
|
||||
fieldM[each.Name] = field
|
||||
}
|
||||
return fieldM, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user