print entire sql statements in logx if necessary (#704)
This commit is contained in:
@@ -2,6 +2,7 @@ package sqlx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
@@ -45,24 +46,6 @@ func escape(input string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatForPrint(query string, args ...interface{}) string {
|
||||
if len(args) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
var vals []string
|
||||
for _, arg := range args {
|
||||
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteByte('[')
|
||||
b.WriteString(strings.Join(vals, ", "))
|
||||
b.WriteByte(']')
|
||||
|
||||
return strings.Join([]string{query, b.String()}, " ")
|
||||
}
|
||||
|
||||
func format(query string, args ...interface{}) (string, error) {
|
||||
numArgs := len(args)
|
||||
if numArgs == 0 {
|
||||
@@ -72,36 +55,50 @@ func format(query string, args ...interface{}) (string, error) {
|
||||
var b strings.Builder
|
||||
argIndex := 0
|
||||
|
||||
for _, ch := range query {
|
||||
if ch == '?' {
|
||||
bytes := len(query)
|
||||
for i := 0; i < bytes; i++ {
|
||||
ch := query[i]
|
||||
switch ch {
|
||||
case '?':
|
||||
if argIndex >= numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
|
||||
}
|
||||
|
||||
arg := args[argIndex]
|
||||
writeValue(&b, args[argIndex])
|
||||
argIndex++
|
||||
|
||||
switch v := arg.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
b.WriteByte('1')
|
||||
} else {
|
||||
b.WriteByte('0')
|
||||
case '$':
|
||||
var j int
|
||||
for j = i + 1; j < bytes; j++ {
|
||||
char := query[j]
|
||||
if char < '0' || '9' < char {
|
||||
break
|
||||
}
|
||||
case string:
|
||||
b.WriteByte('\'')
|
||||
b.WriteString(escape(v))
|
||||
b.WriteByte('\'')
|
||||
default:
|
||||
b.WriteString(mapping.Repr(v))
|
||||
}
|
||||
} else {
|
||||
b.WriteRune(ch)
|
||||
if j > i+1 {
|
||||
index, err := strconv.Atoi(query[i+1 : j])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// index starts from 1 for pg
|
||||
if index > argIndex {
|
||||
argIndex = index
|
||||
}
|
||||
index--
|
||||
if index < 0 || numArgs <= index {
|
||||
return "", fmt.Errorf("error: wrong index %d in sql", index)
|
||||
}
|
||||
|
||||
writeValue(&b, args[index])
|
||||
i = j - 1
|
||||
}
|
||||
default:
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
}
|
||||
|
||||
if argIndex < numArgs {
|
||||
return "", fmt.Errorf("error: %d ? in sql, but more arguments provided", argIndex)
|
||||
return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
@@ -117,3 +114,20 @@ func logSqlError(stmt string, err error) {
|
||||
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func writeValue(buf *strings.Builder, arg interface{}) {
|
||||
switch v := arg.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
buf.WriteByte('1')
|
||||
} else {
|
||||
buf.WriteByte('0')
|
||||
}
|
||||
case string:
|
||||
buf.WriteByte('\'')
|
||||
buf.WriteString(escape(v))
|
||||
buf.WriteByte('\'')
|
||||
default:
|
||||
buf.WriteString(mapping.Repr(v))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user