print entire sql statements in logx if necessary (#704)

This commit is contained in:
Kevin Wan
2021-05-20 16:14:44 +08:00
committed by GitHub
parent 73906f996d
commit aaa39e17a3
5 changed files with 154 additions and 68 deletions

View File

@@ -56,7 +56,8 @@ type (
} }
statement struct { statement struct {
stmt *sql.Stmt query string
stmt *sql.Stmt
} }
stmtConn interface { stmtConn interface {
@@ -111,7 +112,8 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
} }
stmt = statement{ stmt = statement{
stmt: st, query: query,
stmt: st,
} }
return nil return nil
}, db.acceptable) }, db.acceptable)
@@ -181,29 +183,29 @@ func (s statement) Close() error {
} }
func (s statement) Exec(args ...interface{}) (sql.Result, error) { func (s statement) Exec(args ...interface{}) (sql.Result, error) {
return execStmt(s.stmt, args...) return execStmt(s.stmt, s.query, args...)
} }
func (s statement) QueryRow(v interface{}, args ...interface{}) error { func (s statement) QueryRow(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error { return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true) return unmarshalRow(v, rows, true)
}, args...) }, s.query, args...)
} }
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error { func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error { return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false) return unmarshalRow(v, rows, false)
}, args...) }, s.query, args...)
} }
func (s statement) QueryRows(v interface{}, args ...interface{}) error { func (s statement) QueryRows(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error { return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true) return unmarshalRows(v, rows, true)
}, args...) }, s.query, args...)
} }
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error { func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error { return queryStmt(s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false) return unmarshalRows(v, rows, false)
}, args...) }, s.query, args...)
} }

View File

