@@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
|
func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
|
||||||
defineStruct, ok := ty.(spec.DefineStruct)
|
defineStruct, done, err := c.checkStruct(ty)
|
||||||
if !ok {
|
if done {
|
||||||
return errors.New("unsupported type %s" + ty.Name())
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
for _, item := range c.requestTypes {
|
|
||||||
if item.Name() == defineStruct.Name() {
|
|
||||||
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
modelFile := util.Title(ty.Name()) + ".java"
|
modelFile := util.Title(ty.Name()) + ".java"
|
||||||
@@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) {
|
||||||
|
defineStruct, ok := ty.(spec.DefineStruct)
|
||||||
|
if !ok {
|
||||||
|
return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range c.requestTypes {
|
||||||
|
if item.Name() == defineStruct.Name() {
|
||||||
|
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
|
||||||
|
return spec.DefineStruct{}, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defineStruct, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
|
func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
if err := c.writeType(&builder, defineStruct); err != nil {
|
if err := c.writeType(&builder, defineStruct); err != nil {
|
||||||
|
|||||||
@@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch valueType {
|
s := getBaseType(valueType)
|
||||||
case "int":
|
if len(s) == 0 {
|
||||||
return "Integer[]", nil
|
return s, errors.New("unsupported primitive type " + tp.Name())
|
||||||
case "long":
|
|
||||||
return "Long[]", nil
|
|
||||||
case "float":
|
|
||||||
return "Float[]", nil
|
|
||||||
case "double":
|
|
||||||
return "Double[]", nil
|
|
||||||
case "boolean":
|
|
||||||
return "Boolean[]", nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
|
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
|
||||||
@@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) {
|
|||||||
return "", errors.New("unsupported primitive type " + tp.Name())
|
return "", errors.New("unsupported primitive type " + tp.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getBaseType(valueType string) string {
|
||||||
|
switch valueType {
|
||||||
|
case "int":
|
||||||
|
return "Integer[]"
|
||||||
|
case "long":
|
||||||
|
return "Long[]"
|
||||||
|
case "float":
|
||||||
|
return "Float[]"
|
||||||
|
case "double":
|
||||||
|
return "Double[]"
|
||||||
|
case "boolean":
|
||||||
|
return "Boolean[]"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func primitiveType(tp string) (string, bool) {
|
func primitiveType(tp string) (string, bool) {
|
||||||
switch tp {
|
switch tp {
|
||||||
case "string":
|
case "string":
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
@@ -19,7 +21,10 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestFromDDl(t *testing.T) {
|
func TestFromDDl(t *testing.T) {
|
||||||
err := fromDDl("./user.sql", t.TempDir(), cfg, true, false)
|
err := gen.Clean()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
err = fromDDl("./user.sql", t.TempDir(), cfg, true, false)
|
||||||
assert.Equal(t, errNotMatched, err)
|
assert.Equal(t, errNotMatched, err)
|
||||||
|
|
||||||
// case dir is not exists
|
// case dir is not exists
|
||||||
|
|||||||
@@ -25,27 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
|||||||
var list []string
|
var list []string
|
||||||
camelTableName := table.Name.ToCamel()
|
camelTableName := table.Name.ToCamel()
|
||||||
for _, key := range table.UniqueCacheKey {
|
for _, key := range table.UniqueCacheKey {
|
||||||
var inJoin, paramJoin, argJoin Join
|
in, paramJoinString, originalFieldString := convertJoin(key)
|
||||||
for _, f := range key.Fields {
|
|
||||||
param := stringx.From(f.Name.ToCamel()).Untitle()
|
|
||||||
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
|
||||||
paramJoin = append(paramJoin, param)
|
|
||||||
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
|
|
||||||
}
|
|
||||||
var in string
|
|
||||||
if len(inJoin) > 0 {
|
|
||||||
in = inJoin.With(", ").Source()
|
|
||||||
}
|
|
||||||
|
|
||||||
var paramJoinString string
|
|
||||||
if len(paramJoin) > 0 {
|
|
||||||
paramJoinString = paramJoin.With(",").Source()
|
|
||||||
}
|
|
||||||
|
|
||||||
var originalFieldString string
|
|
||||||
if len(argJoin) > 0 {
|
|
||||||
originalFieldString = argJoin.With(" and ").Source()
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := t.Execute(map[string]interface{}{
|
output, err := t.Execute(map[string]interface{}{
|
||||||
"upperStartCamelObject": camelTableName,
|
"upperStartCamelObject": camelTableName,
|
||||||
@@ -125,3 +105,25 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
|||||||
findOneInterfaceMethod: strings.Join(listMethod, util.NL),
|
findOneInterfaceMethod: strings.Join(listMethod, util.NL),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func convertJoin(key Key) (in, paramJoinString, originalFieldString string) {
|
||||||
|
var inJoin, paramJoin, argJoin Join
|
||||||
|
for _, f := range key.Fields {
|
||||||
|
param := stringx.From(f.Name.ToCamel()).Untitle()
|
||||||
|
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
||||||
|
paramJoin = append(paramJoin, param)
|
||||||
|
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
|
||||||
|
}
|
||||||
|
if len(inJoin) > 0 {
|
||||||
|
in = inJoin.With(", ").Source()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(paramJoin) > 0 {
|
||||||
|
paramJoinString = paramJoin.With(",").Source()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(argJoin) > 0 {
|
||||||
|
originalFieldString = argJoin.With(" and ").Source()
|
||||||
|
}
|
||||||
|
return in, paramJoinString, originalFieldString
|
||||||
|
}
|
||||||
|
|||||||
@@ -102,6 +102,17 @@ func Parse(ddl string) (*Table, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
|
||||||
|
return &Table{
|
||||||
|
Name: stringx.From(tableName),
|
||||||
|
PrimaryKey: primaryKey,
|
||||||
|
UniqueIndex: uniqueIndex,
|
||||||
|
NormalIndex: normalIndex,
|
||||||
|
Fields: fields,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
|
||||||
log := console.NewColorConsole()
|
log := console.NewColorConsole()
|
||||||
uniqueSet := collection.NewSet()
|
uniqueSet := collection.NewSet()
|
||||||
for k, i := range uniqueIndex {
|
for k, i := range uniqueIndex {
|
||||||
@@ -136,14 +147,6 @@ func Parse(ddl string) (*Table, error) {
|
|||||||
|
|
||||||
normalIndexSet.Add(joinRet)
|
normalIndexSet.Add(joinRet)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Table{
|
|
||||||
Name: stringx.From(tableName),
|
|
||||||
PrimaryKey: primaryKey,
|
|
||||||
UniqueIndex: uniqueIndex,
|
|
||||||
NormalIndex: normalIndex,
|
|
||||||
Fields: fields,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
|
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
|
||||||
@@ -289,27 +292,9 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
|||||||
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
|
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
|
||||||
}
|
}
|
||||||
|
|
||||||
fieldM := make(map[string]*Field)
|
fieldM, err := getTableFields(table)
|
||||||
for _, each := range table.Columns {
|
if err != nil {
|
||||||
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
return nil, err
|
||||||
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
columnSeqInIndex := 0
|
|
||||||
if each.Index != nil {
|
|
||||||
columnSeqInIndex = each.Index.SeqInIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
field := &Field{
|
|
||||||
Name: stringx.From(each.Name),
|
|
||||||
DataBaseType: each.DataType,
|
|
||||||
DataType: dt,
|
|
||||||
Comment: each.Comment,
|
|
||||||
SeqInIndex: columnSeqInIndex,
|
|
||||||
OrdinalPosition: each.OrdinalPosition,
|
|
||||||
}
|
|
||||||
fieldM[each.Name] = field
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, each := range fieldM {
|
for _, each := range fieldM {
|
||||||
@@ -379,3 +364,29 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
|||||||
|
|
||||||
return &reply, nil
|
return &reply, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getTableFields(table *model.Table) (map[string]*Field, error) {
|
||||||
|
fieldM := make(map[string]*Field)
|
||||||
|
for _, each := range table.Columns {
|
||||||
|
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
||||||
|
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
columnSeqInIndex := 0
|
||||||
|
if each.Index != nil {
|
||||||
|
columnSeqInIndex = each.Index.SeqInIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
field := &Field{
|
||||||
|
Name: stringx.From(each.Name),
|
||||||
|
DataBaseType: each.DataType,
|
||||||
|
DataType: dt,
|
||||||
|
Comment: each.Comment,
|
||||||
|
SeqInIndex: columnSeqInIndex,
|
||||||
|
OrdinalPosition: each.OrdinalPosition,
|
||||||
|
}
|
||||||
|
fieldM[each.Name] = field
|
||||||
|
}
|
||||||
|
return fieldM, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user