print entire sql statements in logx if necessary (#704)

This commit is contained in:
Kevin Wan
2021-05-20 16:14:44 +08:00
committed by GitHub
parent 73906f996d
commit aaa39e17a3
5 changed files with 154 additions and 68 deletions

View File

@@ -2,6 +2,7 @@ package sqlx
import (
"fmt"
"strconv"
"strings"
"github.com/tal-tech/go-zero/core/logx"
@@ -45,24 +46,6 @@ func escape(input string) string {
return b.String()
}
func formatForPrint(query string, args ...interface{}) string {
if len(args) == 0 {
return query
}
var vals []string
for _, arg := range args {
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
}
var b strings.Builder
b.WriteByte('[')
b.WriteString(strings.Join(vals, ", "))
b.WriteByte(']')
return strings.Join([]string{query, b.String()}, " ")
}
func format(query string, args ...interface{}) (string, error) {
numArgs := len(args)
if numArgs == 0 {
@@ -72,36 +55,50 @@ func format(query string, args ...interface{}) (string, error) {
var b strings.Builder
argIndex := 0
for _, ch := range query {
if ch == '?' {
bytes := len(query)
for i := 0; i < bytes; i++ {
ch := query[i]
switch ch {
case '?':
if argIndex >= numArgs {
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
}
arg := args[argIndex]
writeValue(&b, args[argIndex])
argIndex++
switch v := arg.(type) {
case bool:
if v {
b.WriteByte('1')
} else {
b.WriteByte('0')
case '$':
var j int
for j = i + 1; j < bytes; j++ {
char := query[j]
if char < '0' || '9' < char {
break
}
case string:
b.WriteByte('\'')
b.WriteString(escape(v))
b.WriteByte('\'')
default:
b.WriteString(mapping.Repr(v))
}
} else {
b.WriteRune(ch)
if j > i+1 {
index, err := strconv.Atoi(query[i+1 : j])
if err != nil {
return "", err
}
// index starts from 1 for pg
if index > argIndex {
argIndex = index
}
index--
if index < 0 || numArgs <= index {
return "", fmt.Errorf("error: wrong index %d in sql", index)
}
writeValue(&b, args[index])
i = j - 1
}
default:
b.WriteByte(ch)
}
}
if argIndex < numArgs {
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
}
return b.String(), nil
@@ -117,3 +114,20 @@ func logSqlError(stmt string, err error) {
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
}
}
func writeValue(buf *strings.Builder, arg interface{}) {
switch v := arg.(type) {
case bool:
if v {
buf.WriteByte('1')
} else {
buf.WriteByte('0')
}
case string:
buf.WriteByte('\'')
buf.WriteString(escape(v))
buf.WriteByte('\'')
default:
buf.WriteString(mapping.Repr(v))
}
}