@@ -2,7 +2,6 @@ package sqlx
import ( import (
"database/sql" "database/sql"
"fmt"
"time" "time"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
@@ -12,10 +11,14 @@ import (
const slowThreshold = time.Millisecond * 500 const slowThreshold = time.Millisecond * 500
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) { func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now() startTime := timex.Now()
result, err := conn.Exec(q, args...) result, err := conn.Exec(q, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
stmt := formatForPrint(q, args)
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else { } else {
@@ -28,11 +31,15 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
return result, err return result, err
} }
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) { func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now() startTime := timex.Now()
result, err := conn.Exec(args...) result, err := conn.Exec(args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
stmt := fmt.Sprint(args...)
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else { } else {
@@ -46,10 +53,14 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
} }
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now() startTime := timex.Now()
rows, err := conn.Query(q, args...) rows, err := conn.Query(q, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
stmt := fmt.Sprint(args...)
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else { } else {
@@ -64,8 +75,12 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in
return scanner(rows) return scanner(rows)
} }
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, args ...interface{}) error { func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
stmt := fmt.Sprint(args...) stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now() startTime := timex.Now()
rows, err := conn.Query(args...) rows, err := conn.Query(args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)

View File

@@ -14,6 +14,7 @@ var errMockedPlaceholder = errors.New("placeholder")
func TestStmt_exec(t *testing.T) { func TestStmt_exec(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
query string
args []interface{} args []interface{}
delay bool delay bool
hasError bool hasError bool
@@ -23,18 +24,28 @@ func TestStmt_exec(t *testing.T) {
}{ }{
{ {
name: "normal", name: "normal",
query: "select user from users where id=?",
args: []interface{}{1}, args: []interface{}{1},
lastInsertId: 1, lastInsertId: 1,
rowsAffected: 2, rowsAffected: 2,
}, },
{ {
name: "exec error", name: "exec error",
query: "select user from users where id=?",
args: []interface{}{1},
hasError: true,
err: errors.New("exec"),
},
{
name: "exec more args error",
query: "select user from users where id=? and name=?",
args: []interface{}{1}, args: []interface{}{1},
hasError: true, hasError: true,
err: errors.New("exec"), err: errors.New("exec"),
}, },
{ {
name: "slowcall", name: "slowcall",
query: "select user from users where id=?",
args: []interface{}{1}, args: []interface{}{1},
delay: true, delay: true,
lastInsertId: 1, lastInsertId: 1,
@@ -51,7 +62,7 @@ func TestStmt_exec(t *testing.T) {
rowsAffected: test.rowsAffected, rowsAffected: test.rowsAffected,
err: test.err, err: test.err,
delay: test.delay, delay: test.delay,
}, "select user from users where id=?", args...) }, test.query, args...)
}, },
func(args ...interface{}) (sql.Result, error) { func(args ...interface{}) (sql.Result, error) {
return execStmt(&mockedStmtConn{ return execStmt(&mockedStmtConn{
@@ -59,7 +70,7 @@ func TestStmt_exec(t *testing.T) {
rowsAffected: test.rowsAffected, rowsAffected: test.rowsAffected,
err: test.err, err: test.err,
delay: test.delay, delay: test.delay,
}, args...) }, test.query, args...)
}, },
} }
@@ -89,23 +100,34 @@ func TestStmt_exec(t *testing.T) {
func TestStmt_query(t *testing.T) { func TestStmt_query(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
query string
args []interface{} args []interface{}
delay bool delay bool
hasError bool hasError bool
err error err error
}{ }{
{ {
name: "normal", name: "normal",
args: []interface{}{1}, query: "select user from users where id=?",
args: []interface{}{1},
}, },
{ {
name: "query error", name: "query error",
query: "select user from users where id=?",
args: []interface{}{1},
hasError: true,
err: errors.New("exec"),
},
{
name: "query more args error",
query: "select user from users where id=? and name=?",
args: []interface{}{1}, args: []interface{}{1},
hasError: true, hasError: true,
err: errors.New("exec"), err: errors.New("exec"),
}, },
{ {
name: "slowcall", name: "slowcall",
query: "select user from users where id=?",
args: []interface{}{1}, args: []interface{}{1},
delay: true, delay: true,
}, },
@@ -120,7 +142,7 @@ func TestStmt_query(t *testing.T) {
delay: test.delay, delay: test.delay,
}, func(rows *sql.Rows) error { }, func(rows *sql.Rows) error {
return nil return nil
}, "select user from users where id=?", args...) }, test.query, args...)
}, },
func(args ...interface{}) error { func(args ...interface{}) error {
return queryStmt(&mockedStmtConn{ return queryStmt(&mockedStmtConn{
@@ -128,7 +150,7 @@ func TestStmt_query(t *testing.T) {
delay: test.delay, delay: test.delay,
}, func(rows *sql.Rows) error { }, func(rows *sql.Rows) error {
return nil return nil
}, args...) }, test.query, args...)
}, },
} }
@@ -143,7 +165,7 @@ func TestStmt_query(t *testing.T) {
return return
} }
assert.Equal(t, errMockedPlaceholder, err) assert.NotNil(t, err)
}) })
} }
} }

View File

