feat: support ctx in sqlx/sqlc, listed in ROADMAP (#1535)
* feat: support ctx in sqlx/sqlc * chore: update roadmap * fix: context.Canceled should be acceptable * use %w to wrap errors * chore: remove unused vars
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
@@ -26,11 +27,19 @@ func NewSessionFromTx(tx *sql.Tx) Session {
|
||||
}
|
||||
|
||||
func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
|
||||
return exec(t.Tx, q, args...)
|
||||
return t.ExecCtx(context.Background(), q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (sql.Result, error) {
|
||||
return exec(ctx, t.Tx, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) Prepare(q string) (StmtSession, error) {
|
||||
stmt, err := t.Tx.Prepare(q)
|
||||
return t.PrepareCtx(context.Background(), q)
|
||||
}
|
||||
|
||||
func (t txSession) PrepareCtx(ctx context.Context, q string) (StmtSession, error) {
|
||||
stmt, err := t.Tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -42,25 +51,43 @@ func (t txSession) Prepare(q string) (StmtSession, error) {
|
||||
}
|
||||
|
||||
func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return t.QueryRowCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
|
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return t.QueryRowPartialCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string,
|
||||
args ...interface{}) error {
|
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return t.QueryRowsCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
|
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return t.QueryRowsPartialCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string,
|
||||
args ...interface{}) error {
|
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
@@ -76,17 +103,19 @@ func begin(db *sql.DB) (trans, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
|
||||
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
|
||||
fn func(context.Context, Session) error) (err error) {
|
||||
conn, err := db.connProv()
|
||||
if err != nil {
|
||||
db.onError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return transactOnConn(conn, b, fn)
|
||||
return transactOnConn(ctx, conn, b, fn)
|
||||
}
|
||||
|
||||
func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
|
||||
func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
|
||||
fn func(context.Context, Session) error) (err error) {
|
||||
var tx trans
|
||||
tx, err = b(conn)
|
||||
if err != nil {
|
||||
@@ -96,18 +125,18 @@ func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err err
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
|
||||
err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
|
||||
} else {
|
||||
err = fmt.Errorf("recoveer from %#v", p)
|
||||
}
|
||||
} else if err != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
|
||||
err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
|
||||
}
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(tx)
|
||||
return fn(ctx, tx)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user