@@ -275,7 +275,7 @@ func Infov(v interface{}) {
|
|||||||
infoAnySync(v)
|
infoAnySync(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Must checks if err is nil, otherwise logs the err and exits.
|
// Must checks if err is nil, otherwise logs the error and exits.
|
||||||
func Must(err error) {
|
func Must(err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := formatWithCaller(err.Error(), 3)
|
msg := formatWithCaller(err.Error(), 3)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package sqlx
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -10,6 +11,8 @@ import (
|
|||||||
"github.com/zeromicro/go-zero/core/mapping"
|
"github.com/zeromicro/go-zero/core/mapping"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errUnbalancedEscape = errors.New("no char after escape char")
|
||||||
|
|
||||||
func desensitize(datasource string) string {
|
func desensitize(datasource string) string {
|
||||||
// remove account
|
// remove account
|
||||||
pos := strings.LastIndex(datasource, "@")
|
pos := strings.LastIndex(datasource, "@")
|
||||||
@@ -95,6 +98,30 @@ func format(query string, args ...interface{}) (string, error) {
|
|||||||
writeValue(&b, args[index])
|
writeValue(&b, args[index])
|
||||||
i = j - 1
|
i = j - 1
|
||||||
}
|
}
|
||||||
|
case '\'', '"', '`':
|
||||||
|
b.WriteByte(ch)
|
||||||
|
for j := i + 1; j < bytes; j++ {
|
||||||
|
cur := query[j]
|
||||||
|
b.WriteByte(cur)
|
||||||
|
|
||||||
|
switch cur {
|
||||||
|
case '\\':
|
||||||
|
j++
|
||||||
|
if j >= bytes {
|
||||||
|
return "", errUnbalancedEscape
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteByte(query[j])
|
||||||
|
case '\'', '"', '`':
|
||||||
|
if cur == ch {
|
||||||
|
i = j
|
||||||
|
goto end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
end:
|
||||||
|
break
|
||||||
default:
|
default:
|
||||||
b.WriteByte(ch)
|
b.WriteByte(ch)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -97,6 +97,30 @@ func TestFormat(t *testing.T) {
|
|||||||
args: []interface{}{"133", false},
|
args: []interface{}{"133", false},
|
||||||
hasErr: true,
|
hasErr: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "select with date",
|
||||||
|
query: "select * from user where date='2006-01-02 15:04:05' and name=:1",
|
||||||
|
args: []interface{}{"foo"},
|
||||||
|
expect: "select * from user where date='2006-01-02 15:04:05' and name='foo'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select with date and escape",
|
||||||
|
query: `select * from user where date=' 2006-01-02 15:04:05 \'' and name=:1`,
|
||||||
|
args: []interface{}{"foo"},
|
||||||
|
expect: `select * from user where date=' 2006-01-02 15:04:05 \'' and name='foo'`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select with date and bad arg",
|
||||||
|
query: `select * from user where date='2006-01-02 15:04:05 \'' and name=:a`,
|
||||||
|
args: []interface{}{"foo"},
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select with date and escape error",
|
||||||
|
query: `select * from user where date='2006-01-02 15:04:05 \`,
|
||||||
|
args: []interface{}{"foo"},
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -108,6 +132,7 @@ func TestFormat(t *testing.T) {
|
|||||||
if test.hasErr {
|
if test.hasErr {
|
||||||
assert.NotNil(t, err)
|
assert.NotNil(t, err)
|
||||||
} else {
|
} else {
|
||||||
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, test.expect, actual)
|
assert.Equal(t, test.expect, actual)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ func Parse(filename, database string) ([]*Table, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
nameOriginals := parseNameOriginal(tables)
|
nameOriginals := parseNameOriginal(tables)
|
||||||
|
|
||||||
indexNameGen := func(column ...string) string {
|
indexNameGen := func(column ...string) string {
|
||||||
return strings.Join(column, "_")
|
return strings.Join(column, "_")
|
||||||
}
|
}
|
||||||
@@ -77,14 +76,12 @@ func Parse(filename, database string) ([]*Table, error) {
|
|||||||
prefix := filepath.Base(filename)
|
prefix := filepath.Base(filename)
|
||||||
var list []*Table
|
var list []*Table
|
||||||
for indexTable, e := range tables {
|
for indexTable, e := range tables {
|
||||||
columns := e.Columns
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
primaryColumn string
|
||||||
primaryColumnSet = collection.NewSet()
|
primaryColumnSet = collection.NewSet()
|
||||||
|
uniqueKeyMap = make(map[string][]string)
|
||||||
primaryColumn string
|
normalKeyMap = make(map[string][]string)
|
||||||
uniqueKeyMap = make(map[string][]string)
|
columns = e.Columns
|
||||||
normalKeyMap = make(map[string][]string)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, column := range columns {
|
for _, column := range columns {
|
||||||
|
|||||||
Reference in New Issue
Block a user