feat: support ctx in sqlx/sqlc, listed in ROADMAP (#1535)

* feat: support ctx in sqlx/sqlc

* chore: update roadmap

* fix: context.Canceled should be acceptable

* use %w to wrap errors

* chore: remove unused vars
This commit is contained in:
Kevin Wan
2022-02-16 19:31:43 +08:00
committed by GitHub
parent 7c63676be4
commit 607bae27fa
12 changed files with 458 additions and 152 deletions

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"errors"
"strconv"
@@ -17,12 +18,40 @@ type mockedConn struct {
execErr error
}
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...interface{}) (sql.Result, error) {
c.query = query
c.args = args
return nil, c.execErr
}
func (c *mockedConn) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
panic("implement me")
}
func (c *mockedConn) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
panic("implement me")
}
func (c *mockedConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
panic("implement me")
}
func (c *mockedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
panic("implement me")
}
func (c *mockedConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
panic("implement me")
}
func (c *mockedConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
panic("should not called")
}
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
return c.ExecCtx(context.Background(), query, args...)
}
func (c *mockedConn) Prepare(query string) (StmtSession, error) {
panic("should not called")
}

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"errors"
"testing"
@@ -16,7 +17,7 @@ func TestUnmarshalRowBool(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value bool
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.True(t, value)
@@ -29,7 +30,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value bool
assert.NotNil(t, query(db, func(rows *sql.Rows) error {
assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select value from users where user=?", "anyone"))
})
@@ -41,7 +42,7 @@ func TestUnmarshalRowInt(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, 2, value)
@@ -54,7 +55,7 @@ func TestUnmarshalRowInt8(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, int8(3), value)
@@ -67,7 +68,7 @@ func TestUnmarshalRowInt16(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.Equal(t, int16(4), value)
@@ -80,7 +81,7 @@ func TestUnmarshalRowInt32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.Equal(t, int32(5), value)
@@ -93,7 +94,7 @@ func TestUnmarshalRowInt64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value int64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, int64(6), value)
@@ -106,7 +107,7 @@ func TestUnmarshalRowUint(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint(2), value)
@@ -119,7 +120,7 @@ func TestUnmarshalRowUint8(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint8(3), value)
@@ -132,7 +133,7 @@ func TestUnmarshalRowUint16(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint16(4), value)
@@ -145,7 +146,7 @@ func TestUnmarshalRowUint32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint32(5), value)
@@ -158,7 +159,7 @@ func TestUnmarshalRowUint64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, uint16(6), value)
@@ -171,7 +172,7 @@ func TestUnmarshalRowFloat32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value float32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, float32(7), value)
@@ -184,7 +185,7 @@ func TestUnmarshalRowFloat64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value float64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, float64(8), value)
@@ -198,7 +199,7 @@ func TestUnmarshalRowString(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value string
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -215,7 +216,7 @@ func TestUnmarshalRowStruct(t *testing.T) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", value.Name)
@@ -233,7 +234,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
assert.Equal(t, "liao", value.Name)
@@ -251,7 +252,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.NotNil(t, query(db, func(rows *sql.Rows) error {
assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"))
})
@@ -264,7 +265,7 @@ func TestUnmarshalRowsBool(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []bool
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -278,7 +279,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -292,7 +293,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -306,7 +307,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -320,7 +321,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -334,7 +335,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []int64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -348,7 +349,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -362,7 +363,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -376,7 +377,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -390,7 +391,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -404,7 +405,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -418,7 +419,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []float32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -432,7 +433,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []float64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -446,7 +447,7 @@ func TestUnmarshalRowsString(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []string
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -462,7 +463,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*bool
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -478,7 +479,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -494,7 +495,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -510,7 +511,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -526,7 +527,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -542,7 +543,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*int64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -558,7 +559,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -574,7 +575,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint8
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -590,7 +591,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint16
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -606,7 +607,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -622,7 +623,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*uint64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -638,7 +639,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*float32
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -654,7 +655,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*float64
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -670,7 +671,7 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
var value []*string
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select value from users where user=?", "anyone"))
assert.EqualValues(t, expect, value)
@@ -699,7 +700,7 @@ func TestUnmarshalRowsStruct(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
@@ -739,7 +740,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
"first", "firstnullstring").AddRow("second", nil)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
@@ -773,7 +774,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
@@ -814,7 +815,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, value from users where user=?", "anyone"))
@@ -856,7 +857,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, value from users where user=?", "anyone"))
@@ -890,7 +891,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
@@ -923,7 +924,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
@@ -956,7 +957,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age from users where user=?", "anyone"))
@@ -976,7 +977,7 @@ func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
User string `db:"user"`
Age int `db:"age"`
}
assert.Nil(t, query(db, func(rows *sql.Rows) error {
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(&r, rows, false)
}, "select age from users where user=?", "anyone"))
assert.Empty(t, r.User)
@@ -1027,7 +1028,7 @@ func TestUnmarshalRowError(t *testing.T) {
User string `db:"user"`
Age int `db:"age"`
}
test.validate(query(db, func(rows *sql.Rows) error {
test.validate(query(context.Background(), db, func(rows *sql.Rows) error {
scanner := mockedScanner{
colErr: test.colErr,
scanErr: test.scanErr,

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"github.com/zeromicro/go-zero/core/breaker"
@@ -14,11 +15,17 @@ type (
// Session stands for raw connections or transaction sessions
Session interface {
Exec(query string, args ...interface{}) (sql.Result, error)
ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (StmtSession, error)
PrepareCtx(ctx context.Context, query string) (StmtSession, error)
QueryRow(v interface{}, query string, args ...interface{}) error
QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
QueryRowPartial(v interface{}, query string, args ...interface{}) error
QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
QueryRows(v interface{}, query string, args ...interface{}) error
QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
QueryRowsPartial(v interface{}, query string, args ...interface{}) error
QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
}
// SqlConn only stands for raw connections, so Transact method can be called.
@@ -27,7 +34,8 @@ type (
// RawDB is for other ORM to operate with, use it with caution.
// Notice: don't close it.
RawDB() (*sql.DB, error)
Transact(func(session Session) error) error
Transact(fn func(Session) error) error
TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
}
// SqlOption defines the method to customize a sql connection.
@@ -37,10 +45,15 @@ type (
StmtSession interface {
Close() error
Exec(args ...interface{}) (sql.Result, error)
ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error)
QueryRow(v interface{}, args ...interface{}) error
QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error
QueryRowPartial(v interface{}, args ...interface{}) error
QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
QueryRows(v interface{}, args ...interface{}) error
QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error
QueryRowsPartial(v interface{}, args ...interface{}) error
QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
}
// thread-safe
@@ -58,7 +71,9 @@ type (
sessionConn interface {
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
statement struct {
@@ -68,7 +83,9 @@ type (
stmtConn interface {
Exec(args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
Query(args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
}
)
@@ -112,6 +129,11 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
}
func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
return db.ExecCtx(context.Background(), q, args...)
}
func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) (
result sql.Result, err error) {
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = db.connProv()
@@ -120,7 +142,7 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
return err
}
result, err = exec(conn, q, args...)
result, err = exec(ctx, conn, q, args...)
return err
}, db.acceptable)
@@ -128,6 +150,10 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
}
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
return db.PrepareCtx(context.Background(), query)
}
func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
err = db.brk.DoWithAcceptable(func() error {
var conn *sql.DB
conn, err = db.connProv()
@@ -136,7 +162,7 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
return err
}
st, err := conn.Prepare(query)
st, err := conn.PrepareContext(ctx, query)
if err != nil {
return err
}
@@ -152,25 +178,45 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
}
func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return db.QueryRowCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return db.QueryRowPartialCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{},
q string, args ...interface{}) error {
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return db.QueryRowsCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return db.queryRows(func(rows *sql.Rows) error {
return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
}
func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
q string, args ...interface{}) error {
return db.queryRows(ctx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, q, args...)
}
@@ -180,13 +226,19 @@ func (db *commonSqlConn) RawDB() (*sql.DB, error) {
}
func (db *commonSqlConn) Transact(fn func(Session) error) error {
return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
return fn(session)
})
}
func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
return db.brk.DoWithAcceptable(func() error {
return transact(db, db.beginTx, fn)
return transact(ctx, db, db.beginTx, fn)
}, db.acceptable)
}
func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
if db.accept == nil {
return ok
}
@@ -194,7 +246,8 @@ func (db *commonSqlConn) acceptable(err error) bool {
return ok || db.accept(err)
}
func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
q string, args ...interface{}) error {
var qerr error
return db.brk.DoWithAcceptable(func() error {
conn, err := db.connProv()
@@ -203,7 +256,7 @@ func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args
return err
}
return query(conn, func(rows *sql.Rows) error {
return query(ctx, conn, func(rows *sql.Rows) error {
qerr = scanner(rows)
return qerr
}, q, args...)
@@ -217,29 +270,49 @@ func (s statement) Close() error {
}
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
return execStmt(s.stmt, s.query, args...)
return s.ExecCtx(context.Background(), args...)
}
func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) {
return execStmt(ctx, s.stmt, s.query, args...)
}
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return s.QueryRowCtx(context.Background(), v, args...)
}
func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error {
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, s.query, args...)
}
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return s.QueryRowPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, s.query, args...)
}
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return s.QueryRowsCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error {
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, s.query, args...)
}
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
return queryStmt(s.stmt, func(rows *sql.Rows) error {
return s.QueryRowsPartialCtx(context.Background(), v, args...)
}
func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, s.query, args...)
}

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"time"
@@ -18,64 +19,65 @@ func SetSlowThreshold(threshold time.Duration) {
slowThreshold.Set(threshold)
}
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now()
result, err := conn.Exec(q, args...)
result, err := conn.ExecContext(ctx, q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Infof("sql exec: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
logSqlError(ctx, stmt, err)
}
return result, err
}
func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
stmt, err := format(q, args...)
if err != nil {
return nil, err
}
startTime := timex.Now()
result, err := conn.Exec(args...)
result, err := conn.ExecContext(ctx, args...)
duration := timex.Since(startTime)
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Infof("sql execStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
logSqlError(ctx, stmt, err)
}
return result, err
}
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now()
rows, err := conn.Query(q, args...)
rows, err := conn.QueryContext(ctx, q, args...)
duration := timex.Since(startTime)
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql query: %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Infof("sql query: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
logSqlError(ctx, stmt, err)
return err
}
defer rows.Close()
@@ -83,22 +85,23 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in
return scanner(rows)
}
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error,
q string, args ...interface{}) error {
stmt, err := format(q, args...)
if err != nil {
return err
}
startTime := timex.Now()
rows, err := conn.Query(args...)
rows, err := conn.QueryContext(ctx, args...)
duration := timex.Since(startTime)
if duration > slowThreshold.Load() {
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
} else {
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
logx.WithContext(ctx).WithDuration(duration).Infof("sql queryStmt: %s", stmt)
}
if err != nil {
logSqlError(stmt, err)
logSqlError(ctx, stmt, err)
return err
}
defer rows.Close()

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"errors"
"testing"
@@ -57,7 +58,7 @@ func TestStmt_exec(t *testing.T) {
test := test
fns := []func(args ...interface{}) (sql.Result, error){
func(args ...interface{}) (sql.Result, error) {
return exec(&mockedSessionConn{
return exec(context.Background(), &mockedSessionConn{
lastInsertId: test.lastInsertId,
rowsAffected: test.rowsAffected,
err: test.err,
@@ -65,7 +66,7 @@ func TestStmt_exec(t *testing.T) {
}, test.query, args...)
},
func(args ...interface{}) (sql.Result, error) {
return execStmt(&mockedStmtConn{
return execStmt(context.Background(), &mockedStmtConn{
lastInsertId: test.lastInsertId,
rowsAffected: test.rowsAffected,
err: test.err,
@@ -137,7 +138,7 @@ func TestStmt_query(t *testing.T) {
test := test
fns := []func(args ...interface{}) error{
func(args ...interface{}) error {
return query(&mockedSessionConn{
return query(context.Background(), &mockedSessionConn{
err: test.err,
delay: test.delay,
}, func(rows *sql.Rows) error {
@@ -145,7 +146,7 @@ func TestStmt_query(t *testing.T) {
}, test.query, args...)
},
func(args ...interface{}) error {
return queryStmt(&mockedStmtConn{
return queryStmt(context.Background(), &mockedStmtConn{
err: test.err,
delay: test.delay,
}, func(rows *sql.Rows) error {
@@ -185,6 +186,10 @@ type mockedSessionConn struct {
}
func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
return m.ExecContext(context.Background(), query, args...)
}
func (m *mockedSessionConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if m.delay {
time.Sleep(defaultSlowThreshold + time.Millisecond)
}
@@ -195,6 +200,10 @@ func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result,
}
func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
return m.QueryContext(context.Background(), query, args...)
}
func (m *mockedSessionConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
if m.delay {
time.Sleep(defaultSlowThreshold + time.Millisecond)
}
@@ -214,6 +223,10 @@ type mockedStmtConn struct {
}
func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
return m.ExecContext(context.Background(), args...)
}
func (m *mockedStmtConn) ExecContext(_ context.Context, _ ...interface{}) (sql.Result, error) {
if m.delay {
time.Sleep(defaultSlowThreshold + time.Millisecond)
}
@@ -224,6 +237,10 @@ func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
}
func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
return m.QueryContext(context.Background(), args...)
}
func (m *mockedStmtConn) QueryContext(_ context.Context, _ ...interface{}) (*sql.Rows, error) {
if m.delay {
time.Sleep(defaultSlowThreshold + time.Millisecond)
}

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"fmt"
)
@@ -26,11 +27,19 @@ func NewSessionFromTx(tx *sql.Tx) Session {
}
func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
return exec(t.Tx, q, args...)
return t.ExecCtx(context.Background(), q, args...)
}
func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (sql.Result, error) {
return exec(ctx, t.Tx, q, args...)
}
func (t txSession) Prepare(q string) (StmtSession, error) {
stmt, err := t.Tx.Prepare(q)
return t.PrepareCtx(context.Background(), q)
}
func (t txSession) PrepareCtx(ctx context.Context, q string) (StmtSession, error) {
stmt, err := t.Tx.PrepareContext(ctx, q)
if err != nil {
return nil, err
}
@@ -42,25 +51,43 @@ func (t txSession) Prepare(q string) (StmtSession, error) {
}
func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return t.QueryRowCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
return query(ctx, t.Tx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, true)
}, q, args...)
}
func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return t.QueryRowPartialCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
return query(ctx, t.Tx, func(rows *sql.Rows) error {
return unmarshalRow(v, rows, false)
}, q, args...)
}
func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return t.QueryRowsCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
return query(ctx, t.Tx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, true)
}, q, args...)
}
func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return query(t.Tx, func(rows *sql.Rows) error {
return t.QueryRowsPartialCtx(context.Background(), v, q, args...)
}
func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string,
args ...interface{}) error {
return query(ctx, t.Tx, func(rows *sql.Rows) error {
return unmarshalRows(v, rows, false)
}, q, args...)
}
@@ -76,17 +103,19 @@ func begin(db *sql.DB) (trans, error) {
}, nil
}
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
fn func(context.Context, Session) error) (err error) {
conn, err := db.connProv()
if err != nil {
db.onError(err)
return err
}
return transactOnConn(conn, b, fn)
return transactOnConn(ctx, conn, b, fn)
}
func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
fn func(context.Context, Session) error) (err error) {
var tx trans
tx, err = b(conn)
if err != nil {
@@ -96,18 +125,18 @@ func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err err
defer func() {
if p := recover(); p != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
} else {
err = fmt.Errorf("recoveer from %#v", p)
}
} else if err != nil {
if e := tx.Rollback(); e != nil {
err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
}
} else {
err = tx.Commit()
}
}()
return fn(tx)
return fn(ctx, tx)
}

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"database/sql"
"errors"
"testing"
@@ -26,26 +27,50 @@ func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
return nil
}
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
return nil
}
func (mt *mockTx) Rollback() error {
mt.status |= mockRollback
return nil
@@ -59,18 +84,20 @@ func beginMock(mock *mockTx) beginnable {
func TestTransactCommit(t *testing.T) {
mock := &mockTx{}
err := transactOnConn(nil, beginMock(mock), func(Session) error {
return nil
})
err := transactOnConn(context.Background(), nil, beginMock(mock),
func(context.Context, Session) error {
return nil
})
assert.Equal(t, mockCommit, mock.status)
assert.Nil(t, err)
}
func TestTransactRollback(t *testing.T) {
mock := &mockTx{}
err := transactOnConn(nil, beginMock(mock), func(Session) error {
return errors.New("rollback")
})
err := transactOnConn(context.Background(), nil, beginMock(mock),
func(context.Context, Session) error {
return errors.New("rollback")
})
assert.Equal(t, mockRollback, mock.status)
assert.NotNil(t, err)
}

View File

@@ -1,6 +1,7 @@
package sqlx
import (
"context"
"fmt"
"strconv"
"strings"
@@ -109,9 +110,9 @@ func logInstanceError(datasource string, err error) {
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
}
func logSqlError(stmt string, err error) {
func logSqlError(ctx context.Context, stmt string, err error) {
if err != nil && err != ErrNotFound {
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
}
}