@@ -2,6 +2,7 @@ package sqlx
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"github.com/tal-tech/go-zero/core/logx" "github.com/tal-tech/go-zero/core/logx"
@@ -45,24 +46,6 @@ func escape(input string) string {
return b.String() return b.String()
} }
func formatForPrint(query string, args ...interface{}) string {
if len(args) == 0 {
return query
}
var vals []string
for _, arg := range args {
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
}
var b strings.Builder
b.WriteByte('[')
b.WriteString(strings.Join(vals, ", "))
b.WriteByte(']')
return strings.Join([]string{query, b.String()}, " ")
}
func format(query string, args ...interface{}) (string, error) { func format(query string, args ...interface{}) (string, error) {
numArgs := len(args) numArgs := len(args)
if numArgs == 0 { if numArgs == 0 {
@@ -72,36 +55,50 @@ func format(query string, args ...interface{}) (string, error) {
var b strings.Builder var b strings.Builder
argIndex := 0 argIndex := 0
for _, ch := range query { bytes := len(query)
if ch == '?' { for i := 0; i < bytes; i++ {
ch := query[i]
switch ch {
case '?':
if argIndex >= numArgs { if argIndex >= numArgs {
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex) return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
} }
arg := args[argIndex] writeValue(&b, args[argIndex])
argIndex++ argIndex++
case '$':
switch v := arg.(type) { var j int
case bool: for j = i + 1; j < bytes; j++ {
if v { char := query[j]
b.WriteByte('1') if char < '0' || '9' < char {
} else { break
b.WriteByte('0')
} }
case string:
b.WriteByte('\'')
b.WriteString(escape(v))
b.WriteByte('\'')
default:
b.WriteString(mapping.Repr(v))
} }
} else { if j > i+1 {
b.WriteRune(ch) index, err := strconv.Atoi(query[i+1 : j])
if err != nil {
return "", err
}
// index starts from 1 for pg
if index > argIndex {
argIndex = index
}
index--
if index < 0 || numArgs <= index {
return "", fmt.Errorf("error: wrong index %d in sql", index)
}
writeValue(&b, args[index])
i = j - 1
}
default:
b.WriteByte(ch)
} }
} }
if argIndex < numArgs { if argIndex < numArgs {
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex) return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
} }
return b.String(), nil return b.String(), nil
@@ -117,3 +114,20 @@ func logSqlError(stmt string, err error) {
logx.Errorf("stmt: %s, error: %s", stmt, err.Error()) logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
} }
} }
func writeValue(buf *strings.Builder, arg interface{}) {
switch v := arg.(type) {
case bool:
if v {
buf.WriteByte('1')
} else {
buf.WriteByte('0')
}
case string:
buf.WriteByte('\'')
buf.WriteString(escape(v))
buf.WriteByte('\'')
default:
buf.WriteString(mapping.Repr(v))
}
}

View File

@@ -29,30 +29,63 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)")) assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
} }
func TestFormatForPrint(t *testing.T) { func TestFormat(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
query string query string
args []interface{} args []interface{}
expect string expect string
hasErr bool
}{ }{
{ {
name: "no args", name: "mysql normal",
query: "select user, name from table where id=?", query: "select name, age from users where bool=? and phone=?",
expect: `select user, name from table where id=?`, args: []interface{}{true, "133"},
expect: "select name, age from users where bool=1 and phone='133'",
}, },
{ {
name: "one arg", name: "mysql normal",
query: "select user, name from table where id=?", query: "select name, age from users where bool=? and phone=?",
args: []interface{}{"kevin"}, args: []interface{}{false, "133"},
expect: `select user, name from table where id=? ["kevin"]`, expect: "select name, age from users where bool=0 and phone='133'",
},
{
name: "pg normal",
query: "select name, age from users where bool=$1 and phone=$2",
args: []interface{}{true, "133"},
expect: "select name, age from users where bool=1 and phone='133'",
},
{
name: "pg normal reverse",
query: "select name, age from users where bool=$2 and phone=$1",
args: []interface{}{"133", false},
expect: "select name, age from users where bool=0 and phone='133'",
},
{
name: "pg error not number",
query: "select name, age from users where bool=$a and phone=$1",
args: []interface{}{"133", false},
hasErr: true,
},
{
name: "pg error more args",
query: "select name, age from users where bool=$2 and phone=$1 and nickname=$3",
args: []interface{}{"133", false},
hasErr: true,
}, },
} }
for _, test := range tests { for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
actual := formatForPrint(test.query, test.args...) t.Parallel()
assert.Equal(t, test.expect, actual)
actual, err := format(test.query, test.args...)
if test.hasErr {
assert.NotNil(t, err)
} else {
assert.Equal(t, test.expect, actual)
}
}) })
} }
} }