support postgresql (#583)

support postgresql
This commit is contained in:
Kevin Wan
2021-03-27 17:14:32 +08:00
committed by GitHub
parent 9e6c2ba2c0
commit bd623aaac3
5 changed files with 59 additions and 44 deletions

View File

@@ -24,6 +24,7 @@ type (
ResultHandler func(sql.Result, error) ResultHandler func(sql.Result, error)
// A BulkInserter is used to batch insert records. // A BulkInserter is used to batch insert records.
// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
BulkInserter struct { BulkInserter struct {
executor *executors.PeriodicalExecutor executor *executors.PeriodicalExecutor
inserter *dbInserter inserter *dbInserter

View File

@@ -12,14 +12,10 @@ import (
const slowThreshold = time.Millisecond * 500 const slowThreshold = time.Millisecond * 500
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) { func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now() startTime := timex.Now()
result, err := conn.Exec(q, args...) result, err := conn.Exec(q, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
stmt := formatForPrint(q, args)
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt) logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else { } else {
@@ -33,10 +29,10 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
} }
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) { func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
stmt := fmt.Sprint(args...)
startTime := timex.Now() startTime := timex.Now()
result, err := conn.Exec(args...) result, err := conn.Exec(args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
stmt := fmt.Sprint(args...)
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt) logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else { } else {
@@ -50,14 +46,10 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
} }
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error { func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now() startTime := timex.Now()
rows, err := conn.Query(q, args...) rows, err := conn.Query(q, args...)
duration := timex.Since(startTime) duration := timex.Since(startTime)
stmt := fmt.Sprint(args...)
if duration > slowThreshold { if duration > slowThreshold {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt) logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else { } else {

View File

@@ -16,7 +16,6 @@ func TestStmt_exec(t *testing.T) {
name string name string
args []interface{} args []interface{}
delay bool delay bool
formatError bool
hasError bool hasError bool
err error err error
lastInsertId int64 lastInsertId int64
@@ -28,12 +27,6 @@ func TestStmt_exec(t *testing.T) {
lastInsertId: 1, lastInsertId: 1,
rowsAffected: 2, rowsAffected: 2,
}, },
{
name: "wrong format",
args: []interface{}{1, 2},
formatError: true,
hasError: true,
},
{ {
name: "exec error", name: "exec error",
args: []interface{}{1}, args: []interface{}{1},
@@ -70,18 +63,13 @@ func TestStmt_exec(t *testing.T) {
}, },
} }
for i, fn := range fns { for _, fn := range fns {
i := i
fn := fn fn := fn
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
res, err := fn(test.args...) res, err := fn(test.args...)
if i == 0 && test.formatError { if test.hasError {
assert.NotNil(t, err)
return
}
if !test.formatError && test.hasError {
assert.NotNil(t, err) assert.NotNil(t, err)
return return
} }
@@ -100,23 +88,16 @@ func TestStmt_exec(t *testing.T) {
func TestStmt_query(t *testing.T) { func TestStmt_query(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args []interface{} args []interface{}
delay bool delay bool
formatError bool hasError bool
hasError bool err error
err error
}{ }{
{ {
name: "normal", name: "normal",
args: []interface{}{1}, args: []interface{}{1},
}, },
{
name: "wrong format",
args: []interface{}{1, 2},
formatError: true,
hasError: true,
},
{ {
name: "query error", name: "query error",
args: []interface{}{1}, args: []interface{}{1},
@@ -151,18 +132,13 @@ func TestStmt_query(t *testing.T) {
}, },
} }
for i, fn := range fns { for _, fn := range fns {
i := i
fn := fn fn := fn
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
t.Parallel() t.Parallel()
err := fn(test.args...) err := fn(test.args...)
if i == 0 && test.formatError { if test.hasError {
assert.NotNil(t, err)
return
}
if !test.formatError && test.hasError {
assert.NotNil(t, err) assert.NotNil(t, err)
return return
} }

View File

@@ -45,6 +45,24 @@ func escape(input string) string {
return b.String() return b.String()
} }
func formatForPrint(query string, args ...interface{}) string {
if len(args) == 0 {
return query
}
var vals []string
for _, arg := range args {
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
}
var b strings.Builder
b.WriteByte('[')
b.WriteString(strings.Join(vals, ", "))
b.WriteByte(']')
return strings.Join([]string{query, b.String()}, " ")
}
func format(query string, args ...interface{}) (string, error) { func format(query string, args ...interface{}) (string, error) {
numArgs := len(args) numArgs := len(args)
if numArgs == 0 { if numArgs == 0 {

View File

@@ -28,3 +28,31 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
datasource = desensitize(datasource) datasource = desensitize(datasource)
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)")) assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
} }
func TestFormatForPrint(t *testing.T) {
tests := []struct {
name string
query string
args []interface{}
expect string
}{
{
name: "no args",
query: "select user, name from table where id=?",
expect: `select user, name from table where id=?`,
},
{
name: "one arg",
query: "select user, name from table where id=?",
args: []interface{}{"kevin"},
expect: `select user, name from table where id=? ["kevin"]`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actual := formatForPrint(test.query, test.args...)
assert.Equal(t, test.expect, actual)
})
}
}