feature 1.1.5 (#411)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package new
|
package new
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -35,6 +36,10 @@ func CreateServiceCommand(c *cli.Context) error {
|
|||||||
dirName = "greet"
|
dirName = "greet"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.Contains(dirName, "-") {
|
||||||
|
return errors.New("api new command service name not support strikethrough, because this will used by function name")
|
||||||
|
}
|
||||||
|
|
||||||
abs, err := filepath.Abs(dirName)
|
abs, err := filepath.Abs(dirName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
|
|
||||||
"github.com/antlr/antlr4/runtime/Go/antlr"
|
"github.com/antlr/antlr4/runtime/Go/antlr"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
|
"github.com/tal-tech/go-zero/tools/goctl/api/parser/g4/gen/api"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/api/util"
|
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -323,18 +322,3 @@ func (v *ApiVisitor) getHiddenTokensToRight(t TokenStream, channel int) []Expr {
|
|||||||
|
|
||||||
return list
|
return list
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *ApiVisitor) exportCheck(expr Expr) {
|
|
||||||
if expr == nil || !expr.IsNotNil() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if api.IsBasicType(expr.Text()) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if util.UnExport(expr.Text()) {
|
|
||||||
v.log.Warning("%s line %d:%d unexported declaration '%s', use %s instead", expr.Prefix(), expr.Line(),
|
|
||||||
expr.Column(), expr.Text(), strings.Title(expr.Text()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -219,7 +219,6 @@ func (v *ApiVisitor) VisitBody(ctx *api.BodyContext) interface{} {
|
|||||||
if api.IsGolangKeyWord(idRxpr.Text()) {
|
if api.IsGolangKeyWord(idRxpr.Text()) {
|
||||||
v.panic(idRxpr, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", idRxpr.Text()))
|
v.panic(idRxpr, fmt.Sprintf("expecting 'ID', but found golang keyword '%s'", idRxpr.Text()))
|
||||||
}
|
}
|
||||||
v.exportCheck(idRxpr)
|
|
||||||
|
|
||||||
return &Body{
|
return &Body{
|
||||||
Lp: v.newExprWithToken(ctx.GetLp()),
|
Lp: v.newExprWithToken(ctx.GetLp()),
|
||||||
@@ -250,7 +249,6 @@ func (v *ApiVisitor) VisitReplybody(ctx *api.ReplybodyContext) interface{} {
|
|||||||
default:
|
default:
|
||||||
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
|
v.panic(dt.Expr(), fmt.Sprintf("unsupport %s", dt.Expr().Text()))
|
||||||
}
|
}
|
||||||
v.log.Warning("%s %d:%d deprecated array type near '%s'", v.prefix, dataType.ArrayExpr.Line(), dataType.ArrayExpr.Column(), dataType.ArrayExpr.Text())
|
|
||||||
case *Literal:
|
case *Literal:
|
||||||
lit := dataType.Literal.Text()
|
lit := dataType.Literal.Text()
|
||||||
if api.IsGolangKeyWord(dataType.Literal.Text()) {
|
if api.IsGolangKeyWord(dataType.Literal.Text()) {
|
||||||
|
|||||||
@@ -153,7 +153,6 @@ func (v *ApiVisitor) VisitTypeBlockBody(ctx *api.TypeBlockBodyContext) interface
|
|||||||
func (v *ApiVisitor) VisitTypeStruct(ctx *api.TypeStructContext) interface{} {
|
func (v *ApiVisitor) VisitTypeStruct(ctx *api.TypeStructContext) interface{} {
|
||||||
var st TypeStruct
|
var st TypeStruct
|
||||||
st.Name = v.newExprWithToken(ctx.GetStructName())
|
st.Name = v.newExprWithToken(ctx.GetStructName())
|
||||||
v.exportCheck(st.Name)
|
|
||||||
|
|
||||||
if util.UnExport(ctx.GetStructName().GetText()) {
|
if util.UnExport(ctx.GetStructName().GetText()) {
|
||||||
|
|
||||||
@@ -189,7 +188,6 @@ func (v *ApiVisitor) VisitTypeStruct(ctx *api.TypeStructContext) interface{} {
|
|||||||
func (v *ApiVisitor) VisitTypeBlockStruct(ctx *api.TypeBlockStructContext) interface{} {
|
func (v *ApiVisitor) VisitTypeBlockStruct(ctx *api.TypeBlockStructContext) interface{} {
|
||||||
var st TypeStruct
|
var st TypeStruct
|
||||||
st.Name = v.newExprWithToken(ctx.GetStructName())
|
st.Name = v.newExprWithToken(ctx.GetStructName())
|
||||||
v.exportCheck(st.Name)
|
|
||||||
|
|
||||||
if ctx.GetStructToken() != nil {
|
if ctx.GetStructToken() != nil {
|
||||||
structExpr := v.newExprWithToken(ctx.GetStructToken())
|
structExpr := v.newExprWithToken(ctx.GetStructToken())
|
||||||
@@ -261,7 +259,6 @@ func (v *ApiVisitor) VisitField(ctx *api.FieldContext) interface{} {
|
|||||||
func (v *ApiVisitor) VisitNormalField(ctx *api.NormalFieldContext) interface{} {
|
func (v *ApiVisitor) VisitNormalField(ctx *api.NormalFieldContext) interface{} {
|
||||||
var field TypeField
|
var field TypeField
|
||||||
field.Name = v.newExprWithToken(ctx.GetFieldName())
|
field.Name = v.newExprWithToken(ctx.GetFieldName())
|
||||||
v.exportCheck(field.Name)
|
|
||||||
|
|
||||||
iDataTypeContext := ctx.DataType()
|
iDataTypeContext := ctx.DataType()
|
||||||
if iDataTypeContext != nil {
|
if iDataTypeContext != nil {
|
||||||
@@ -289,7 +286,6 @@ func (v *ApiVisitor) VisitAnonymousFiled(ctx *api.AnonymousFiledContext) interfa
|
|||||||
field.IsAnonymous = true
|
field.IsAnonymous = true
|
||||||
if ctx.GetStar() != nil {
|
if ctx.GetStar() != nil {
|
||||||
nameExpr := v.newExprWithTerminalNode(ctx.ID())
|
nameExpr := v.newExprWithTerminalNode(ctx.ID())
|
||||||
v.exportCheck(nameExpr)
|
|
||||||
field.DataType = &Pointer{
|
field.DataType = &Pointer{
|
||||||
PointerExpr: v.newExprWithText(ctx.GetStar().GetText()+ctx.ID().GetText(), start.GetLine(), start.GetColumn(), start.GetStart(), stop.GetStop()),
|
PointerExpr: v.newExprWithText(ctx.GetStar().GetText()+ctx.ID().GetText(), start.GetLine(), start.GetColumn(), start.GetStart(), stop.GetStop()),
|
||||||
Star: v.newExprWithToken(ctx.GetStar()),
|
Star: v.newExprWithToken(ctx.GetStar()),
|
||||||
@@ -297,7 +293,6 @@ func (v *ApiVisitor) VisitAnonymousFiled(ctx *api.AnonymousFiledContext) interfa
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
nameExpr := v.newExprWithTerminalNode(ctx.ID())
|
nameExpr := v.newExprWithTerminalNode(ctx.ID())
|
||||||
v.exportCheck(nameExpr)
|
|
||||||
field.DataType = &Literal{Literal: nameExpr}
|
field.DataType = &Literal{Literal: nameExpr}
|
||||||
}
|
}
|
||||||
field.DocExpr = v.getDoc(ctx)
|
field.DocExpr = v.getDoc(ctx)
|
||||||
@@ -309,7 +304,6 @@ func (v *ApiVisitor) VisitAnonymousFiled(ctx *api.AnonymousFiledContext) interfa
|
|||||||
func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} {
|
func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} {
|
||||||
if ctx.ID() != nil {
|
if ctx.ID() != nil {
|
||||||
idExpr := v.newExprWithTerminalNode(ctx.ID())
|
idExpr := v.newExprWithTerminalNode(ctx.ID())
|
||||||
v.exportCheck(idExpr)
|
|
||||||
return &Literal{Literal: idExpr}
|
return &Literal{Literal: idExpr}
|
||||||
}
|
}
|
||||||
if ctx.MapType() != nil {
|
if ctx.MapType() != nil {
|
||||||
@@ -337,7 +331,6 @@ func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} {
|
|||||||
// VisitPointerType implements from api.BaseApiParserVisitor
|
// VisitPointerType implements from api.BaseApiParserVisitor
|
||||||
func (v *ApiVisitor) VisitPointerType(ctx *api.PointerTypeContext) interface{} {
|
func (v *ApiVisitor) VisitPointerType(ctx *api.PointerTypeContext) interface{} {
|
||||||
nameExpr := v.newExprWithTerminalNode(ctx.ID())
|
nameExpr := v.newExprWithTerminalNode(ctx.ID())
|
||||||
v.exportCheck(nameExpr)
|
|
||||||
return &Pointer{
|
return &Pointer{
|
||||||
PointerExpr: v.newExprWithText(ctx.GetText(), ctx.GetStar().GetLine(), ctx.GetStar().GetColumn(), ctx.GetStar().GetStart(), ctx.ID().GetSymbol().GetStop()),
|
PointerExpr: v.newExprWithText(ctx.GetText(), ctx.GetStar().GetLine(), ctx.GetStar().GetColumn(), ctx.GetStar().GetStart(), ctx.ID().GetSymbol().GetStop()),
|
||||||
Star: v.newExprWithToken(ctx.GetStar()),
|
Star: v.newExprWithToken(ctx.GetStar()),
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bo
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
matchTables := make(map[string][]*model.Column)
|
matchTables := make(map[string]*model.Table)
|
||||||
for _, item := range tables {
|
for _, item := range tables {
|
||||||
match, err := filepath.Match(pattern, item)
|
match, err := filepath.Match(pattern, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -131,11 +131,18 @@ func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bo
|
|||||||
if !match {
|
if !match {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
columns, err := im.FindByTableName(dsn.DBName, item)
|
|
||||||
|
columnData, err := im.FindColumns(dsn.DBName, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
matchTables[item] = columns
|
|
||||||
|
table, err := columnData.Convert()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
matchTables[item] = table
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(matchTables) == 0 {
|
if len(matchTables) == 0 {
|
||||||
@@ -147,5 +154,5 @@ func fromDataSource(url, pattern, dir string, cfg *config.Config, cache, idea bo
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return generator.StartFromInformationSchema(dsn.DBName, matchTables, cache)
|
return generator.StartFromInformationSchema(matchTables, cache)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ fromDDLWithoutCache:
|
|||||||
|
|
||||||
|
|
||||||
# generate model with cache from data source
|
# generate model with cache from data source
|
||||||
user=root
|
user=ugozero
|
||||||
password=password
|
password=
|
||||||
datasource=127.0.0.1:3306
|
datasource=127.0.0.1:3306
|
||||||
database=gozero
|
database=gozero
|
||||||
|
|
||||||
|
|||||||
@@ -17,10 +17,12 @@ CREATE TABLE `user` (
|
|||||||
|
|
||||||
CREATE TABLE `student` (
|
CREATE TABLE `student` (
|
||||||
`id` bigint NOT NULL AUTO_INCREMENT,
|
`id` bigint NOT NULL AUTO_INCREMENT,
|
||||||
|
`class` varchar(255) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
|
||||||
`name` varchar(255) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
|
`name` varchar(255) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
|
||||||
`age` tinyint DEFAULT NULL,
|
`age` tinyint DEFAULT NULL,
|
||||||
`score` float(10,0) DEFAULT NULL,
|
`score` float(10,0) DEFAULT NULL,
|
||||||
`create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
`create_time` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
`update_time` timestamp NULL DEFAULT NULL,
|
`update_time` timestamp NULL DEFAULT NULL,
|
||||||
PRIMARY KEY (`id`) USING BTREE
|
PRIMARY KEY (`id`) USING BTREE,
|
||||||
|
UNIQUE KEY `class_name_index` (`class`,`name`)
|
||||||
) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
|
) ENGINE=InnoDB AUTO_INCREMENT=4 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;
|
||||||
@@ -12,13 +12,11 @@ import (
|
|||||||
func genDelete(table Table, withCache bool) (string, string, error) {
|
func genDelete(table Table, withCache bool) (string, string, error) {
|
||||||
keySet := collection.NewSet()
|
keySet := collection.NewSet()
|
||||||
keyVariableSet := collection.NewSet()
|
keyVariableSet := collection.NewSet()
|
||||||
for fieldName, key := range table.CacheKey {
|
keySet.AddStr(table.PrimaryCacheKey.KeyExpression)
|
||||||
if fieldName == table.PrimaryKey.Name.Source() {
|
keyVariableSet.AddStr(table.PrimaryCacheKey.KeyLeft)
|
||||||
keySet.AddStr(key.KeyExpression)
|
for _, key := range table.UniqueCacheKey {
|
||||||
} else {
|
keySet.AddStr(key.DataKeyExpression)
|
||||||
keySet.AddStr(key.DataKeyExpression)
|
keyVariableSet.AddStr(key.KeyLeft)
|
||||||
}
|
|
||||||
keyVariableSet.AddStr(key.Variable)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
camel := table.Name.ToCamel()
|
camel := table.Name.ToCamel()
|
||||||
@@ -32,7 +30,7 @@ func genDelete(table Table, withCache bool) (string, string, error) {
|
|||||||
Execute(map[string]interface{}{
|
Execute(map[string]interface{}{
|
||||||
"upperStartCamelObject": camel,
|
"upperStartCamelObject": camel,
|
||||||
"withCache": withCache,
|
"withCache": withCache,
|
||||||
"containsIndexCache": table.ContainsUniqueKey,
|
"containsIndexCache": table.ContainsUniqueCacheKey,
|
||||||
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
||||||
"dataType": table.PrimaryKey.DataType,
|
"dataType": table.PrimaryKey.DataType,
|
||||||
"keys": strings.Join(keySet.KeysStr(), "\n"),
|
"keys": strings.Join(keySet.KeysStr(), "\n"),
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func genFields(fields []parser.Field) (string, error) {
|
func genFields(fields []*parser.Field) (string, error) {
|
||||||
var list []string
|
var list []string
|
||||||
|
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
@@ -23,7 +23,7 @@ func genFields(fields []parser.Field) (string, error) {
|
|||||||
return strings.Join(list, "\n"), nil
|
return strings.Join(list, "\n"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genField(field parser.Field) (string, error) {
|
func genField(field *parser.Field) (string, error) {
|
||||||
tag, err := genTag(field.Name.Source())
|
tag, err := genTag(field.Name.Source())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ func genFindOne(table Table, withCache bool) (string, string, error) {
|
|||||||
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||||
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
"lowerStartCamelPrimaryKey": stringx.From(table.PrimaryKey.Name.ToCamel()).Untitle(),
|
||||||
"dataType": table.PrimaryKey.DataType,
|
"dataType": table.PrimaryKey.DataType,
|
||||||
"cacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].KeyExpression,
|
"cacheKey": table.PrimaryCacheKey.KeyExpression,
|
||||||
"cacheKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable,
|
"cacheKeyVariable": table.PrimaryCacheKey.KeyLeft,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
|
|||||||
@@ -24,22 +24,40 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
|||||||
t := util.With("findOneByField").Parse(text)
|
t := util.With("findOneByField").Parse(text)
|
||||||
var list []string
|
var list []string
|
||||||
camelTableName := table.Name.ToCamel()
|
camelTableName := table.Name.ToCamel()
|
||||||
for _, field := range table.Fields {
|
for _, key := range table.UniqueCacheKey {
|
||||||
if field.IsPrimaryKey || !field.IsUniqueKey {
|
var inJoin, paramJoin, argJoin Join
|
||||||
continue
|
for _, f := range key.Fields {
|
||||||
|
param := stringx.From(f.Name.ToCamel()).Untitle()
|
||||||
|
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
||||||
|
paramJoin = append(paramJoin, param)
|
||||||
|
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
|
||||||
}
|
}
|
||||||
camelFieldName := field.Name.ToCamel()
|
var in string
|
||||||
|
if len(inJoin) > 0 {
|
||||||
|
in = inJoin.With(", ").Source()
|
||||||
|
}
|
||||||
|
|
||||||
|
var paramJoinString string
|
||||||
|
if len(paramJoin) > 0 {
|
||||||
|
paramJoinString = paramJoin.With(",").Source()
|
||||||
|
}
|
||||||
|
|
||||||
|
var originalFieldString string
|
||||||
|
if len(argJoin) > 0 {
|
||||||
|
originalFieldString = argJoin.With(" and ").Source()
|
||||||
|
}
|
||||||
|
|
||||||
output, err := t.Execute(map[string]interface{}{
|
output, err := t.Execute(map[string]interface{}{
|
||||||
"upperStartCamelObject": camelTableName,
|
"upperStartCamelObject": camelTableName,
|
||||||
"upperField": camelFieldName,
|
"upperField": key.FieldNameJoin.Camel().With("").Source(),
|
||||||
"in": fmt.Sprintf("%s %s", stringx.From(camelFieldName).Untitle(), field.DataType),
|
"in": in,
|
||||||
"withCache": withCache,
|
"withCache": withCache,
|
||||||
"cacheKey": table.CacheKey[field.Name.Source()].KeyExpression,
|
"cacheKey": key.KeyExpression,
|
||||||
"cacheKeyVariable": table.CacheKey[field.Name.Source()].Variable,
|
"cacheKeyVariable": key.KeyLeft,
|
||||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||||
"lowerStartCamelField": stringx.From(camelFieldName).Untitle(),
|
"lowerStartCamelField": paramJoinString,
|
||||||
"upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(),
|
"upperStartCamelPrimaryKey": table.PrimaryKey.Name.ToCamel(),
|
||||||
"originalField": wrapWithRawString(field.Name.Source()),
|
"originalField": originalFieldString,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -55,15 +73,22 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
|||||||
|
|
||||||
t = util.With("findOneByFieldMethod").Parse(text)
|
t = util.With("findOneByFieldMethod").Parse(text)
|
||||||
var listMethod []string
|
var listMethod []string
|
||||||
for _, field := range table.Fields {
|
for _, key := range table.UniqueCacheKey {
|
||||||
if field.IsPrimaryKey || !field.IsUniqueKey {
|
var inJoin, paramJoin Join
|
||||||
continue
|
for _, f := range key.Fields {
|
||||||
|
param := stringx.From(f.Name.ToCamel()).Untitle()
|
||||||
|
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
||||||
|
paramJoin = append(paramJoin, param)
|
||||||
|
}
|
||||||
|
|
||||||
|
var in string
|
||||||
|
if len(inJoin) > 0 {
|
||||||
|
in = inJoin.With(", ").Source()
|
||||||
}
|
}
|
||||||
camelFieldName := field.Name.ToCamel()
|
|
||||||
output, err := t.Execute(map[string]interface{}{
|
output, err := t.Execute(map[string]interface{}{
|
||||||
"upperStartCamelObject": camelTableName,
|
"upperStartCamelObject": camelTableName,
|
||||||
"upperField": camelFieldName,
|
"upperField": key.FieldNameJoin.Camel().With("").Source(),
|
||||||
"in": fmt.Sprintf("%s %s", stringx.From(camelFieldName).Untitle(), field.DataType),
|
"in": in,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -80,7 +105,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
|||||||
|
|
||||||
out, err := util.With("findOneByFieldExtraMethod").Parse(text).Execute(map[string]interface{}{
|
out, err := util.With("findOneByFieldExtraMethod").Parse(text).Execute(map[string]interface{}{
|
||||||
"upperStartCamelObject": camelTableName,
|
"upperStartCamelObject": camelTableName,
|
||||||
"primaryKeyLeft": table.CacheKey[table.PrimaryKey.Name.Source()].Left,
|
"primaryKeyLeft": table.PrimaryCacheKey.VarLeft,
|
||||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||||
"originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
"originalPrimaryField": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -99,10 +99,10 @@ func (g *defaultGenerator) StartFromDDL(source string, withCache bool) error {
|
|||||||
return g.createFile(modelList)
|
return g.createFile(modelList)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *defaultGenerator) StartFromInformationSchema(db string, columns map[string][]*model.Column, withCache bool) error {
|
func (g *defaultGenerator) StartFromInformationSchema(tables map[string]*model.Table, withCache bool) error {
|
||||||
m := make(map[string]string)
|
m := make(map[string]string)
|
||||||
for tableName, column := range columns {
|
for _, each := range tables {
|
||||||
table, err := parser.ConvertColumn(db, tableName, column)
|
table, err := parser.ConvertDataType(each)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -182,10 +182,12 @@ func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
code, err := g.genModel(*table, withCache)
|
code, err := g.genModel(*table, withCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m[table.Name.Source()] = code
|
m[table.Name.Source()] = code
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,8 +197,9 @@ func (g *defaultGenerator) genFromDDL(source string, withCache bool) (map[string
|
|||||||
// Table defines mysql table
|
// Table defines mysql table
|
||||||
type Table struct {
|
type Table struct {
|
||||||
parser.Table
|
parser.Table
|
||||||
CacheKey map[string]Key
|
PrimaryCacheKey Key
|
||||||
ContainsUniqueKey bool
|
UniqueCacheKey []Key
|
||||||
|
ContainsUniqueCacheKey bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) {
|
func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, error) {
|
||||||
@@ -204,10 +207,7 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
|
|||||||
return "", fmt.Errorf("table %s: missing primary key", in.Name.Source())
|
return "", fmt.Errorf("table %s: missing primary key", in.Name.Source())
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := genCacheKeys(in)
|
primaryKey, uniqueKey := genCacheKeys(in)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
importsCode, err := genImports(withCache, in.ContainsTime())
|
importsCode, err := genImports(withCache, in.ContainsTime())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -216,15 +216,9 @@ func (g *defaultGenerator) genModel(in parser.Table, withCache bool) (string, er
|
|||||||
|
|
||||||
var table Table
|
var table Table
|
||||||
table.Table = in
|
table.Table = in
|
||||||
table.CacheKey = m
|
table.PrimaryCacheKey = primaryKey
|
||||||
var containsUniqueCache = false
|
table.UniqueCacheKey = uniqueKey
|
||||||
for _, item := range table.Fields {
|
table.ContainsUniqueCacheKey = len(uniqueKey) > 0
|
||||||
if item.IsUniqueKey {
|
|
||||||
containsUniqueCache = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
table.ContainsUniqueKey = containsUniqueCache
|
|
||||||
|
|
||||||
varsCode, err := genVars(table, withCache)
|
varsCode, err := genVars(table, withCache)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -16,18 +16,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
source = "CREATE TABLE `test_user_info` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `nanosecond` bigint NOT NULL DEFAULT '0',\n `data` varchar(255) DEFAULT '',\n `content` json DEFAULT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `nanosecond_unique` (`nanosecond`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci;"
|
source = "CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL,\n `class` bigint NOT NULL,\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCacheModel(t *testing.T) {
|
func TestCacheModel(t *testing.T) {
|
||||||
logx.Disable()
|
logx.Disable()
|
||||||
_ = Clean()
|
_ = Clean()
|
||||||
dir, _ := filepath.Abs("./testmodel")
|
dir := filepath.Join(t.TempDir(), "./testmodel")
|
||||||
cacheDir := filepath.Join(dir, "cache")
|
cacheDir := filepath.Join(dir, "cache")
|
||||||
noCacheDir := filepath.Join(dir, "nocache")
|
noCacheDir := filepath.Join(dir, "nocache")
|
||||||
defer func() {
|
|
||||||
_ = os.RemoveAll(dir)
|
|
||||||
}()
|
|
||||||
g, err := NewDefaultGenerator(cacheDir, &config.Config{
|
g, err := NewDefaultGenerator(cacheDir, &config.Config{
|
||||||
NamingFormat: "GoZero",
|
NamingFormat: "GoZero",
|
||||||
})
|
})
|
||||||
@@ -36,7 +33,7 @@ func TestCacheModel(t *testing.T) {
|
|||||||
err = g.StartFromDDL(source, true)
|
err = g.StartFromDDL(source, true)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.True(t, func() bool {
|
assert.True(t, func() bool {
|
||||||
_, err := os.Stat(filepath.Join(cacheDir, "TestUserInfoModel.go"))
|
_, err := os.Stat(filepath.Join(cacheDir, "TestUserModel.go"))
|
||||||
return err == nil
|
return err == nil
|
||||||
}())
|
}())
|
||||||
g, err = NewDefaultGenerator(noCacheDir, &config.Config{
|
g, err = NewDefaultGenerator(noCacheDir, &config.Config{
|
||||||
@@ -47,7 +44,7 @@ func TestCacheModel(t *testing.T) {
|
|||||||
err = g.StartFromDDL(source, false)
|
err = g.StartFromDDL(source, false)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.True(t, func() bool {
|
assert.True(t, func() bool {
|
||||||
_, err := os.Stat(filepath.Join(noCacheDir, "testuserinfomodel.go"))
|
_, err := os.Stat(filepath.Join(noCacheDir, "testusermodel.go"))
|
||||||
return err == nil
|
return err == nil
|
||||||
}())
|
}())
|
||||||
}
|
}
|
||||||
@@ -69,7 +66,7 @@ func TestNamingModel(t *testing.T) {
|
|||||||
err = g.StartFromDDL(source, true)
|
err = g.StartFromDDL(source, true)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.True(t, func() bool {
|
assert.True(t, func() bool {
|
||||||
_, err := os.Stat(filepath.Join(camelDir, "TestUserInfoModel.go"))
|
_, err := os.Stat(filepath.Join(camelDir, "TestUserModel.go"))
|
||||||
return err == nil
|
return err == nil
|
||||||
}())
|
}())
|
||||||
g, err = NewDefaultGenerator(snakeDir, &config.Config{
|
g, err = NewDefaultGenerator(snakeDir, &config.Config{
|
||||||
@@ -80,7 +77,7 @@ func TestNamingModel(t *testing.T) {
|
|||||||
err = g.StartFromDDL(source, true)
|
err = g.StartFromDDL(source, true)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.True(t, func() bool {
|
assert.True(t, func() bool {
|
||||||
_, err := os.Stat(filepath.Join(snakeDir, "test_user_info_model.go"))
|
_, err := os.Stat(filepath.Join(snakeDir, "test_user_model.go"))
|
||||||
return err == nil
|
return err == nil
|
||||||
}())
|
}())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,12 +12,9 @@ import (
|
|||||||
func genInsert(table Table, withCache bool) (string, string, error) {
|
func genInsert(table Table, withCache bool) (string, string, error) {
|
||||||
keySet := collection.NewSet()
|
keySet := collection.NewSet()
|
||||||
keyVariableSet := collection.NewSet()
|
keyVariableSet := collection.NewSet()
|
||||||
for fieldName, key := range table.CacheKey {
|
for _, key := range table.UniqueCacheKey {
|
||||||
if fieldName == table.PrimaryKey.Name.Source() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
keySet.AddStr(key.DataKeyExpression)
|
keySet.AddStr(key.DataKeyExpression)
|
||||||
keyVariableSet.AddStr(key.Variable)
|
keyVariableSet.AddStr(key.KeyLeft)
|
||||||
}
|
}
|
||||||
|
|
||||||
expressions := make([]string, 0)
|
expressions := make([]string, 0)
|
||||||
@@ -27,12 +24,17 @@ func genInsert(table Table, withCache bool) (string, string, error) {
|
|||||||
if camel == "CreateTime" || camel == "UpdateTime" {
|
if camel == "CreateTime" || camel == "UpdateTime" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if field.IsPrimaryKey && table.PrimaryKey.AutoIncrement {
|
|
||||||
continue
|
if field.Name.Source() == table.PrimaryKey.Name.Source() {
|
||||||
|
if table.PrimaryKey.AutoIncrement {
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
expressions = append(expressions, "?")
|
expressions = append(expressions, "?")
|
||||||
expressionValues = append(expressionValues, "data."+camel)
|
expressionValues = append(expressionValues, "data."+camel)
|
||||||
}
|
}
|
||||||
|
|
||||||
camel := table.Name.ToCamel()
|
camel := table.Name.ToCamel()
|
||||||
text, err := util.LoadTemplate(category, insertTemplateFile, template.Insert)
|
text, err := util.LoadTemplate(category, insertTemplateFile, template.Insert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -43,7 +45,7 @@ func genInsert(table Table, withCache bool) (string, string, error) {
|
|||||||
Parse(text).
|
Parse(text).
|
||||||
Execute(map[string]interface{}{
|
Execute(map[string]interface{}{
|
||||||
"withCache": withCache,
|
"withCache": withCache,
|
||||||
"containsIndexCache": table.ContainsUniqueKey,
|
"containsIndexCache": table.ContainsUniqueCacheKey,
|
||||||
"upperStartCamelObject": camel,
|
"upperStartCamelObject": camel,
|
||||||
"lowerStartCamelObject": stringx.From(camel).Untitle(),
|
"lowerStartCamelObject": stringx.From(camel).Untitle(),
|
||||||
"expression": strings.Join(expressions, ", "),
|
"expression": strings.Join(expressions, ", "),
|
||||||
@@ -61,11 +63,9 @@ func genInsert(table Table, withCache bool) (string, string, error) {
|
|||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
insertMethodOutput, err := util.With("insertMethod").
|
insertMethodOutput, err := util.With("insertMethod").Parse(text).Execute(map[string]interface{}{
|
||||||
Parse(text).
|
"upperStartCamelObject": camel,
|
||||||
Execute(map[string]interface{}{
|
})
|
||||||
"upperStartCamelObject": camel,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,61 +2,163 @@ package gen
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/parser"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Key defines cache key variable for generating code
|
// Key describes cache key
|
||||||
type Key struct {
|
type Key struct {
|
||||||
// VarExpression likes cacheUserIdPrefix = "cache#User#id#"
|
// VarLeft describes the varible of cache key expression which likes cacheUserIdPrefix
|
||||||
|
VarLeft string
|
||||||
|
// VarRight describes the value of cache key expression which likes "cache#user#id#"
|
||||||
|
VarRight string
|
||||||
|
// VarExpression describes the cache key expression which likes cacheUserIdPrefix = "cache#user#id#"
|
||||||
VarExpression string
|
VarExpression string
|
||||||
// Left likes cacheUserIdPrefix
|
// KeyLeft describes the varible of key definiation expression which likes userKey
|
||||||
Left string
|
KeyLeft string
|
||||||
// Right likes cache#user#id#
|
// KeyRight describes the value of key definiation expression which likes fmt.Sprintf("%s%v", cacheUserPrefix, user)
|
||||||
Right string
|
KeyRight string
|
||||||
// Variable likes userIdKey
|
// DataKeyRight describes data key likes fmt.Sprintf("%s%v", cacheUserPrefix, data.User)
|
||||||
Variable string
|
DataKeyRight string
|
||||||
// KeyExpression likes userIdKey: = fmt.Sprintf("cache#user#id#%v", userId)
|
// KeyExpression describes key expression likes userKey := fmt.Sprintf("%s%v", cacheUserPrefix, user)
|
||||||
KeyExpression string
|
KeyExpression string
|
||||||
// DataKeyExpression likes userIdKey: = fmt.Sprintf("cache#user#id#%v", data.userId)
|
// DataKeyExpression describes data key expression likes userKey := fmt.Sprintf("%s%v", cacheUserPrefix, data.User)
|
||||||
DataKeyExpression string
|
DataKeyExpression string
|
||||||
// RespKeyExpression likes userIdKey: = fmt.Sprintf("cache#user#id#%v", resp.userId)
|
// FieldNameJoin describes the filed slice of table
|
||||||
RespKeyExpression string
|
FieldNameJoin Join
|
||||||
|
// Fields describes the fields of table
|
||||||
|
Fields []*parser.Field
|
||||||
}
|
}
|
||||||
|
|
||||||
// key-数据库原始字段名,value-缓存key相关数据
|
// Join describes an alias of string slice
|
||||||
func genCacheKeys(table parser.Table) (map[string]Key, error) {
|
type Join []string
|
||||||
fields := table.Fields
|
|
||||||
m := make(map[string]Key)
|
|
||||||
camelTableName := table.Name.ToCamel()
|
|
||||||
lowerStartCamelTableName := stringx.From(camelTableName).Untitle()
|
|
||||||
for _, field := range fields {
|
|
||||||
if field.IsUniqueKey || field.IsPrimaryKey {
|
|
||||||
camelFieldName := field.Name.ToCamel()
|
|
||||||
lowerStartCamelFieldName := stringx.From(camelFieldName).Untitle()
|
|
||||||
left := fmt.Sprintf("cache%s%sPrefix", camelTableName, camelFieldName)
|
|
||||||
if strings.ToLower(camelFieldName) == strings.ToLower(camelTableName) {
|
|
||||||
left = fmt.Sprintf("cache%sPrefix", camelTableName)
|
|
||||||
}
|
|
||||||
right := fmt.Sprintf("cache#%s#%s#", camelTableName, lowerStartCamelFieldName)
|
|
||||||
variable := fmt.Sprintf("%s%sKey", lowerStartCamelTableName, camelFieldName)
|
|
||||||
if strings.ToLower(lowerStartCamelTableName) == strings.ToLower(camelFieldName) {
|
|
||||||
variable = fmt.Sprintf("%sKey", lowerStartCamelTableName)
|
|
||||||
}
|
|
||||||
|
|
||||||
m[field.Name.Source()] = Key{
|
func genCacheKeys(table parser.Table) (Key, []Key) {
|
||||||
VarExpression: fmt.Sprintf(`%s = "%s"`, left, right),
|
var primaryKey Key
|
||||||
Left: left,
|
var uniqueKey []Key
|
||||||
Right: right,
|
primaryKey = genCacheKey(table.Name, []*parser.Field{&table.PrimaryKey.Field})
|
||||||
Variable: variable,
|
for _, each := range table.UniqueIndex {
|
||||||
KeyExpression: fmt.Sprintf(`%s := fmt.Sprintf("%s%s", %s,%s)`, variable, "%s", "%v", left, lowerStartCamelFieldName),
|
uniqueKey = append(uniqueKey, genCacheKey(table.Name, each))
|
||||||
DataKeyExpression: fmt.Sprintf(`%s := fmt.Sprintf("%s%s",%s, data.%s)`, variable, "%s", "%v", left, camelFieldName),
|
}
|
||||||
RespKeyExpression: fmt.Sprintf(`%s := fmt.Sprintf("%s%s", %s,resp.%s)`, variable, "%s", "%v", left, camelFieldName),
|
sort.Slice(uniqueKey, func(i, j int) bool {
|
||||||
}
|
return uniqueKey[i].VarLeft < uniqueKey[j].VarLeft
|
||||||
}
|
})
|
||||||
|
|
||||||
|
return primaryKey, uniqueKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func genCacheKey(table stringx.String, in []*parser.Field) Key {
|
||||||
|
var (
|
||||||
|
varLeftJoin, varRightJon, fieldNameJoin Join
|
||||||
|
varLeft, varRight, varExpression string
|
||||||
|
|
||||||
|
keyLeftJoin, keyRightJoin, keyRightArgJoin, dataRightJoin Join
|
||||||
|
keyLeft, keyRight, dataKeyRight, keyExpression, dataKeyExpression string
|
||||||
|
)
|
||||||
|
|
||||||
|
varLeftJoin = append(varLeftJoin, "cache", table.Source())
|
||||||
|
varRightJon = append(varRightJon, "cache", table.Source())
|
||||||
|
keyLeftJoin = append(keyLeftJoin, table.Source())
|
||||||
|
|
||||||
|
for _, each := range in {
|
||||||
|
varLeftJoin = append(varLeftJoin, each.Name.Source())
|
||||||
|
varRightJon = append(varRightJon, each.Name.Source())
|
||||||
|
keyLeftJoin = append(keyLeftJoin, each.Name.Source())
|
||||||
|
keyRightJoin = append(keyRightJoin, stringx.From(each.Name.ToCamel()).Untitle())
|
||||||
|
keyRightArgJoin = append(keyRightArgJoin, "%v")
|
||||||
|
dataRightJoin = append(dataRightJoin, "data."+each.Name.ToCamel())
|
||||||
|
fieldNameJoin = append(fieldNameJoin, each.Name.Source())
|
||||||
|
}
|
||||||
|
varLeftJoin = append(varLeftJoin, "prefix")
|
||||||
|
keyLeftJoin = append(keyLeftJoin, "key")
|
||||||
|
|
||||||
|
varLeft = varLeftJoin.Camel().With("").Untitle()
|
||||||
|
varRight = fmt.Sprintf(`"%s"`, varRightJon.Camel().Untitle().With("#").Source()+"#")
|
||||||
|
varExpression = fmt.Sprintf(`%s = %s`, varLeft, varRight)
|
||||||
|
|
||||||
|
keyLeft = keyLeftJoin.Camel().With("").Untitle()
|
||||||
|
keyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With("").Source(), varLeft, keyRightJoin.With(", ").Source())
|
||||||
|
dataKeyRight = fmt.Sprintf(`fmt.Sprintf("%s%s", %s, %s)`, "%s", keyRightArgJoin.With("").Source(), varLeft, dataRightJoin.With(", ").Source())
|
||||||
|
keyExpression = fmt.Sprintf("%s := %s", keyLeft, keyRight)
|
||||||
|
dataKeyExpression = fmt.Sprintf("%s := %s", keyLeft, dataKeyRight)
|
||||||
|
|
||||||
|
return Key{
|
||||||
|
VarLeft: varLeft,
|
||||||
|
VarRight: varRight,
|
||||||
|
VarExpression: varExpression,
|
||||||
|
KeyLeft: keyLeft,
|
||||||
|
KeyRight: keyRight,
|
||||||
|
DataKeyRight: dataKeyRight,
|
||||||
|
KeyExpression: keyExpression,
|
||||||
|
DataKeyExpression: dataKeyExpression,
|
||||||
|
Fields: in,
|
||||||
|
FieldNameJoin: fieldNameJoin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Title convert items into Title and return
|
||||||
|
func (j Join) Title() Join {
|
||||||
|
var join Join
|
||||||
|
for _, each := range j {
|
||||||
|
join = append(join, stringx.From(each).Title())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return join
|
||||||
|
}
|
||||||
|
|
||||||
|
// Camel convert items into Camel and return
|
||||||
|
func (j Join) Camel() Join {
|
||||||
|
var join Join
|
||||||
|
for _, each := range j {
|
||||||
|
join = append(join, stringx.From(each).ToCamel())
|
||||||
|
}
|
||||||
|
return join
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snake convert items into Snake and return
|
||||||
|
func (j Join) Snake() Join {
|
||||||
|
var join Join
|
||||||
|
for _, each := range j {
|
||||||
|
join = append(join, stringx.From(each).ToSnake())
|
||||||
|
}
|
||||||
|
|
||||||
|
return join
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snake convert items into Untitle and return
|
||||||
|
func (j Join) Untitle() Join {
|
||||||
|
var join Join
|
||||||
|
for _, each := range j {
|
||||||
|
join = append(join, stringx.From(each).Untitle())
|
||||||
|
}
|
||||||
|
|
||||||
|
return join
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upper convert items into Upper and return
|
||||||
|
func (j Join) Upper() Join {
|
||||||
|
var join Join
|
||||||
|
for _, each := range j {
|
||||||
|
join = append(join, stringx.From(each).Upper())
|
||||||
|
}
|
||||||
|
|
||||||
|
return join
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lower convert items into Lower and return
|
||||||
|
func (j Join) Lower() Join {
|
||||||
|
var join Join
|
||||||
|
for _, each := range j {
|
||||||
|
join = append(join, stringx.From(each).Lower())
|
||||||
|
}
|
||||||
|
|
||||||
|
return join
|
||||||
|
}
|
||||||
|
|
||||||
|
// With convert items into With and return
|
||||||
|
func (j Join) With(sep string) stringx.String {
|
||||||
|
return stringx.From(strings.Join(j, sep))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package gen
|
package gen
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -10,62 +10,156 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestGenCacheKeys(t *testing.T) {
|
func TestGenCacheKeys(t *testing.T) {
|
||||||
m, err := genCacheKeys(parser.Table{
|
primaryField := &parser.Field{
|
||||||
|
Name: stringx.From("id"),
|
||||||
|
DataBaseType: "bigint",
|
||||||
|
DataType: "int64",
|
||||||
|
Comment: "自增id",
|
||||||
|
SeqInIndex: 1,
|
||||||
|
}
|
||||||
|
mobileField := &parser.Field{
|
||||||
|
Name: stringx.From("mobile"),
|
||||||
|
DataBaseType: "varchar",
|
||||||
|
DataType: "string",
|
||||||
|
Comment: "手机号",
|
||||||
|
SeqInIndex: 1,
|
||||||
|
}
|
||||||
|
classField := &parser.Field{
|
||||||
|
Name: stringx.From("class"),
|
||||||
|
DataBaseType: "varchar",
|
||||||
|
DataType: "string",
|
||||||
|
Comment: "班级",
|
||||||
|
SeqInIndex: 1,
|
||||||
|
}
|
||||||
|
nameField := &parser.Field{
|
||||||
|
Name: stringx.From("name"),
|
||||||
|
DataBaseType: "varchar",
|
||||||
|
DataType: "string",
|
||||||
|
Comment: "姓名",
|
||||||
|
SeqInIndex: 2,
|
||||||
|
}
|
||||||
|
primariCacheKey, uniqueCacheKey := genCacheKeys(parser.Table{
|
||||||
Name: stringx.From("user"),
|
Name: stringx.From("user"),
|
||||||
PrimaryKey: parser.Primary{
|
PrimaryKey: parser.Primary{
|
||||||
Field: parser.Field{
|
Field: *primaryField,
|
||||||
Name: stringx.From("id"),
|
|
||||||
DataBaseType: "bigint",
|
|
||||||
DataType: "int64",
|
|
||||||
IsPrimaryKey: true,
|
|
||||||
IsUniqueKey: false,
|
|
||||||
Comment: "自增id",
|
|
||||||
},
|
|
||||||
AutoIncrement: true,
|
AutoIncrement: true,
|
||||||
},
|
},
|
||||||
Fields: []parser.Field{
|
UniqueIndex: map[string][]*parser.Field{
|
||||||
{
|
"mobile_unique": []*parser.Field{
|
||||||
Name: stringx.From("mobile"),
|
mobileField,
|
||||||
DataBaseType: "varchar",
|
|
||||||
DataType: "string",
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
IsUniqueKey: true,
|
|
||||||
Comment: "手机号",
|
|
||||||
},
|
},
|
||||||
{
|
"class_name_unique": []*parser.Field{
|
||||||
Name: stringx.From("name"),
|
classField,
|
||||||
DataBaseType: "varchar",
|
nameField,
|
||||||
DataType: "string",
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
IsUniqueKey: true,
|
|
||||||
Comment: "姓名",
|
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
NormalIndex: nil,
|
||||||
|
Fields: []*parser.Field{
|
||||||
|
primaryField,
|
||||||
|
mobileField,
|
||||||
|
classField,
|
||||||
|
nameField,
|
||||||
{
|
{
|
||||||
Name: stringx.From("createTime"),
|
Name: stringx.From("createTime"),
|
||||||
DataBaseType: "timestamp",
|
DataBaseType: "timestamp",
|
||||||
DataType: "time.Time",
|
DataType: "time.Time",
|
||||||
IsPrimaryKey: false,
|
|
||||||
IsUniqueKey: false,
|
|
||||||
Comment: "创建时间",
|
Comment: "创建时间",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: stringx.From("updateTime"),
|
Name: stringx.From("updateTime"),
|
||||||
DataBaseType: "timestamp",
|
DataBaseType: "timestamp",
|
||||||
DataType: "time.Time",
|
DataType: "time.Time",
|
||||||
IsPrimaryKey: false,
|
|
||||||
IsUniqueKey: false,
|
|
||||||
Comment: "更新时间",
|
Comment: "更新时间",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
for fieldName, key := range m {
|
t.Run("primaryCacheKey", func(t *testing.T) {
|
||||||
name := stringx.From(fieldName)
|
assert.Equal(t, true, func() bool {
|
||||||
assert.Equal(t, fmt.Sprintf(`cacheUser%sPrefix = "cache#User#%s#"`, name.ToCamel(), name.Untitle()), key.VarExpression)
|
return cacheKeyEqual(primariCacheKey, Key{
|
||||||
assert.Equal(t, fmt.Sprintf(`cacheUser%sPrefix`, name.ToCamel()), key.Left)
|
VarLeft: "cacheUserIdPrefix",
|
||||||
assert.Equal(t, fmt.Sprintf(`cache#User#%s#`, name.Untitle()), key.Right)
|
VarRight: `"cache#user#id#"`,
|
||||||
assert.Equal(t, fmt.Sprintf(`user%sKey`, name.ToCamel()), key.Variable)
|
VarExpression: `cacheUserIdPrefix = "cache#user#id#"`,
|
||||||
assert.Equal(t, `user`+name.ToCamel()+`Key := fmt.Sprintf("%s%v", cacheUser`+name.ToCamel()+`Prefix,`+name.Untitle()+`)`, key.KeyExpression)
|
KeyLeft: "userIdKey",
|
||||||
}
|
KeyRight: `fmt.Sprintf("%s%v", cacheUserIdPrefix, id)`,
|
||||||
|
DataKeyRight: `fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)`,
|
||||||
|
KeyExpression: `userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, id)`,
|
||||||
|
DataKeyExpression: `userIdKey := fmt.Sprintf("%s%v", cacheUserIdPrefix, data.Id)`,
|
||||||
|
FieldNameJoin: []string{"id"},
|
||||||
|
})
|
||||||
|
}())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uniqueCacheKey", func(t *testing.T) {
|
||||||
|
assert.Equal(t, true, func() bool {
|
||||||
|
expected := []Key{
|
||||||
|
{
|
||||||
|
VarLeft: "cacheUserClassNamePrefix",
|
||||||
|
VarRight: `"cache#user#class#name#"`,
|
||||||
|
VarExpression: `cacheUserClassNamePrefix = "cache#user#class#name#"`,
|
||||||
|
KeyLeft: "userClassNameKey",
|
||||||
|
KeyRight: `fmt.Sprintf("%s%v%v", cacheUserClassNamePrefix, class, name)`,
|
||||||
|
DataKeyRight: `fmt.Sprintf("%s%v%v", cacheUserClassNamePrefix, data.Class, data.Name)`,
|
||||||
|
KeyExpression: `userClassNameKey := fmt.Sprintf("%s%v%v", cacheUserClassNamePrefix, class, name)`,
|
||||||
|
DataKeyExpression: `userClassNameKey := fmt.Sprintf("%s%v%v", cacheUserClassNamePrefix, data.Class, data.Name)`,
|
||||||
|
FieldNameJoin: []string{"class", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
VarLeft: "cacheUserMobilePrefix",
|
||||||
|
VarRight: `"cache#user#mobile#"`,
|
||||||
|
VarExpression: `cacheUserMobilePrefix = "cache#user#mobile#"`,
|
||||||
|
KeyLeft: "userMobileKey",
|
||||||
|
KeyRight: `fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)`,
|
||||||
|
DataKeyRight: `fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)`,
|
||||||
|
KeyExpression: `userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, mobile)`,
|
||||||
|
DataKeyExpression: `userMobileKey := fmt.Sprintf("%s%v", cacheUserMobilePrefix, data.Mobile)`,
|
||||||
|
FieldNameJoin: []string{"mobile"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sort.Slice(uniqueCacheKey, func(i, j int) bool {
|
||||||
|
return uniqueCacheKey[i].VarLeft < uniqueCacheKey[j].VarLeft
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(expected) != len(uniqueCacheKey) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, each := range uniqueCacheKey {
|
||||||
|
expecting := expected[index]
|
||||||
|
if !cacheKeyEqual(expecting, each) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}())
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheKeyEqual(k1 Key, k2 Key) bool {
|
||||||
|
k1Join := k1.FieldNameJoin
|
||||||
|
k2Join := k2.FieldNameJoin
|
||||||
|
sort.Strings(k1Join)
|
||||||
|
sort.Strings(k2Join)
|
||||||
|
if len(k1Join) != len(k2Join) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, each := range k1Join {
|
||||||
|
k2Item := k2Join[index]
|
||||||
|
if each != k2Item {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return k1.VarLeft == k2.VarLeft &&
|
||||||
|
k1.VarRight == k2.VarRight &&
|
||||||
|
k1.VarExpression == k2.VarExpression &&
|
||||||
|
k1.KeyLeft == k2.KeyLeft &&
|
||||||
|
k1.KeyRight == k2.KeyRight &&
|
||||||
|
k1.DataKeyRight == k2.DataKeyRight &&
|
||||||
|
k1.DataKeyExpression == k2.DataKeyExpression &&
|
||||||
|
k1.KeyExpression == k2.KeyExpression
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if field.IsPrimaryKey {
|
if field.Name.Source() == table.PrimaryKey.Name.Source() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,8 +35,8 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
|
|||||||
Execute(map[string]interface{}{
|
Execute(map[string]interface{}{
|
||||||
"withCache": withCache,
|
"withCache": withCache,
|
||||||
"upperStartCamelObject": camelTableName,
|
"upperStartCamelObject": camelTableName,
|
||||||
"primaryCacheKey": table.CacheKey[table.PrimaryKey.Name.Source()].DataKeyExpression,
|
"primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression,
|
||||||
"primaryKeyVariable": table.CacheKey[table.PrimaryKey.Name.Source()].Variable,
|
"primaryKeyVariable": table.PrimaryCacheKey.KeyLeft,
|
||||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||||
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||||
"expressionValues": strings.Join(expressionValues, ", "),
|
"expressionValues": strings.Join(expressionValues, ", "),
|
||||||
|
|||||||
@@ -10,26 +10,26 @@ import (
|
|||||||
|
|
||||||
func genVars(table Table, withCache bool) (string, error) {
|
func genVars(table Table, withCache bool) (string, error) {
|
||||||
keys := make([]string, 0)
|
keys := make([]string, 0)
|
||||||
for _, v := range table.CacheKey {
|
keys = append(keys, table.PrimaryCacheKey.VarExpression)
|
||||||
|
for _, v := range table.UniqueCacheKey {
|
||||||
keys = append(keys, v.VarExpression)
|
keys = append(keys, v.VarExpression)
|
||||||
}
|
}
|
||||||
|
|
||||||
camel := table.Name.ToCamel()
|
camel := table.Name.ToCamel()
|
||||||
text, err := util.LoadTemplate(category, varTemplateFile, template.Vars)
|
text, err := util.LoadTemplate(category, varTemplateFile, template.Vars)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
output, err := util.With("var").
|
output, err := util.With("var").Parse(text).
|
||||||
Parse(text).
|
GoFmt(true).Execute(map[string]interface{}{
|
||||||
GoFmt(true).
|
"lowerStartCamelObject": stringx.From(camel).Untitle(),
|
||||||
Execute(map[string]interface{}{
|
"upperStartCamelObject": camel,
|
||||||
"lowerStartCamelObject": stringx.From(camel).Untitle(),
|
"cacheKeys": strings.Join(keys, "\n"),
|
||||||
"upperStartCamelObject": camel,
|
"autoIncrement": table.PrimaryKey.AutoIncrement,
|
||||||
"cacheKeys": strings.Join(keys, "\n"),
|
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
||||||
"autoIncrement": table.PrimaryKey.AutoIncrement,
|
"withCache": withCache,
|
||||||
"originalPrimaryKey": wrapWithRawString(table.PrimaryKey.Name.Source()),
|
})
|
||||||
"withCache": withCache,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import "github.com/tal-tech/go-zero/core/stores/sqlx"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/stores/sqlx"
|
||||||
|
)
|
||||||
|
|
||||||
|
const indexPri = "PRIMARY"
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// InformationSchemaModel defines information schema model
|
// InformationSchemaModel defines information schema model
|
||||||
@@ -10,13 +17,53 @@ type (
|
|||||||
|
|
||||||
// Column defines column in table
|
// Column defines column in table
|
||||||
Column struct {
|
Column struct {
|
||||||
Name string `db:"COLUMN_NAME"`
|
*DbColumn
|
||||||
DataType string `db:"DATA_TYPE"`
|
Index *DbIndex
|
||||||
Key string `db:"COLUMN_KEY"`
|
}
|
||||||
Extra string `db:"EXTRA"`
|
|
||||||
Comment string `db:"COLUMN_COMMENT"`
|
// DbColumn defines column info of columns
|
||||||
ColumnDefault interface{} `db:"COLUMN_DEFAULT"`
|
DbColumn struct {
|
||||||
IsNullAble string `db:"IS_NULLABLE"`
|
Name string `db:"COLUMN_NAME"`
|
||||||
|
DataType string `db:"DATA_TYPE"`
|
||||||
|
Extra string `db:"EXTRA"`
|
||||||
|
Comment string `db:"COLUMN_COMMENT"`
|
||||||
|
ColumnDefault interface{} `db:"COLUMN_DEFAULT"`
|
||||||
|
IsNullAble string `db:"IS_NULLABLE"`
|
||||||
|
OrdinalPosition int `db:"ORDINAL_POSITION"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DbIndex defines index of columns in information_schema.statistic
|
||||||
|
DbIndex struct {
|
||||||
|
IndexName string `db:"INDEX_NAME"`
|
||||||
|
NonUnique int `db:"NON_UNIQUE"`
|
||||||
|
SeqInIndex int `db:"SEQ_IN_INDEX"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ColumnData describes the columns of table
|
||||||
|
ColumnData struct {
|
||||||
|
Db string
|
||||||
|
Table string
|
||||||
|
Columns []*Column
|
||||||
|
}
|
||||||
|
|
||||||
|
// Table describes mysql table which contains database name, table name, columns, keys
|
||||||
|
Table struct {
|
||||||
|
Db string
|
||||||
|
Table string
|
||||||
|
Columns []*Column
|
||||||
|
// Primary key not included
|
||||||
|
UniqueIndex map[string][]*Column
|
||||||
|
PrimaryKey *Column
|
||||||
|
NormalIndex map[string][]*Column
|
||||||
|
}
|
||||||
|
|
||||||
|
// IndexType describes an alias of string
|
||||||
|
IndexType string
|
||||||
|
|
||||||
|
// Index describes a column index
|
||||||
|
Index struct {
|
||||||
|
IndexType IndexType
|
||||||
|
Columns []*Column
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,10 +84,102 @@ func (m *InformationSchemaModel) GetAllTables(database string) ([]string, error)
|
|||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindByTableName finds out the target table by name
|
// FindColumns return columns in specified database and table
|
||||||
func (m *InformationSchemaModel) FindByTableName(db, table string) ([]*Column, error) {
|
func (m *InformationSchemaModel) FindColumns(db, table string) (*ColumnData, error) {
|
||||||
querySQL := `select COLUMN_NAME,COLUMN_DEFAULT,IS_NULLABLE,DATA_TYPE,COLUMN_KEY,EXTRA,COLUMN_COMMENT from COLUMNS where TABLE_SCHEMA = ? and TABLE_NAME = ?`
|
querySql := `SELECT c.COLUMN_NAME,c.DATA_TYPE,EXTRA,c.COLUMN_COMMENT,c.COLUMN_DEFAULT,c.IS_NULLABLE,c.ORDINAL_POSITION from COLUMNS c WHERE c.TABLE_SCHEMA = ? and c.TABLE_NAME = ? `
|
||||||
var reply []*Column
|
var reply []*DbColumn
|
||||||
err := m.conn.QueryRows(&reply, querySQL, db, table)
|
err := m.conn.QueryRowsPartial(&reply, querySql, db, table)
|
||||||
return reply, err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var list []*Column
|
||||||
|
for _, item := range reply {
|
||||||
|
index, err := m.FindIndex(db, table, item.Name)
|
||||||
|
if err != nil {
|
||||||
|
if err != sqlx.ErrNotFound {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(index) > 0 {
|
||||||
|
for _, i := range index {
|
||||||
|
list = append(list, &Column{
|
||||||
|
DbColumn: item,
|
||||||
|
Index: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
list = append(list, &Column{
|
||||||
|
DbColumn: item,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(list, func(i, j int) bool {
|
||||||
|
return list[i].OrdinalPosition < list[j].OrdinalPosition
|
||||||
|
})
|
||||||
|
|
||||||
|
var columnData ColumnData
|
||||||
|
columnData.Db = db
|
||||||
|
columnData.Table = table
|
||||||
|
columnData.Columns = list
|
||||||
|
return &columnData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *InformationSchemaModel) FindIndex(db, table, column string) ([]*DbIndex, error) {
|
||||||
|
querySql := `SELECT s.INDEX_NAME,s.NON_UNIQUE,s.SEQ_IN_INDEX from STATISTICS s WHERE s.TABLE_SCHEMA = ? and s.TABLE_NAME = ? and s.COLUMN_NAME = ?`
|
||||||
|
var reply []*DbIndex
|
||||||
|
err := m.conn.QueryRowsPartial(&reply, querySql, db, table, column)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return reply, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert converts column data into Table
|
||||||
|
func (c *ColumnData) Convert() (*Table, error) {
|
||||||
|
var table Table
|
||||||
|
table.Table = c.Table
|
||||||
|
table.Db = c.Db
|
||||||
|
table.Columns = c.Columns
|
||||||
|
table.UniqueIndex = map[string][]*Column{}
|
||||||
|
table.NormalIndex = map[string][]*Column{}
|
||||||
|
|
||||||
|
m := make(map[string][]*Column)
|
||||||
|
for _, each := range c.Columns {
|
||||||
|
if each.Index != nil {
|
||||||
|
m[each.Index.IndexName] = append(m[each.Index.IndexName], each)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryColumns := m[indexPri]
|
||||||
|
if len(primaryColumns) == 0 {
|
||||||
|
return nil, fmt.Errorf("db:%s, table:%s, missing primary key", c.Db, c.Table)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(primaryColumns) > 1 {
|
||||||
|
return nil, fmt.Errorf("db:%s, table:%s, joint primary key is not supported", c.Db, c.Table)
|
||||||
|
}
|
||||||
|
|
||||||
|
table.PrimaryKey = primaryColumns[0]
|
||||||
|
for indexName, columns := range m {
|
||||||
|
if indexName == indexPri {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, one := range columns {
|
||||||
|
if one.Index != nil {
|
||||||
|
if one.Index.NonUnique == 0 {
|
||||||
|
table.UniqueIndex[indexName] = columns
|
||||||
|
} else {
|
||||||
|
table.NormalIndex[indexName] = columns
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &table, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,30 +2,27 @@ package parser
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tal-tech/go-zero/core/collection"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/converter"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/util/console"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
||||||
"github.com/xwb1989/sqlparser"
|
"github.com/xwb1989/sqlparser"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
_ = iota
|
|
||||||
primary
|
|
||||||
unique
|
|
||||||
normal
|
|
||||||
spatial
|
|
||||||
)
|
|
||||||
|
|
||||||
const timeImport = "time.Time"
|
const timeImport = "time.Time"
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// Table describes a mysql table
|
// Table describes a mysql table
|
||||||
Table struct {
|
Table struct {
|
||||||
Name stringx.String
|
Name stringx.String
|
||||||
PrimaryKey Primary
|
PrimaryKey Primary
|
||||||
Fields []Field
|
UniqueIndex map[string][]*Field
|
||||||
|
NormalIndex map[string][]*Field
|
||||||
|
Fields []*Field
|
||||||
}
|
}
|
||||||
|
|
||||||
// Primary describes a primary key
|
// Primary describes a primary key
|
||||||
@@ -36,12 +33,12 @@ type (
|
|||||||
|
|
||||||
// Field describes a table field
|
// Field describes a table field
|
||||||
Field struct {
|
Field struct {
|
||||||
Name stringx.String
|
Name stringx.String
|
||||||
DataBaseType string
|
DataBaseType string
|
||||||
DataType string
|
DataType string
|
||||||
IsPrimaryKey bool
|
Comment string
|
||||||
IsUniqueKey bool
|
SeqInIndex int
|
||||||
Comment string
|
OrdinalPosition int
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeyType types alias of int
|
// KeyType types alias of int
|
||||||
@@ -73,34 +70,58 @@ func Parse(ddl string) (*Table, error) {
|
|||||||
|
|
||||||
columns := tableSpec.Columns
|
columns := tableSpec.Columns
|
||||||
indexes := tableSpec.Indexes
|
indexes := tableSpec.Indexes
|
||||||
keyMap, err := getIndexKeyType(indexes)
|
primaryColumn, uniqueKeyMap, normalKeyMap, err := convertIndexes(indexes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
fields, primaryKey, err := convertFileds(columns, keyMap)
|
fields, primaryKey, fieldM, err := convertColumns(columns, primaryColumn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
uniqueIndex = make(map[string][]*Field)
|
||||||
|
normalIndex = make(map[string][]*Field)
|
||||||
|
)
|
||||||
|
for indexName, each := range uniqueKeyMap {
|
||||||
|
for _, columnName := range each {
|
||||||
|
uniqueIndex[indexName] = append(uniqueIndex[indexName], fieldM[columnName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for indexName, each := range normalKeyMap {
|
||||||
|
for _, columnName := range each {
|
||||||
|
normalIndex[indexName] = append(normalIndex[indexName], fieldM[columnName])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &Table{
|
return &Table{
|
||||||
Name: stringx.From(tableName),
|
Name: stringx.From(tableName),
|
||||||
PrimaryKey: primaryKey,
|
PrimaryKey: primaryKey,
|
||||||
Fields: fields,
|
UniqueIndex: uniqueIndex,
|
||||||
|
NormalIndex: normalIndex,
|
||||||
|
Fields: fields,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyType) ([]Field, Primary, error) {
|
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) ([]*Field, Primary, map[string]*Field, error) {
|
||||||
var fields []Field
|
var (
|
||||||
var primaryKey Primary
|
fields []*Field
|
||||||
|
primaryKey Primary
|
||||||
|
fieldM = make(map[string]*Field)
|
||||||
|
)
|
||||||
|
|
||||||
for _, column := range columns {
|
for _, column := range columns {
|
||||||
if column == nil {
|
if column == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var comment string
|
var comment string
|
||||||
if column.Type.Comment != nil {
|
if column.Type.Comment != nil {
|
||||||
comment = string(column.Type.Comment.Val)
|
comment = string(column.Type.Comment.Val)
|
||||||
}
|
}
|
||||||
|
|
||||||
var isDefaultNull = true
|
var isDefaultNull = true
|
||||||
if column.Type.NotNull {
|
if column.Type.NotNull {
|
||||||
isDefaultNull = false
|
isDefaultNull = false
|
||||||
@@ -111,9 +132,10 @@ func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyT
|
|||||||
isDefaultNull = false
|
isDefaultNull = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
|
dataType, err := converter.ConvertDataType(column.Type.Type, isDefaultNull)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, primaryKey, err
|
return nil, Primary{}, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var field Field
|
var field Field
|
||||||
@@ -121,60 +143,75 @@ func convertFileds(columns []*sqlparser.ColumnDefinition, keyMap map[string]KeyT
|
|||||||
field.DataBaseType = column.Type.Type
|
field.DataBaseType = column.Type.Type
|
||||||
field.DataType = dataType
|
field.DataType = dataType
|
||||||
field.Comment = comment
|
field.Comment = comment
|
||||||
key, ok := keyMap[column.Name.String()]
|
|
||||||
if ok {
|
if field.Name.Source() == primaryColumn {
|
||||||
field.IsPrimaryKey = key == primary
|
primaryKey = Primary{
|
||||||
field.IsUniqueKey = key == unique
|
Field: field,
|
||||||
if field.IsPrimaryKey {
|
AutoIncrement: bool(column.Type.Autoincrement),
|
||||||
primaryKey.Field = field
|
|
||||||
if column.Type.Autoincrement {
|
|
||||||
primaryKey.AutoIncrement = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fields = append(fields, field)
|
|
||||||
|
fields = append(fields, &field)
|
||||||
|
fieldM[field.Name.Source()] = &field
|
||||||
}
|
}
|
||||||
return fields, primaryKey, nil
|
return fields, primaryKey, fieldM, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getIndexKeyType(indexes []*sqlparser.IndexDefinition) (map[string]KeyType, error) {
|
func convertIndexes(indexes []*sqlparser.IndexDefinition) (string, map[string][]string, map[string][]string, error) {
|
||||||
keyMap := make(map[string]KeyType)
|
var primaryColumn string
|
||||||
|
uniqueKeyMap := make(map[string][]string)
|
||||||
|
normalKeyMap := make(map[string][]string)
|
||||||
|
|
||||||
|
isCreateTimeOrUpdateTime := func(name string) bool {
|
||||||
|
camelColumnName := stringx.From(name).ToCamel()
|
||||||
|
// by default, createTime|updateTime findOne is not used.
|
||||||
|
return camelColumnName == "CreateTime" || camelColumnName == "UpdateTime"
|
||||||
|
}
|
||||||
|
|
||||||
for _, index := range indexes {
|
for _, index := range indexes {
|
||||||
info := index.Info
|
info := index.Info
|
||||||
if info == nil {
|
if info == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
indexName := index.Info.Name.String()
|
||||||
if info.Primary {
|
if info.Primary {
|
||||||
if len(index.Columns) > 1 {
|
if len(index.Columns) > 1 {
|
||||||
return nil, errPrimaryKey
|
return "", nil, nil, errPrimaryKey
|
||||||
|
}
|
||||||
|
columnName := index.Columns[0].Column.String()
|
||||||
|
if isCreateTimeOrUpdateTime(columnName) {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
keyMap[index.Columns[0].Column.String()] = primary
|
primaryColumn = columnName
|
||||||
continue
|
continue
|
||||||
}
|
} else if info.Unique {
|
||||||
// can optimize
|
for _, each := range index.Columns {
|
||||||
if len(index.Columns) > 1 {
|
columnName := each.Column.String()
|
||||||
continue
|
if isCreateTimeOrUpdateTime(columnName) {
|
||||||
}
|
break
|
||||||
column := index.Columns[0]
|
}
|
||||||
columnName := column.Column.String()
|
|
||||||
camelColumnName := stringx.From(columnName).ToCamel()
|
uniqueKeyMap[indexName] = append(uniqueKeyMap[indexName], columnName)
|
||||||
// by default, createTime|updateTime findOne is not used.
|
}
|
||||||
if camelColumnName == "CreateTime" || camelColumnName == "UpdateTime" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if info.Unique {
|
|
||||||
keyMap[columnName] = unique
|
|
||||||
} else if info.Spatial {
|
} else if info.Spatial {
|
||||||
keyMap[columnName] = spatial
|
// do nothing
|
||||||
} else {
|
} else {
|
||||||
keyMap[columnName] = normal
|
for _, each := range index.Columns {
|
||||||
|
columnName := each.Column.String()
|
||||||
|
if isCreateTimeOrUpdateTime(columnName) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
normalKeyMap[indexName] = append(normalKeyMap[indexName], each.Column.String())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return keyMap, nil
|
return primaryColumn, uniqueKeyMap, normalKeyMap, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ContainsTime determines whether the table field contains time.Time
|
// ContainsTime returns true if contains golang type time.Time
|
||||||
func (t *Table) ContainsTime() bool {
|
func (t *Table) ContainsTime() bool {
|
||||||
for _, item := range t.Fields {
|
for _, item := range t.Fields {
|
||||||
if item.DataType == timeImport {
|
if item.DataType == timeImport {
|
||||||
@@ -184,63 +221,110 @@ func (t *Table) ContainsTime() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertColumn provides type conversion for mysql clolumn, primary key lookup
|
// ConvertDataType converts mysql data type into golang data type
|
||||||
func ConvertColumn(db, table string, in []*model.Column) (*Table, error) {
|
func ConvertDataType(table *model.Table) (*Table, error) {
|
||||||
var reply Table
|
isPrimaryDefaultNull := table.PrimaryKey.ColumnDefault == nil && table.PrimaryKey.IsNullAble == "YES"
|
||||||
reply.Name = stringx.From(table)
|
primaryDataType, err := converter.ConvertDataType(table.PrimaryKey.DataType, isPrimaryDefaultNull)
|
||||||
keyMap := make(map[string][]*model.Column)
|
|
||||||
|
|
||||||
for _, column := range in {
|
|
||||||
keyMap[column.Key] = append(keyMap[column.Key], column)
|
|
||||||
}
|
|
||||||
primaryColumns := keyMap["PRI"]
|
|
||||||
if len(primaryColumns) == 0 {
|
|
||||||
return nil, fmt.Errorf("database:%s, table %s: missing primary key", db, table)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(primaryColumns) > 1 {
|
|
||||||
return nil, fmt.Errorf("database:%s, table %s: only one primary key expected", db, table)
|
|
||||||
}
|
|
||||||
|
|
||||||
primaryColumn := primaryColumns[0]
|
|
||||||
isDefaultNull := primaryColumn.ColumnDefault == nil && primaryColumn.IsNullAble == "YES"
|
|
||||||
primaryFt, err := converter.ConvertDataType(primaryColumn.DataType, isDefaultNull)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
primaryField := Field{
|
var reply Table
|
||||||
Name: stringx.From(primaryColumn.Name),
|
reply.UniqueIndex = map[string][]*Field{}
|
||||||
DataBaseType: primaryColumn.DataType,
|
reply.NormalIndex = map[string][]*Field{}
|
||||||
DataType: primaryFt,
|
reply.Name = stringx.From(table.Table)
|
||||||
IsUniqueKey: true,
|
seqInIndex := 0
|
||||||
IsPrimaryKey: true,
|
if table.PrimaryKey.Index != nil {
|
||||||
Comment: primaryColumn.Comment,
|
seqInIndex = table.PrimaryKey.Index.SeqInIndex
|
||||||
}
|
}
|
||||||
reply.PrimaryKey = Primary{
|
|
||||||
Field: primaryField,
|
|
||||||
AutoIncrement: strings.Contains(primaryColumn.Extra, "auto_increment"),
|
|
||||||
}
|
|
||||||
for key, columns := range keyMap {
|
|
||||||
for _, item := range columns {
|
|
||||||
isColumnDefaultNull := item.ColumnDefault == nil && item.IsNullAble == "YES"
|
|
||||||
dt, err := converter.ConvertDataType(item.DataType, isColumnDefaultNull)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
f := Field{
|
reply.PrimaryKey = Primary{
|
||||||
Name: stringx.From(item.Name),
|
Field: Field{
|
||||||
DataBaseType: item.DataType,
|
Name: stringx.From(table.PrimaryKey.Name),
|
||||||
DataType: dt,
|
DataBaseType: table.PrimaryKey.DataType,
|
||||||
IsPrimaryKey: primaryColumn.Name == item.Name,
|
DataType: primaryDataType,
|
||||||
Comment: item.Comment,
|
Comment: table.PrimaryKey.Comment,
|
||||||
}
|
SeqInIndex: seqInIndex,
|
||||||
if key == "UNI" {
|
OrdinalPosition: table.PrimaryKey.OrdinalPosition,
|
||||||
f.IsUniqueKey = true
|
},
|
||||||
}
|
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
|
||||||
reply.Fields = append(reply.Fields, f)
|
}
|
||||||
|
|
||||||
|
fieldM := make(map[string]*Field)
|
||||||
|
for _, each := range table.Columns {
|
||||||
|
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
||||||
|
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
columnSeqInIndex := 0
|
||||||
|
if each.Index != nil {
|
||||||
|
columnSeqInIndex = each.Index.SeqInIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
field := &Field{
|
||||||
|
Name: stringx.From(each.Name),
|
||||||
|
DataBaseType: each.DataType,
|
||||||
|
DataType: dt,
|
||||||
|
Comment: each.Comment,
|
||||||
|
SeqInIndex: columnSeqInIndex,
|
||||||
|
OrdinalPosition: each.OrdinalPosition,
|
||||||
|
}
|
||||||
|
fieldM[each.Name] = field
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, each := range fieldM {
|
||||||
|
reply.Fields = append(reply.Fields, each)
|
||||||
|
}
|
||||||
|
sort.Slice(reply.Fields, func(i, j int) bool {
|
||||||
|
return reply.Fields[i].OrdinalPosition < reply.Fields[j].OrdinalPosition
|
||||||
|
})
|
||||||
|
|
||||||
|
uniqueIndexSet := collection.NewSet()
|
||||||
|
log := console.NewColorConsole()
|
||||||
|
for indexName, each := range table.UniqueIndex {
|
||||||
|
sort.Slice(each, func(i, j int) bool {
|
||||||
|
if each[i].Index != nil {
|
||||||
|
return each[i].Index.SeqInIndex < each[j].Index.SeqInIndex
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(each) == 1 {
|
||||||
|
one := each[0]
|
||||||
|
if one.Name == table.PrimaryKey.Name {
|
||||||
|
log.Warning("duplicate unique index with primary key, %s", one.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var list []*Field
|
||||||
|
var uniqueJoin []string
|
||||||
|
for _, c := range each {
|
||||||
|
list = append(list, fieldM[c.Name])
|
||||||
|
uniqueJoin = append(uniqueJoin, c.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
uniqueKey := strings.Join(uniqueJoin, ",")
|
||||||
|
if uniqueIndexSet.Contains(uniqueKey) {
|
||||||
|
log.Warning("duplicate unique index, %s", uniqueKey)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
reply.UniqueIndex[indexName] = list
|
||||||
|
}
|
||||||
|
|
||||||
|
for indexName, each := range table.NormalIndex {
|
||||||
|
var list []*Field
|
||||||
|
for _, c := range each {
|
||||||
|
list = append(list, fieldM[c.Name])
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(list, func(i, j int) bool {
|
||||||
|
return list[i].SeqInIndex < list[j].SeqInIndex
|
||||||
|
})
|
||||||
|
|
||||||
|
reply.NormalIndex[indexName] = list
|
||||||
}
|
}
|
||||||
|
|
||||||
return &reply, nil
|
return &reply, nil
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
package parser
|
package parser
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
|
"github.com/tal-tech/go-zero/tools/goctl/model/sql/model"
|
||||||
|
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParsePlainText(t *testing.T) {
|
func TestParsePlainText(t *testing.T) {
|
||||||
@@ -18,68 +20,158 @@ func TestParseSelect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestParseCreateTable(t *testing.T) {
|
func TestParseCreateTable(t *testing.T) {
|
||||||
table, err := Parse("CREATE TABLE `user_snake` (\n `id` bigint(10) NOT NULL AUTO_INCREMENT,\n `name` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户名称',\n `password` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '用户密码',\n `mobile` varchar(255) COLLATE utf8mb4_general_ci NOT NULL DEFAULT '' COMMENT '手机号',\n `gender` char(5) COLLATE utf8mb4_general_ci NOT NULL COMMENT '男|女|未公开',\n `nickname` varchar(255) COLLATE utf8mb4_general_ci DEFAULT '' COMMENT '用户昵称',\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `name_index` (`name`),\n KEY `mobile_index` (`mobile`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;")
|
table, err := Parse("CREATE TABLE `test_user` (\n `id` bigint NOT NULL AUTO_INCREMENT,\n `mobile` varchar(255) COLLATE utf8mb4_bin NOT NULL,\n `class` bigint NOT NULL,\n `name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,\n `create_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP,\n `update_time` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n PRIMARY KEY (`id`),\n UNIQUE KEY `mobile_unique` (`mobile`),\n UNIQUE KEY `class_name_unique` (`class`,`name`),\n KEY `create_index` (`create_time`),\n KEY `name_index` (`name`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "user_snake", table.Name.Source())
|
assert.Equal(t, "test_user", table.Name.Source())
|
||||||
assert.Equal(t, "id", table.PrimaryKey.Name.Source())
|
assert.Equal(t, "id", table.PrimaryKey.Name.Source())
|
||||||
assert.Equal(t, true, table.ContainsTime())
|
assert.Equal(t, true, table.ContainsTime())
|
||||||
|
assert.Equal(t, true, func() bool {
|
||||||
|
mobileUniqueIndex, ok := table.UniqueIndex["mobile_unique"]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
classNameUniqueIndex, ok := table.UniqueIndex["class_name_unique"]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
equal := func(f1, f2 []*Field) bool {
|
||||||
|
sort.Slice(f1, func(i, j int) bool {
|
||||||
|
return f1[i].Name.Source() < f1[j].Name.Source()
|
||||||
|
})
|
||||||
|
sort.Slice(f2, func(i, j int) bool {
|
||||||
|
return f2[i].Name.Source() < f2[j].Name.Source()
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(f2) != len(f2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, f := range f1 {
|
||||||
|
if f1[index].Name.Source() != f.Name.Source() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !equal(mobileUniqueIndex, []*Field{
|
||||||
|
{
|
||||||
|
Name: stringx.From("mobile"),
|
||||||
|
DataBaseType: "varchar",
|
||||||
|
DataType: "string",
|
||||||
|
SeqInIndex: 1,
|
||||||
|
},
|
||||||
|
}) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return equal(classNameUniqueIndex, []*Field{
|
||||||
|
{
|
||||||
|
Name: stringx.From("class"),
|
||||||
|
DataBaseType: "bigint",
|
||||||
|
DataType: "int64",
|
||||||
|
SeqInIndex: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: stringx.From("name"),
|
||||||
|
DataBaseType: "varchar",
|
||||||
|
DataType: "string",
|
||||||
|
SeqInIndex: 2,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertColumn(t *testing.T) {
|
func TestConvertColumn(t *testing.T) {
|
||||||
_, err := ConvertColumn("user", "user", []*model.Column{
|
t.Run("missingPrimaryKey", func(t *testing.T) {
|
||||||
{
|
columnData := model.ColumnData{
|
||||||
Name: "id",
|
Db: "user",
|
||||||
DataType: "bigint",
|
Table: "user",
|
||||||
Key: "",
|
Columns: []*model.Column{
|
||||||
Extra: "",
|
{
|
||||||
Comment: "",
|
DbColumn: &model.DbColumn{
|
||||||
},
|
Name: "id",
|
||||||
})
|
DataType: "bigint",
|
||||||
assert.NotNil(t, err)
|
},
|
||||||
assert.Contains(t, err.Error(), "missing primary key")
|
},
|
||||||
|
},
|
||||||
_, err = ConvertColumn("user", "user", []*model.Column{
|
|
||||||
{
|
|
||||||
Name: "id",
|
|
||||||
DataType: "bigint",
|
|
||||||
Key: "PRI",
|
|
||||||
Extra: "",
|
|
||||||
Comment: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "mobile",
|
|
||||||
DataType: "varchar",
|
|
||||||
Key: "PRI",
|
|
||||||
Extra: "",
|
|
||||||
Comment: "手机号",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
assert.NotNil(t, err)
|
|
||||||
assert.Contains(t, err.Error(), "only one primary key expected")
|
|
||||||
|
|
||||||
table, err := ConvertColumn("user", "user", []*model.Column{
|
|
||||||
{
|
|
||||||
Name: "id",
|
|
||||||
DataType: "bigint",
|
|
||||||
Key: "PRI",
|
|
||||||
Extra: "auto_increment",
|
|
||||||
Comment: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "mobile",
|
|
||||||
DataType: "varchar",
|
|
||||||
Key: "UNI",
|
|
||||||
Extra: "",
|
|
||||||
Comment: "手机号",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.True(t, table.PrimaryKey.AutoIncrement && table.PrimaryKey.IsPrimaryKey)
|
|
||||||
assert.Equal(t, "id", table.PrimaryKey.Name.Source())
|
|
||||||
for _, item := range table.Fields {
|
|
||||||
if item.Name.Source() == "mobile" {
|
|
||||||
assert.True(t, item.IsUniqueKey)
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
_, err := columnData.Convert()
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "missing primary key")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("jointPrimaryKey", func(t *testing.T) {
|
||||||
|
columnData := model.ColumnData{
|
||||||
|
Db: "user",
|
||||||
|
Table: "user",
|
||||||
|
Columns: []*model.Column{
|
||||||
|
{
|
||||||
|
DbColumn: &model.DbColumn{
|
||||||
|
Name: "id",
|
||||||
|
DataType: "bigint",
|
||||||
|
},
|
||||||
|
Index: &model.DbIndex{
|
||||||
|
IndexName: "PRIMARY",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DbColumn: &model.DbColumn{
|
||||||
|
Name: "mobile",
|
||||||
|
DataType: "varchar",
|
||||||
|
Comment: "手机号",
|
||||||
|
},
|
||||||
|
Index: &model.DbIndex{
|
||||||
|
IndexName: "PRIMARY",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := columnData.Convert()
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "joint primary key is not supported")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("normal", func(t *testing.T) {
|
||||||
|
columnData := model.ColumnData{
|
||||||
|
Db: "user",
|
||||||
|
Table: "user",
|
||||||
|
Columns: []*model.Column{
|
||||||
|
{
|
||||||
|
DbColumn: &model.DbColumn{
|
||||||
|
Name: "id",
|
||||||
|
DataType: "bigint",
|
||||||
|
Extra: "auto_increment",
|
||||||
|
},
|
||||||
|
Index: &model.DbIndex{
|
||||||
|
IndexName: "PRIMARY",
|
||||||
|
SeqInIndex: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DbColumn: &model.DbColumn{
|
||||||
|
Name: "mobile",
|
||||||
|
DataType: "varchar",
|
||||||
|
Comment: "手机号",
|
||||||
|
},
|
||||||
|
Index: &model.DbIndex{
|
||||||
|
IndexName: "mobile_unique",
|
||||||
|
SeqInIndex: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := columnData.Convert()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.True(t, table.PrimaryKey.Index.IndexName == "PRIMARY" && table.PrimaryKey.Name == "id")
|
||||||
|
for _, item := range table.Columns {
|
||||||
|
if item.Name == "mobile" {
|
||||||
|
assert.True(t, item.Index.NonUnique == 0)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ func (m *default{{.upperStartCamelObject}}Model) FindOneBy{{.upperField}}({{.in}
|
|||||||
{{if .withCache}}{{.cacheKey}}
|
{{if .withCache}}{{.cacheKey}}
|
||||||
var resp {{.upperStartCamelObject}}
|
var resp {{.upperStartCamelObject}}
|
||||||
err := m.QueryRowIndex(&resp, {{.cacheKeyVariable}}, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
|
err := m.QueryRowIndex(&resp, {{.cacheKeyVariable}}, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
|
||||||
query := fmt.Sprintf("select %s from %s where {{.originalField}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table)
|
query := fmt.Sprintf("select %s from %s where {{.originalField}} limit 1", {{.lowerStartCamelObject}}Rows, m.table)
|
||||||
if err := conn.QueryRow(&resp, query, {{.lowerStartCamelField}}); err != nil {
|
if err := conn.QueryRow(&resp, query, {{.lowerStartCamelField}}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -51,7 +51,7 @@ func (m *default{{.upperStartCamelObject}}Model) FindOneBy{{.upperField}}({{.in}
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}{{else}}var resp {{.upperStartCamelObject}}
|
}{{else}}var resp {{.upperStartCamelObject}}
|
||||||
query := fmt.Sprintf("select %s from %s where {{.originalField}} = ? limit 1", {{.lowerStartCamelObject}}Rows, m.table )
|
query := fmt.Sprintf("select %s from %s where {{.originalField}} limit 1", {{.lowerStartCamelObject}}Rows, m.table )
|
||||||
err := m.conn.QueryRow(&resp, query, {{.lowerStartCamelField}})
|
err := m.conn.QueryRow(&resp, query, {{.lowerStartCamelField}})
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -20,11 +21,13 @@ func TestStudentModel(t *testing.T) {
|
|||||||
testTable = "`student`"
|
testTable = "`student`"
|
||||||
testUpdateName = "gozero1"
|
testUpdateName = "gozero1"
|
||||||
testRowsAffected int64 = 1
|
testRowsAffected int64 = 1
|
||||||
testInsertID int64 = 1
|
testInsertId int64 = 1
|
||||||
|
class = "一年级1班"
|
||||||
)
|
)
|
||||||
|
|
||||||
var data Student
|
var data Student
|
||||||
data.ID = testInsertID
|
data.Id = testInsertId
|
||||||
|
data.Class = class
|
||||||
data.Name = "gozero"
|
data.Name = "gozero"
|
||||||
data.Age = sql.NullInt64{
|
data.Age = sql.NullInt64{
|
||||||
Int64: 1,
|
Int64: 1,
|
||||||
@@ -42,15 +45,15 @@ func TestStudentModel(t *testing.T) {
|
|||||||
|
|
||||||
err := mockStudent(func(mock sqlmock.Sqlmock) {
|
err := mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
|
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
|
||||||
WithArgs(data.Name, data.Age, data.Score).
|
WithArgs(data.Class, data.Name, data.Age, data.Score).
|
||||||
WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected))
|
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
}, func(m StudentModel) {
|
}, func(m StudentModel, redis *redis.Redis) {
|
||||||
r, err := m.Insert(data)
|
r, err := m.Insert(data)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
lastInsertID, err := r.LastInsertId()
|
lastInsertId, err := r.LastInsertId()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, testInsertID, lastInsertID)
|
assert.Equal(t, testInsertId, lastInsertId)
|
||||||
|
|
||||||
rowsAffected, err := r.RowsAffected()
|
rowsAffected, err := r.RowsAffected()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
@@ -60,41 +63,84 @@ func TestStudentModel(t *testing.T) {
|
|||||||
|
|
||||||
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
|
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
|
||||||
WithArgs(testInsertID).
|
WithArgs(testInsertId).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertID, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
|
WillReturnRows(sqlmock.NewRows([]string{"id", "class", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Class, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
|
||||||
}, func(m StudentModel) {
|
}, func(m StudentModel, redis *redis.Redis) {
|
||||||
result, err := m.FindOne(testInsertID)
|
result, err := m.FindOne(testInsertId)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, *result, data)
|
assert.Equal(t, *result, data)
|
||||||
|
|
||||||
|
var resp Student
|
||||||
|
val, err := redis.Get(fmt.Sprintf("%s%v", cacheStudentIdPrefix, testInsertId))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = json.Unmarshal([]byte(val), &resp)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, resp.Name, data.Name)
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(testUpdateName, data.Age, data.Score, testInsertID).WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected))
|
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.Class, testUpdateName, data.Age, data.Score, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
}, func(m StudentModel) {
|
}, func(m StudentModel, redis *redis.Redis) {
|
||||||
data.Name = testUpdateName
|
data.Name = testUpdateName
|
||||||
err := m.Update(data)
|
err := m.Update(data)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
val, err := redis.Get(fmt.Sprintf("%s%v", cacheStudentIdPrefix, testInsertId))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "", val)
|
||||||
|
})
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
data.Name = testUpdateName
|
||||||
|
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
|
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
|
||||||
|
WithArgs(testInsertId).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"id", "class", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Class, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
|
||||||
|
}, func(m StudentModel, redis *redis.Redis) {
|
||||||
|
result, err := m.FindOne(testInsertId)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, *result, data)
|
||||||
|
|
||||||
|
var resp Student
|
||||||
|
val, err := redis.Get(fmt.Sprintf("%s%v", cacheStudentIdPrefix, testInsertId))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
err = json.Unmarshal([]byte(val), &resp)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, testUpdateName, data.Name)
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
|
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
|
||||||
WithArgs(testInsertID).
|
WithArgs(class, testUpdateName).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertID, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
|
WillReturnRows(sqlmock.NewRows([]string{"id", "class", "name", "age", "score", "create_time", "update_time"}).AddRow(testInsertId, data.Class, data.Name, data.Age, data.Score, testTimeValue, testTimeValue))
|
||||||
}, func(m StudentModel) {
|
}, func(m StudentModel, redis *redis.Redis) {
|
||||||
result, err := m.FindOne(testInsertID)
|
result, err := m.FindOneByClassName(class, testUpdateName)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, *result, data)
|
assert.Equal(t, *result, data)
|
||||||
|
|
||||||
|
val, err := redis.Get(fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, class, testUpdateName))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "1", val)
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
err = mockStudent(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertID).WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected))
|
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
}, func(m StudentModel) {
|
}, func(m StudentModel, redis *redis.Redis) {
|
||||||
err := m.Delete(testInsertID)
|
err = m.Delete(testInsertId, class, testUpdateName)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
val, err := redis.Get(fmt.Sprintf("%s%v", cacheStudentIdPrefix, testInsertId))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "", val)
|
||||||
|
|
||||||
|
val, err = redis.Get(fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, class, testUpdateName))
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "", val)
|
||||||
})
|
})
|
||||||
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,11 +155,11 @@ func TestUserModel(t *testing.T) {
|
|||||||
testGender = "男"
|
testGender = "男"
|
||||||
testNickname = "test_nickname"
|
testNickname = "test_nickname"
|
||||||
testRowsAffected int64 = 1
|
testRowsAffected int64 = 1
|
||||||
testInsertID int64 = 1
|
testInsertId int64 = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
var data User
|
var data User
|
||||||
data.ID = testInsertID
|
data.ID = testInsertId
|
||||||
data.User = testUser
|
data.User = testUser
|
||||||
data.Name = "gozero"
|
data.Name = "gozero"
|
||||||
data.Password = testPassword
|
data.Password = testPassword
|
||||||
@@ -126,14 +172,14 @@ func TestUserModel(t *testing.T) {
|
|||||||
err := mockUser(func(mock sqlmock.Sqlmock) {
|
err := mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
|
mock.ExpectExec(fmt.Sprintf("insert into %s", testTable)).
|
||||||
WithArgs(data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname).
|
WithArgs(data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname).
|
||||||
WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected))
|
WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
}, func(m UserModel) {
|
}, func(m UserModel) {
|
||||||
r, err := m.Insert(data)
|
r, err := m.Insert(data)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
lastInsertID, err := r.LastInsertId()
|
lastInsertId, err := r.LastInsertId()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, testInsertID, lastInsertID)
|
assert.Equal(t, testInsertId, lastInsertId)
|
||||||
|
|
||||||
rowsAffected, err := r.RowsAffected()
|
rowsAffected, err := r.RowsAffected()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
@@ -143,17 +189,17 @@ func TestUserModel(t *testing.T) {
|
|||||||
|
|
||||||
err = mockUser(func(mock sqlmock.Sqlmock) {
|
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
|
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s", testTable)).
|
||||||
WithArgs(testInsertID).
|
WithArgs(testInsertId).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertID, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
|
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
|
||||||
}, func(m UserModel) {
|
}, func(m UserModel) {
|
||||||
result, err := m.FindOne(testInsertID)
|
result, err := m.FindOne(testInsertId)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, *result, data)
|
assert.Equal(t, *result, data)
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
err = mockUser(func(mock sqlmock.Sqlmock) {
|
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.User, testUpdateName, data.Password, data.Mobile, data.Gender, data.Nickname, testInsertID).WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected))
|
mock.ExpectExec(fmt.Sprintf("update %s", testTable)).WithArgs(data.User, testUpdateName, data.Password, data.Mobile, data.Gender, data.Nickname, testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
}, func(m UserModel) {
|
}, func(m UserModel) {
|
||||||
data.Name = testUpdateName
|
data.Name = testUpdateName
|
||||||
err := m.Update(data)
|
err := m.Update(data)
|
||||||
@@ -163,26 +209,26 @@ func TestUserModel(t *testing.T) {
|
|||||||
|
|
||||||
err = mockUser(func(mock sqlmock.Sqlmock) {
|
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
|
mock.ExpectQuery(fmt.Sprintf("select (.+) from %s ", testTable)).
|
||||||
WithArgs(testInsertID).
|
WithArgs(testInsertId).
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertID, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
|
WillReturnRows(sqlmock.NewRows([]string{"id", "user", "name", "password", "mobile", "gender", "nickname", "create_time", "update_time"}).AddRow(testInsertId, data.User, data.Name, data.Password, data.Mobile, data.Gender, data.Nickname, testTimeValue, testTimeValue))
|
||||||
}, func(m UserModel) {
|
}, func(m UserModel) {
|
||||||
result, err := m.FindOne(testInsertID)
|
result, err := m.FindOne(testInsertId)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, *result, data)
|
assert.Equal(t, *result, data)
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
err = mockUser(func(mock sqlmock.Sqlmock) {
|
err = mockUser(func(mock sqlmock.Sqlmock) {
|
||||||
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertID).WillReturnResult(sqlmock.NewResult(testInsertID, testRowsAffected))
|
mock.ExpectExec(fmt.Sprintf("delete from %s where `id` = ?", testTable)).WithArgs(testInsertId).WillReturnResult(sqlmock.NewResult(testInsertId, testRowsAffected))
|
||||||
}, func(m UserModel) {
|
}, func(m UserModel) {
|
||||||
err := m.Delete(testInsertID)
|
err := m.Delete(testInsertId)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
})
|
})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// with cache
|
// with cache
|
||||||
func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel)) error {
|
func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel, r *redis.Redis)) error {
|
||||||
db, mock, err := sqlmock.New()
|
db, mock, err := sqlmock.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -211,7 +257,9 @@ func mockStudent(mockFn func(mock sqlmock.Sqlmock), fn func(m StudentModel)) err
|
|||||||
Weight: 100,
|
Weight: 100,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
fn(m)
|
mock.ExpectBegin()
|
||||||
|
fn(m, r)
|
||||||
|
mock.ExpectCommit()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,16 +19,19 @@ var (
|
|||||||
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
|
studentRowsExpectAutoSet = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
|
||||||
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
|
studentRowsWithPlaceHolder = strings.Join(stringx.Remove(studentFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
|
||||||
|
|
||||||
cacheStudentIDPrefix = "cache#Student#id#"
|
cacheStudentIdPrefix = "cache#student#id#"
|
||||||
|
cacheStudentClassNamePrefix = "cache#student#class#name#"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// StudentModel defines a model for Student
|
// StudentModel only for test
|
||||||
StudentModel interface {
|
StudentModel interface {
|
||||||
Insert(data Student) (sql.Result, error)
|
Insert(data Student) (sql.Result, error)
|
||||||
FindOne(id int64) (*Student, error)
|
FindOne(id int64) (*Student, error)
|
||||||
|
FindOneByClassName(class string, name string) (*Student, error)
|
||||||
Update(data Student) error
|
Update(data Student) error
|
||||||
Delete(id int64) error
|
// only for test
|
||||||
|
Delete(id int64, className, studentName string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultStudentModel struct {
|
defaultStudentModel struct {
|
||||||
@@ -36,9 +39,10 @@ type (
|
|||||||
table string
|
table string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Student defines an data structure for mysql
|
// Student only for test
|
||||||
Student struct {
|
Student struct {
|
||||||
ID int64 `db:"id"`
|
Id int64 `db:"id"`
|
||||||
|
Class string `db:"class"`
|
||||||
Name string `db:"name"`
|
Name string `db:"name"`
|
||||||
Age sql.NullInt64 `db:"age"`
|
Age sql.NullInt64 `db:"age"`
|
||||||
Score sql.NullFloat64 `db:"score"`
|
Score sql.NullFloat64 `db:"score"`
|
||||||
@@ -47,7 +51,7 @@ type (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewStudentModel creates an instance for StudentModel
|
// NewStudentModel only for test
|
||||||
func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel {
|
func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel {
|
||||||
return &defaultStudentModel{
|
return &defaultStudentModel{
|
||||||
CachedConn: sqlc.NewConn(conn, c),
|
CachedConn: sqlc.NewConn(conn, c),
|
||||||
@@ -56,16 +60,18 @@ func NewStudentModel(conn sqlx.SqlConn, c cache.CacheConf) StudentModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultStudentModel) Insert(data Student) (sql.Result, error) {
|
func (m *defaultStudentModel) Insert(data Student) (sql.Result, error) {
|
||||||
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?)", m.table, studentRowsExpectAutoSet)
|
studentClassNameKey := fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, data.Class, data.Name)
|
||||||
ret, err := m.ExecNoCache(query, data.Name, data.Age, data.Score)
|
ret, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||||
|
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?)", m.table, studentRowsExpectAutoSet)
|
||||||
|
return conn.Exec(query, data.Class, data.Name, data.Age, data.Score)
|
||||||
|
}, studentClassNameKey)
|
||||||
return ret, err
|
return ret, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultStudentModel) FindOne(id int64) (*Student, error) {
|
func (m *defaultStudentModel) FindOne(id int64) (*Student, error) {
|
||||||
studentIDKey := fmt.Sprintf("%s%v", cacheStudentIDPrefix, id)
|
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
|
||||||
var resp Student
|
var resp Student
|
||||||
err := m.QueryRow(&resp, studentIDKey, func(conn sqlx.SqlConn, v interface{}) error {
|
err := m.QueryRow(&resp, studentIdKey, func(conn sqlx.SqlConn, v interface{}) error {
|
||||||
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
|
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", studentRows, m.table)
|
||||||
return conn.QueryRow(v, query, id)
|
return conn.QueryRow(v, query, id)
|
||||||
})
|
})
|
||||||
@@ -79,27 +85,47 @@ func (m *defaultStudentModel) FindOne(id int64) (*Student, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *defaultStudentModel) FindOneByClassName(class string, name string) (*Student, error) {
|
||||||
|
studentClassNameKey := fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, class, name)
|
||||||
|
var resp Student
|
||||||
|
err := m.QueryRowIndex(&resp, studentClassNameKey, m.formatPrimary, func(conn sqlx.SqlConn, v interface{}) (i interface{}, e error) {
|
||||||
|
query := fmt.Sprintf("select %s from %s where `class` = ? and `name` = ? limit 1", studentRows, m.table)
|
||||||
|
if err := conn.QueryRow(&resp, query, class, name); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp.Id, nil
|
||||||
|
}, m.queryPrimary)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return &resp, nil
|
||||||
|
case sqlc.ErrNotFound:
|
||||||
|
return nil, ErrNotFound
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *defaultStudentModel) Update(data Student) error {
|
func (m *defaultStudentModel) Update(data Student) error {
|
||||||
studentIDKey := fmt.Sprintf("%s%v", cacheStudentIDPrefix, data.ID)
|
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, data.Id)
|
||||||
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||||
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, studentRowsWithPlaceHolder)
|
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, studentRowsWithPlaceHolder)
|
||||||
return conn.Exec(query, data.Name, data.Age, data.Score, data.ID)
|
return conn.Exec(query, data.Class, data.Name, data.Age, data.Score, data.Id)
|
||||||
}, studentIDKey)
|
}, studentIdKey)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultStudentModel) Delete(id int64) error {
|
func (m *defaultStudentModel) Delete(id int64, className, studentName string) error {
|
||||||
|
studentIdKey := fmt.Sprintf("%s%v", cacheStudentIdPrefix, id)
|
||||||
studentIDKey := fmt.Sprintf("%s%v", cacheStudentIDPrefix, id)
|
studentClassNameKey := fmt.Sprintf("%s%v%v", cacheStudentClassNamePrefix, className, studentName)
|
||||||
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||||
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
|
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
|
||||||
return conn.Exec(query, id)
|
return conn.Exec(query, id)
|
||||||
}, studentIDKey)
|
}, studentIdKey, studentClassNameKey)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultStudentModel) formatPrimary(primary interface{}) string {
|
func (m *defaultStudentModel) formatPrimary(primary interface{}) string {
|
||||||
return fmt.Sprintf("%s%v", cacheStudentIDPrefix, primary)
|
return fmt.Sprintf("%s%v", cacheStudentIdPrefix, primary)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultStudentModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
|
func (m *defaultStudentModel) queryPrimary(conn sqlx.SqlConn, v, primary interface{}) error {
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
userFieldNames = builderx.FieldNames(&User{})
|
userFieldNames = builderx.RawFieldNames(&User{})
|
||||||
userRows = strings.Join(userFieldNames, ",")
|
userRows = strings.Join(userFieldNames, ",")
|
||||||
userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), ",")
|
userRowsExpectAutoSet = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), ",")
|
||||||
userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "id", "create_time", "update_time"), "=?,") + "=?"
|
userRowsWithPlaceHolder = strings.Join(stringx.Remove(userFieldNames, "`id`", "`create_time`", "`update_time`"), "=?,") + "=?"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@@ -25,8 +25,8 @@ type (
|
|||||||
Insert(data User) (sql.Result, error)
|
Insert(data User) (sql.Result, error)
|
||||||
FindOne(id int64) (*User, error)
|
FindOne(id int64) (*User, error)
|
||||||
FindOneByUser(user string) (*User, error)
|
FindOneByUser(user string) (*User, error)
|
||||||
FindOneByName(name string) (*User, error)
|
|
||||||
FindOneByMobile(mobile string) (*User, error)
|
FindOneByMobile(mobile string) (*User, error)
|
||||||
|
FindOneByName(name string) (*User, error)
|
||||||
Update(data User) error
|
Update(data User) error
|
||||||
Delete(id int64) error
|
Delete(id int64) error
|
||||||
}
|
}
|
||||||
@@ -92,10 +92,10 @@ func (m *defaultUserModel) FindOneByUser(user string) (*User, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultUserModel) FindOneByName(name string) (*User, error) {
|
func (m *defaultUserModel) FindOneByMobile(mobile string) (*User, error) {
|
||||||
var resp User
|
var resp User
|
||||||
query := fmt.Sprintf("select %s from %s where `name` = ? limit 1", userRows, m.table)
|
query := fmt.Sprintf("select %s from %s where `mobile` = ? limit 1", userRows, m.table)
|
||||||
err := m.conn.QueryRow(&resp, query, name)
|
err := m.conn.QueryRow(&resp, query, mobile)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
@@ -106,10 +106,10 @@ func (m *defaultUserModel) FindOneByName(name string) (*User, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *defaultUserModel) FindOneByMobile(mobile string) (*User, error) {
|
func (m *defaultUserModel) FindOneByName(name string) (*User, error) {
|
||||||
var resp User
|
var resp User
|
||||||
query := fmt.Sprintf("select %s from %s where `mobile` = ? limit 1", userRows, m.table)
|
query := fmt.Sprintf("select %s from %s where `name` = ? limit 1", userRows, m.table)
|
||||||
err := m.conn.QueryRow(&resp, query, mobile)
|
err := m.conn.QueryRow(&resp, query, name)
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/tal-tech/go-zero/core/mapping"
|
"github.com/tal-tech/go-zero/core/mapping"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrNotFound is the alias of sql.ErrNoRows
|
||||||
var ErrNotFound = sql.ErrNoRows
|
var ErrNotFound = sql.ErrNoRows
|
||||||
|
|
||||||
func desensitize(datasource string) string {
|
func desensitize(datasource string) string {
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ func (s String) Lower() string {
|
|||||||
return strings.ToLower(s.source)
|
return strings.ToLower(s.source)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Upper calls the strings.ToUpper
|
||||||
|
func (s String) Upper() string {
|
||||||
|
return strings.ToUpper(s.source)
|
||||||
|
}
|
||||||
|
|
||||||
// ReplaceAll calls the strings.ReplaceAll
|
// ReplaceAll calls the strings.ReplaceAll
|
||||||
func (s String) ReplaceAll(old, new string) string {
|
func (s String) ReplaceAll(old, new string) string {
|
||||||
return strings.ReplaceAll(s.source, old, new)
|
return strings.ReplaceAll(s.source, old, new)
|
||||||
|
|||||||
Reference in New Issue
Block a user