* fix #1806

* chore: refine error text
This commit is contained in:
Kevin Wan
2022-04-27 00:01:31 +08:00
committed by GitHub
parent 5c9fae7e62
commit 5bcee4cf7c
4 changed files with 57 additions and 8 deletions

View File

@@ -275,7 +275,7 @@ func Infov(v interface{}) {
infoAnySync(v)
}
// Must checks if err is nil, otherwise logs the err and exits.
// Must checks if err is nil, otherwise logs the error and exits.
func Must(err error) {
if err != nil {
msg := formatWithCaller(err.Error(), 3)

View File

@@ -2,6 +2,7 @@ package sqlx
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
@@ -10,6 +11,8 @@ import (
"github.com/zeromicro/go-zero/core/mapping"
)
var errUnbalancedEscape = errors.New("no char after escape char")
func desensitize(datasource string) string {
// remove account
pos := strings.LastIndex(datasource, "@")
@@ -95,6 +98,30 @@ func format(query string, args ...interface{}) (string, error) {
writeValue(&b, args[index])
i = j - 1
}
case '\'', '"', '`':
b.WriteByte(ch)
for j := i + 1; j < bytes; j++ {
cur := query[j]
b.WriteByte(cur)
switch cur {
case '\\':
j++
if j >= bytes {
return "", errUnbalancedEscape
}
b.WriteByte(query[j])
case '\'', '"', '`':
if cur == ch {
i = j
goto end
}
}
}
end:
break
default:
b.WriteByte(ch)
}

View File

@@ -97,6 +97,30 @@ func TestFormat(t *testing.T) {
args: []interface{}{"133", false},
hasErr: true,
},
{
name: "select with date",
query: "select * from user where date='2006-01-02 15:04:05' and name=:1",
args: []interface{}{"foo"},
expect: "select * from user where date='2006-01-02 15:04:05' and name='foo'",
},
{
name: "select with date and escape",
query: `select * from user where date=' 2006-01-02 15:04:05 \'' and name=:1`,
args: []interface{}{"foo"},
expect: `select * from user where date=' 2006-01-02 15:04:05 \'' and name='foo'`,
},
{
name: "select with date and bad arg",
query: `select * from user where date='2006-01-02 15:04:05 \'' and name=:a`,
args: []interface{}{"foo"},
hasErr: true,
},
{
name: "select with date and escape error",
query: `select * from user where date='2006-01-02 15:04:05 \`,
args: []interface{}{"foo"},
hasErr: true,
},
}
for _, test := range tests {
@@ -108,6 +132,7 @@ func TestFormat(t *testing.T) {
if test.hasErr {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
assert.Equal(t, test.expect, actual)
}
})

View File

@@ -69,7 +69,6 @@ func Parse(filename, database string) ([]*Table, error) {
}
nameOriginals := parseNameOriginal(tables)
indexNameGen := func(column ...string) string {
return strings.Join(column, "_")
}
@@ -77,14 +76,12 @@ func Parse(filename, database string) ([]*Table, error) {
prefix := filepath.Base(filename)
var list []*Table
for indexTable, e := range tables {
columns := e.Columns
var (
primaryColumn string
primaryColumnSet = collection.NewSet()
primaryColumn string
uniqueKeyMap = make(map[string][]string)
normalKeyMap = make(map[string][]string)
uniqueKeyMap = make(map[string][]string)
normalKeyMap = make(map[string][]string)
columns = e.Columns
)
for _, column := range columns {