feat: support ctx in sqlx/sqlc, listed in ROADMAP (#1535)

* feat: support ctx in sqlx/sqlc

* chore: update roadmap

* fix: context.Canceled should be acceptable

* use %w to wrap errors

* chore: remove unused vars
This commit is contained in:
Kevin Wan
2022-02-16 19:31:43 +08:00
committed by GitHub
parent 7c63676be4
commit 607bae27fa
12 changed files with 458 additions and 152 deletions

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"time"
@@ -18,64 +19,65 @@ func SetSlowThreshold(threshold time.Duration) {
slowThreshold.Set(threshold)
}
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now()
result, err := conn.Exec(q, args...)
result, err := conn.ExecContext(ctx, q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Infof("sql exec: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
logSqlError(ctx, stmt, err)
}
return result, err
}
func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now()
result, err := conn.Exec(args...)
result, err := conn.ExecContext(ctx, args...)
duration := timex.Since(startTime)
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Infof("sql execStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
logSqlError(ctx, stmt, err)
}
return result, err
}
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now()
rows, err := conn.Query(q, args...)
rows, err := conn.QueryContext(ctx, q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql query: %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Infof("sql query: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
logSqlError(ctx, stmt, err)
return err
}
defer rows.Close()
@@ -83,22 +85,23 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in
return scanner(rows)
}
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error,
q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now()
rows, err := conn.Query(args...)
rows, err := conn.QueryContext(ctx, args...)
duration := timex.Since(startTime)
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Infof("sql queryStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
logSqlError(ctx, stmt, err)
return err
}
defer rows.Close()