feat: support ctx in sqlx/sqlc, listed in ROADMAP (#1535)
* feat: support ctx in sqlx/sqlc * chore: update roadmap * fix: context.Canceled should be acceptable * use %w to wrap errors * chore: remove unused vars
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
@@ -18,19 +19,27 @@ var (
|
||||
ErrNotFound = sqlx.ErrNotFound
|
||||
|
||||
// can't use one SingleFlight per conn, because multiple conns may share the same cache key.
|
||||
exclusiveCalls = syncx.NewSingleFlight()
|
||||
stats = cache.NewStat("sqlc")
|
||||
singleFlights = syncx.NewSingleFlight()
|
||||
stats = cache.NewStat("sqlc")
|
||||
)
|
||||
|
||||
type (
|
||||
// ExecFn defines the sql exec method.
|
||||
ExecFn func(conn sqlx.SqlConn) (sql.Result, error)
|
||||
// ExecCtxFn defines the sql exec method.
|
||||
ExecCtxFn func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error)
|
||||
// IndexQueryFn defines the query method that based on unique indexes.
|
||||
IndexQueryFn func(conn sqlx.SqlConn, v interface{}) (interface{}, error)
|
||||
// IndexQueryCtxFn defines the query method that based on unique indexes.
|
||||
IndexQueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v interface{}) (interface{}, error)
|
||||
// PrimaryQueryFn defines the query method that based on primary keys.
|
||||
PrimaryQueryFn func(conn sqlx.SqlConn, v, primary interface{}) error
|
||||
// PrimaryQueryCtxFn defines the query method that based on primary keys.
|
||||
PrimaryQueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v, primary interface{}) error
|
||||
// QueryFn defines the query method.
|
||||
QueryFn func(conn sqlx.SqlConn, v interface{}) error
|
||||
// QueryCtxFn defines the query method.
|
||||
QueryCtxFn func(ctx context.Context, conn sqlx.SqlConn, v interface{}) error
|
||||
|
||||
// A CachedConn is a DB connection with cache capability.
|
||||
CachedConn struct {
|
||||
@@ -41,7 +50,7 @@ type (
|
||||
|
||||
// NewConn returns a CachedConn with a redis cluster cache.
|
||||
func NewConn(db sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) CachedConn {
|
||||
cc := cache.New(c, exclusiveCalls, stats, sql.ErrNoRows, opts...)
|
||||
cc := cache.New(c, singleFlights, stats, sql.ErrNoRows, opts...)
|
||||
return NewConnWithCache(db, cc)
|
||||
}
|
||||
|
||||
@@ -55,28 +64,46 @@ func NewConnWithCache(db sqlx.SqlConn, c cache.Cache) CachedConn {
|
||||
|
||||
// NewNodeConn returns a CachedConn with a redis node cache.
|
||||
func NewNodeConn(db sqlx.SqlConn, rds *redis.Redis, opts ...cache.Option) CachedConn {
|
||||
c := cache.NewNode(rds, exclusiveCalls, stats, sql.ErrNoRows, opts...)
|
||||
c := cache.NewNode(rds, singleFlights, stats, sql.ErrNoRows, opts...)
|
||||
return NewConnWithCache(db, c)
|
||||
}
|
||||
|
||||
// DelCache deletes cache with keys.
|
||||
func (cc CachedConn) DelCache(keys ...string) error {
|
||||
return cc.cache.Del(keys...)
|
||||
return cc.DelCacheCtx(context.Background(), keys...)
|
||||
}
|
||||
|
||||
// DelCacheCtx deletes cache with keys.
|
||||
func (cc CachedConn) DelCacheCtx(ctx context.Context, keys ...string) error {
|
||||
return cc.cache.DelCtx(ctx, keys...)
|
||||
}
|
||||
|
||||
// GetCache unmarshals cache with given key into v.
|
||||
func (cc CachedConn) GetCache(key string, v interface{}) error {
|
||||
return cc.cache.Get(key, v)
|
||||
return cc.GetCacheCtx(context.Background(), key, v)
|
||||
}
|
||||
|
||||
// GetCacheCtx unmarshals cache with given key into v.
|
||||
func (cc CachedConn) GetCacheCtx(ctx context.Context, key string, v interface{}) error {
|
||||
return cc.cache.GetCtx(ctx, key, v)
|
||||
}
|
||||
|
||||
// Exec runs given exec on given keys, and returns execution result.
|
||||
func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
|
||||
res, err := exec(cc.db)
|
||||
execCtx := func(_ context.Context, conn sqlx.SqlConn) (sql.Result, error) {
|
||||
return exec(conn)
|
||||
}
|
||||
return cc.ExecCtx(context.Background(), execCtx, keys...)
|
||||
}
|
||||
|
||||
// ExecCtx runs given exec on given keys, and returns execution result.
|
||||
func (cc CachedConn) ExecCtx(ctx context.Context, exec ExecCtxFn, keys ...string) (sql.Result, error) {
|
||||
res, err := exec(ctx, cc.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := cc.DelCache(keys...); err != nil {
|
||||
if err := cc.DelCacheCtx(ctx, keys...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -85,31 +112,61 @@ func (cc CachedConn) Exec(exec ExecFn, keys ...string) (sql.Result, error) {
|
||||
|
||||
// ExecNoCache runs exec with given sql statement, without affecting cache.
|
||||
func (cc CachedConn) ExecNoCache(q string, args ...interface{}) (sql.Result, error) {
|
||||
return cc.db.Exec(q, args...)
|
||||
return cc.ExecNoCacheCtx(context.Background(), q, args...)
|
||||
}
|
||||
|
||||
// ExecNoCacheCtx runs exec with given sql statement, without affecting cache.
|
||||
func (cc CachedConn) ExecNoCacheCtx(ctx context.Context, q string, args ...interface{}) (
|
||||
sql.Result, error) {
|
||||
return cc.db.ExecCtx(ctx, q, args...)
|
||||
}
|
||||
|
||||
// QueryRow unmarshals into v with given key and query func.
|
||||
func (cc CachedConn) QueryRow(v interface{}, key string, query QueryFn) error {
|
||||
return cc.cache.Take(v, key, func(v interface{}) error {
|
||||
return query(cc.db, v)
|
||||
queryCtx := func(_ context.Context, conn sqlx.SqlConn, v interface{}) error {
|
||||
return query(conn, v)
|
||||
}
|
||||
return cc.QueryRowCtx(context.Background(), v, key, queryCtx)
|
||||
}
|
||||
|
||||
// QueryRowCtx unmarshals into v with given key and query func.
|
||||
func (cc CachedConn) QueryRowCtx(ctx context.Context, v interface{}, key string, query QueryCtxFn) error {
|
||||
return cc.cache.TakeCtx(ctx, v, key, func(v interface{}) error {
|
||||
return query(ctx, cc.db, v)
|
||||
})
|
||||
}
|
||||
|
||||
// QueryRowIndex unmarshals into v with given key.
|
||||
func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary interface{}) string,
|
||||
indexQuery IndexQueryFn, primaryQuery PrimaryQueryFn) error {
|
||||
indexQueryCtx := func(_ context.Context, conn sqlx.SqlConn, v interface{}) (interface{}, error) {
|
||||
return indexQuery(conn, v)
|
||||
}
|
||||
primaryQueryCtx := func(_ context.Context, conn sqlx.SqlConn, v, primary interface{}) error {
|
||||
return primaryQuery(conn, v, primary)
|
||||
}
|
||||
|
||||
return cc.QueryRowIndexCtx(context.Background(), v, key, keyer, indexQueryCtx, primaryQueryCtx)
|
||||
}
|
||||
|
||||
// QueryRowIndexCtx unmarshals into v with given key.
|
||||
func (cc CachedConn) QueryRowIndexCtx(ctx context.Context, v interface{}, key string,
|
||||
keyer func(primary interface{}) string, indexQuery IndexQueryCtxFn,
|
||||
primaryQuery PrimaryQueryCtxFn) error {
|
||||
var primaryKey interface{}
|
||||
var found bool
|
||||
|
||||
if err := cc.cache.TakeWithExpire(&primaryKey, key, func(val interface{}, expire time.Duration) (err error) {
|
||||
primaryKey, err = indexQuery(cc.db, v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if err := cc.cache.TakeWithExpireCtx(ctx, &primaryKey, key,
|
||||
func(val interface{}, expire time.Duration) (err error) {
|
||||
primaryKey, err = indexQuery(ctx, cc.db, v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
found = true
|
||||
return cc.cache.SetWithExpire(keyer(primaryKey), v, expire+cacheSafeGapBetweenIndexAndPrimary)
|
||||
}); err != nil {
|
||||
found = true
|
||||
return cc.cache.SetWithExpireCtx(ctx, keyer(primaryKey), v,
|
||||
expire+cacheSafeGapBetweenIndexAndPrimary)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -117,28 +174,54 @@ func (cc CachedConn) QueryRowIndex(v interface{}, key string, keyer func(primary
|
||||
return nil
|
||||
}
|
||||
|
||||
return cc.cache.Take(v, keyer(primaryKey), func(v interface{}) error {
|
||||
return primaryQuery(cc.db, v, primaryKey)
|
||||
return cc.cache.TakeCtx(ctx, v, keyer(primaryKey), func(v interface{}) error {
|
||||
return primaryQuery(ctx, cc.db, v, primaryKey)
|
||||
})
|
||||
}
|
||||
|
||||
// QueryRowNoCache unmarshals into v with given statement.
|
||||
func (cc CachedConn) QueryRowNoCache(v interface{}, q string, args ...interface{}) error {
|
||||
return cc.db.QueryRow(v, q, args...)
|
||||
return cc.QueryRowNoCacheCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
// QueryRowNoCacheCtx unmarshals into v with given statement.
|
||||
func (cc CachedConn) QueryRowNoCacheCtx(ctx context.Context, v interface{}, q string,
|
||||
args ...interface{}) error {
|
||||
return cc.db.QueryRowCtx(ctx, v, q, args...)
|
||||
}
|
||||
|
||||
// QueryRowsNoCache unmarshals into v with given statement.
|
||||
// It doesn't use cache, because it might cause consistency problem.
|
||||
func (cc CachedConn) QueryRowsNoCache(v interface{}, q string, args ...interface{}) error {
|
||||
return cc.db.QueryRows(v, q, args...)
|
||||
return cc.QueryRowsNoCacheCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
// QueryRowsNoCacheCtx unmarshals into v with given statement.
|
||||
// It doesn't use cache, because it might cause consistency problem.
|
||||
func (cc CachedConn) QueryRowsNoCacheCtx(ctx context.Context, v interface{}, q string,
|
||||
args ...interface{}) error {
|
||||
return cc.db.QueryRowsCtx(ctx, v, q, args...)
|
||||
}
|
||||
|
||||
// SetCache sets v into cache with given key.
|
||||
func (cc CachedConn) SetCache(key string, v interface{}) error {
|
||||
return cc.cache.Set(key, v)
|
||||
func (cc CachedConn) SetCache(key string, val interface{}) error {
|
||||
return cc.SetCacheCtx(context.Background(), key, val)
|
||||
}
|
||||
|
||||
// SetCacheCtx sets v into cache with given key.
|
||||
func (cc CachedConn) SetCacheCtx(ctx context.Context, key string, val interface{}) error {
|
||||
return cc.cache.SetCtx(ctx, key, val)
|
||||
}
|
||||
|
||||
// Transact runs given fn in transaction mode.
|
||||
func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
|
||||
return cc.db.Transact(fn)
|
||||
fnCtx := func(_ context.Context, session sqlx.Session) error {
|
||||
return fn(session)
|
||||
}
|
||||
return cc.TransactCtx(context.Background(), fnCtx)
|
||||
}
|
||||
|
||||
// TransactCtx runs given fn in transaction mode.
|
||||
func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
|
||||
return cc.db.TransactCtx(ctx, fn)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -568,7 +569,7 @@ func TestNewConnWithCache(t *testing.T) {
|
||||
defer clean()
|
||||
|
||||
var conn trackedConn
|
||||
c := NewConnWithCache(&conn, cache.NewNode(r, exclusiveCalls, stats, sql.ErrNoRows))
|
||||
c := NewConnWithCache(&conn, cache.NewNode(r, singleFlights, stats, sql.ErrNoRows))
|
||||
_, err = c.ExecNoCache("delete from user_table where id='kevin'")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, conn.execValue)
|
||||
@@ -585,6 +586,30 @@ type dummySqlConn struct {
|
||||
queryRow func(interface{}, string, ...interface{}) error
|
||||
}
|
||||
|
||||
func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -594,6 +619,10 @@ func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRow(v interface{}, query string, args ...interface{}) error {
|
||||
return d.QueryRowCtx(context.Background(), v, query, args...)
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowCtx(_ context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
if d.queryRow != nil {
|
||||
return d.queryRow(v, query, args...)
|
||||
}
|
||||
@@ -628,13 +657,21 @@ type trackedConn struct {
|
||||
}
|
||||
|
||||
func (c *trackedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return c.ExecCtx(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (c *trackedConn) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
c.execValue = true
|
||||
return c.dummySqlConn.Exec(query, args...)
|
||||
return c.dummySqlConn.ExecCtx(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (c *trackedConn) QueryRows(v interface{}, query string, args ...interface{}) error {
|
||||
return c.QueryRowsCtx(context.Background(), v, query, args...)
|
||||
}
|
||||
|
||||
func (c *trackedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
c.queryRowsValue = true
|
||||
return c.dummySqlConn.QueryRows(v, query, args...)
|
||||
return c.dummySqlConn.QueryRowsCtx(ctx, v, query, args...)
|
||||
}
|
||||
|
||||
func (c *trackedConn) RawDB() (*sql.DB, error) {
|
||||
@@ -642,6 +679,12 @@ func (c *trackedConn) RawDB() (*sql.DB, error) {
|
||||
}
|
||||
|
||||
func (c *trackedConn) Transact(fn func(session sqlx.Session) error) error {
|
||||
c.transactValue = true
|
||||
return c.dummySqlConn.Transact(fn)
|
||||
return c.TransactCtx(context.Background(), func(_ context.Context, session sqlx.Session) error {
|
||||
return fn(session)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *trackedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
|
||||
c.transactValue = true
|
||||
return c.dummySqlConn.TransactCtx(ctx, fn)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strconv"
|
||||
@@ -17,12 +18,40 @@ type mockedConn struct {
|
||||
execErr error
|
||||
}
|
||||
|
||||
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
c.query = query
|
||||
c.args = args
|
||||
return nil, c.execErr
|
||||
}
|
||||
|
||||
func (c *mockedConn) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *mockedConn) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *mockedConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
func (c *mockedConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return c.ExecCtx(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (c *mockedConn) Prepare(query string) (StmtSession, error) {
|
||||
panic("should not called")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
@@ -16,7 +17,7 @@ func TestUnmarshalRowBool(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value bool
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.True(t, value)
|
||||
@@ -29,7 +30,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value bool
|
||||
assert.NotNil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
})
|
||||
@@ -41,7 +42,7 @@ func TestUnmarshalRowInt(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, 2, value)
|
||||
@@ -54,7 +55,7 @@ func TestUnmarshalRowInt8(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, int8(3), value)
|
||||
@@ -67,7 +68,7 @@ func TestUnmarshalRowInt16(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.Equal(t, int16(4), value)
|
||||
@@ -80,7 +81,7 @@ func TestUnmarshalRowInt32(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.Equal(t, int32(5), value)
|
||||
@@ -93,7 +94,7 @@ func TestUnmarshalRowInt64(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value int64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, int64(6), value)
|
||||
@@ -106,7 +107,7 @@ func TestUnmarshalRowUint(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint(2), value)
|
||||
@@ -119,7 +120,7 @@ func TestUnmarshalRowUint8(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint8(3), value)
|
||||
@@ -132,7 +133,7 @@ func TestUnmarshalRowUint16(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint16(4), value)
|
||||
@@ -145,7 +146,7 @@ func TestUnmarshalRowUint32(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint32(5), value)
|
||||
@@ -158,7 +159,7 @@ func TestUnmarshalRowUint64(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value uint64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, uint16(6), value)
|
||||
@@ -171,7 +172,7 @@ func TestUnmarshalRowFloat32(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value float32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, float32(7), value)
|
||||
@@ -184,7 +185,7 @@ func TestUnmarshalRowFloat64(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value float64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, float64(8), value)
|
||||
@@ -198,7 +199,7 @@ func TestUnmarshalRowString(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value string
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -215,7 +216,7 @@ func TestUnmarshalRowStruct(t *testing.T) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, "liao", value.Name)
|
||||
@@ -233,7 +234,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
assert.Equal(t, "liao", value.Name)
|
||||
@@ -251,7 +252,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
|
||||
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
assert.NotNil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.NotNil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
})
|
||||
@@ -264,7 +265,7 @@ func TestUnmarshalRowsBool(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []bool
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -278,7 +279,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -292,7 +293,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -306,7 +307,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -320,7 +321,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -334,7 +335,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []int64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -348,7 +349,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -362,7 +363,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -376,7 +377,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -390,7 +391,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -404,7 +405,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []uint64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -418,7 +419,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []float32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -432,7 +433,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []float64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -446,7 +447,7 @@ func TestUnmarshalRowsString(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []string
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -462,7 +463,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*bool
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -478,7 +479,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -494,7 +495,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -510,7 +511,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -526,7 +527,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -542,7 +543,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*int64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -558,7 +559,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -574,7 +575,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint8
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -590,7 +591,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint16
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -606,7 +607,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -622,7 +623,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*uint64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -638,7 +639,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*float32
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -654,7 +655,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*float64
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -670,7 +671,7 @@ func TestUnmarshalRowsStringPtr(t *testing.T) {
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
|
||||
var value []*string
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select value from users where user=?", "anyone"))
|
||||
assert.EqualValues(t, expect, value)
|
||||
@@ -699,7 +700,7 @@ func TestUnmarshalRowsStruct(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
@@ -739,7 +740,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
|
||||
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
|
||||
"first", "firstnullstring").AddRow("second", nil)
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
@@ -773,7 +774,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
@@ -814,7 +815,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, value from users where user=?", "anyone"))
|
||||
|
||||
@@ -856,7 +857,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age, value from users where user=?", "anyone"))
|
||||
|
||||
@@ -890,7 +891,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
@@ -923,7 +924,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
@@ -956,7 +957,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
|
||||
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
|
||||
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(&value, rows, true)
|
||||
}, "select name, age from users where user=?", "anyone"))
|
||||
|
||||
@@ -976,7 +977,7 @@ func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
|
||||
User string `db:"user"`
|
||||
Age int `db:"age"`
|
||||
}
|
||||
assert.Nil(t, query(db, func(rows *sql.Rows) error {
|
||||
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(&r, rows, false)
|
||||
}, "select age from users where user=?", "anyone"))
|
||||
assert.Empty(t, r.User)
|
||||
@@ -1027,7 +1028,7 @@ func TestUnmarshalRowError(t *testing.T) {
|
||||
User string `db:"user"`
|
||||
Age int `db:"age"`
|
||||
}
|
||||
test.validate(query(db, func(rows *sql.Rows) error {
|
||||
test.validate(query(context.Background(), db, func(rows *sql.Rows) error {
|
||||
scanner := mockedScanner{
|
||||
colErr: test.colErr,
|
||||
scanErr: test.scanErr,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
@@ -14,11 +15,17 @@ type (
|
||||
// Session stands for raw connections or transaction sessions
|
||||
Session interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(query string) (StmtSession, error)
|
||||
PrepareCtx(ctx context.Context, query string) (StmtSession, error)
|
||||
QueryRow(v interface{}, query string, args ...interface{}) error
|
||||
QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
|
||||
QueryRowPartial(v interface{}, query string, args ...interface{}) error
|
||||
QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
|
||||
QueryRows(v interface{}, query string, args ...interface{}) error
|
||||
QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
|
||||
QueryRowsPartial(v interface{}, query string, args ...interface{}) error
|
||||
QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error
|
||||
}
|
||||
|
||||
// SqlConn only stands for raw connections, so Transact method can be called.
|
||||
@@ -27,7 +34,8 @@ type (
|
||||
// RawDB is for other ORM to operate with, use it with caution.
|
||||
// Notice: don't close it.
|
||||
RawDB() (*sql.DB, error)
|
||||
Transact(func(session Session) error) error
|
||||
Transact(fn func(Session) error) error
|
||||
TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error
|
||||
}
|
||||
|
||||
// SqlOption defines the method to customize a sql connection.
|
||||
@@ -37,10 +45,15 @@ type (
|
||||
StmtSession interface {
|
||||
Close() error
|
||||
Exec(args ...interface{}) (sql.Result, error)
|
||||
ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error)
|
||||
QueryRow(v interface{}, args ...interface{}) error
|
||||
QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error
|
||||
QueryRowPartial(v interface{}, args ...interface{}) error
|
||||
QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
|
||||
QueryRows(v interface{}, args ...interface{}) error
|
||||
QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error
|
||||
QueryRowsPartial(v interface{}, args ...interface{}) error
|
||||
QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error
|
||||
}
|
||||
|
||||
// thread-safe
|
||||
@@ -58,7 +71,9 @@ type (
|
||||
|
||||
sessionConn interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
statement struct {
|
||||
@@ -68,7 +83,9 @@ type (
|
||||
|
||||
stmtConn interface {
|
||||
Exec(args ...interface{}) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error)
|
||||
Query(args ...interface{}) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -112,6 +129,11 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result, err error) {
|
||||
return db.ExecCtx(context.Background(), q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) ExecCtx(ctx context.Context, q string, args ...interface{}) (
|
||||
result sql.Result, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = db.connProv()
|
||||
@@ -120,7 +142,7 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
|
||||
return err
|
||||
}
|
||||
|
||||
result, err = exec(conn, q, args...)
|
||||
result, err = exec(ctx, conn, q, args...)
|
||||
return err
|
||||
}, db.acceptable)
|
||||
|
||||
@@ -128,6 +150,10 @@ func (db *commonSqlConn) Exec(q string, args ...interface{}) (result sql.Result,
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
|
||||
return db.PrepareCtx(context.Background(), query)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) PrepareCtx(ctx context.Context, query string) (stmt StmtSession, err error) {
|
||||
err = db.brk.DoWithAcceptable(func() error {
|
||||
var conn *sql.DB
|
||||
conn, err = db.connProv()
|
||||
@@ -136,7 +162,7 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
st, err := conn.Prepare(query)
|
||||
st, err := conn.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -152,25 +178,45 @@ func (db *commonSqlConn) Prepare(query string) (stmt StmtSession, err error) {
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return db.QueryRowCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowCtx(ctx context.Context, v interface{}, q string,
|
||||
args ...interface{}) error {
|
||||
return db.queryRows(ctx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return db.QueryRowPartialCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowPartialCtx(ctx context.Context, v interface{},
|
||||
q string, args ...interface{}) error {
|
||||
return db.queryRows(ctx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return db.QueryRowsCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowsCtx(ctx context.Context, v interface{}, q string,
|
||||
args ...interface{}) error {
|
||||
return db.queryRows(ctx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return db.queryRows(func(rows *sql.Rows) error {
|
||||
return db.QueryRowsPartialCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) QueryRowsPartialCtx(ctx context.Context, v interface{},
|
||||
q string, args ...interface{}) error {
|
||||
return db.queryRows(ctx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
@@ -180,13 +226,19 @@ func (db *commonSqlConn) RawDB() (*sql.DB, error) {
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) Transact(fn func(Session) error) error {
|
||||
return db.TransactCtx(context.Background(), func(_ context.Context, session Session) error {
|
||||
return fn(session)
|
||||
})
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Context, Session) error) error {
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
return transact(db, db.beginTx, fn)
|
||||
return transact(ctx, db, db.beginTx, fn)
|
||||
}, db.acceptable)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) acceptable(err error) bool {
|
||||
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone
|
||||
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
|
||||
if db.accept == nil {
|
||||
return ok
|
||||
}
|
||||
@@ -194,7 +246,8 @@ func (db *commonSqlConn) acceptable(err error) bool {
|
||||
return ok || db.accept(err)
|
||||
}
|
||||
|
||||
func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,
|
||||
q string, args ...interface{}) error {
|
||||
var qerr error
|
||||
return db.brk.DoWithAcceptable(func() error {
|
||||
conn, err := db.connProv()
|
||||
@@ -203,7 +256,7 @@ func (db *commonSqlConn) queryRows(scanner func(*sql.Rows) error, q string, args
|
||||
return err
|
||||
}
|
||||
|
||||
return query(conn, func(rows *sql.Rows) error {
|
||||
return query(ctx, conn, func(rows *sql.Rows) error {
|
||||
qerr = scanner(rows)
|
||||
return qerr
|
||||
}, q, args...)
|
||||
@@ -217,29 +270,49 @@ func (s statement) Close() error {
|
||||
}
|
||||
|
||||
func (s statement) Exec(args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(s.stmt, s.query, args...)
|
||||
return s.ExecCtx(context.Background(), args...)
|
||||
}
|
||||
|
||||
func (s statement) ExecCtx(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(ctx, s.stmt, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRow(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return s.QueryRowCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowCtx(ctx context.Context, v interface{}, args ...interface{}) error {
|
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return s.QueryRowPartialCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
|
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRows(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return s.QueryRowsCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsCtx(ctx context.Context, v interface{}, args ...interface{}) error {
|
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartial(v interface{}, args ...interface{}) error {
|
||||
return queryStmt(s.stmt, func(rows *sql.Rows) error {
|
||||
return s.QueryRowsPartialCtx(context.Background(), v, args...)
|
||||
}
|
||||
|
||||
func (s statement) QueryRowsPartialCtx(ctx context.Context, v interface{}, args ...interface{}) error {
|
||||
return queryStmt(ctx, s.stmt, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, s.query, args...)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
@@ -18,64 +19,65 @@ func SetSlowThreshold(threshold time.Duration) {
|
||||
slowThreshold.Set(threshold)
|
||||
}
|
||||
|
||||
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
func exec(ctx context.Context, conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(q, args...)
|
||||
result, err := conn.ExecContext(ctx, q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold.Load() {
|
||||
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql exec: %s", stmt)
|
||||
logx.WithContext(ctx).WithDuration(duration).Infof("sql exec: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
logSqlError(ctx, stmt, err)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func execStmt(conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
func execStmt(ctx context.Context, conn stmtConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(args...)
|
||||
result, err := conn.ExecContext(ctx, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold.Load() {
|
||||
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql execStmt: %s", stmt)
|
||||
logx.WithContext(ctx).WithDuration(duration).Infof("sql execStmt: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
logSqlError(ctx, stmt, err)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
func query(ctx context.Context, conn sessionConn, scanner func(*sql.Rows) error,
|
||||
q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(q, args...)
|
||||
rows, err := conn.QueryContext(ctx, q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold.Load() {
|
||||
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql query: %s", stmt)
|
||||
logx.WithContext(ctx).WithDuration(duration).Infof("sql query: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
logSqlError(ctx, stmt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
@@ -83,22 +85,23 @@ func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...in
|
||||
return scanner(rows)
|
||||
}
|
||||
|
||||
func queryStmt(conn stmtConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
func queryStmt(ctx context.Context, conn stmtConn, scanner func(*sql.Rows) error,
|
||||
q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(args...)
|
||||
rows, err := conn.QueryContext(ctx, args...)
|
||||
duration := timex.Since(startTime)
|
||||
if duration > slowThreshold.Load() {
|
||||
logx.WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
|
||||
logx.WithContext(ctx).WithDuration(duration).Slowf("[SQL] queryStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("sql queryStmt: %s", stmt)
|
||||
logx.WithContext(ctx).WithDuration(duration).Infof("sql queryStmt: %s", stmt)
|
||||
}
|
||||
if err != nil {
|
||||
logSqlError(stmt, err)
|
||||
logSqlError(ctx, stmt, err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
@@ -57,7 +58,7 @@ func TestStmt_exec(t *testing.T) {
|
||||
test := test
|
||||
fns := []func(args ...interface{}) (sql.Result, error){
|
||||
func(args ...interface{}) (sql.Result, error) {
|
||||
return exec(&mockedSessionConn{
|
||||
return exec(context.Background(), &mockedSessionConn{
|
||||
lastInsertId: test.lastInsertId,
|
||||
rowsAffected: test.rowsAffected,
|
||||
err: test.err,
|
||||
@@ -65,7 +66,7 @@ func TestStmt_exec(t *testing.T) {
|
||||
}, test.query, args...)
|
||||
},
|
||||
func(args ...interface{}) (sql.Result, error) {
|
||||
return execStmt(&mockedStmtConn{
|
||||
return execStmt(context.Background(), &mockedStmtConn{
|
||||
lastInsertId: test.lastInsertId,
|
||||
rowsAffected: test.rowsAffected,
|
||||
err: test.err,
|
||||
@@ -137,7 +138,7 @@ func TestStmt_query(t *testing.T) {
|
||||
test := test
|
||||
fns := []func(args ...interface{}) error{
|
||||
func(args ...interface{}) error {
|
||||
return query(&mockedSessionConn{
|
||||
return query(context.Background(), &mockedSessionConn{
|
||||
err: test.err,
|
||||
delay: test.delay,
|
||||
}, func(rows *sql.Rows) error {
|
||||
@@ -145,7 +146,7 @@ func TestStmt_query(t *testing.T) {
|
||||
}, test.query, args...)
|
||||
},
|
||||
func(args ...interface{}) error {
|
||||
return queryStmt(&mockedStmtConn{
|
||||
return queryStmt(context.Background(), &mockedStmtConn{
|
||||
err: test.err,
|
||||
delay: test.delay,
|
||||
}, func(rows *sql.Rows) error {
|
||||
@@ -185,6 +186,10 @@ type mockedSessionConn struct {
|
||||
}
|
||||
|
||||
func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return m.ExecContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (m *mockedSessionConn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
if m.delay {
|
||||
time.Sleep(defaultSlowThreshold + time.Millisecond)
|
||||
}
|
||||
@@ -195,6 +200,10 @@ func (m *mockedSessionConn) Exec(query string, args ...interface{}) (sql.Result,
|
||||
}
|
||||
|
||||
func (m *mockedSessionConn) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
return m.QueryContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (m *mockedSessionConn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
if m.delay {
|
||||
time.Sleep(defaultSlowThreshold + time.Millisecond)
|
||||
}
|
||||
@@ -214,6 +223,10 @@ type mockedStmtConn struct {
|
||||
}
|
||||
|
||||
func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
|
||||
return m.ExecContext(context.Background(), args...)
|
||||
}
|
||||
|
||||
func (m *mockedStmtConn) ExecContext(_ context.Context, _ ...interface{}) (sql.Result, error) {
|
||||
if m.delay {
|
||||
time.Sleep(defaultSlowThreshold + time.Millisecond)
|
||||
}
|
||||
@@ -224,6 +237,10 @@ func (m *mockedStmtConn) Exec(args ...interface{}) (sql.Result, error) {
|
||||
}
|
||||
|
||||
func (m *mockedStmtConn) Query(args ...interface{}) (*sql.Rows, error) {
|
||||
return m.QueryContext(context.Background(), args...)
|
||||
}
|
||||
|
||||
func (m *mockedStmtConn) QueryContext(_ context.Context, _ ...interface{}) (*sql.Rows, error) {
|
||||
if m.delay {
|
||||
time.Sleep(defaultSlowThreshold + time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
@@ -26,11 +27,19 @@ func NewSessionFromTx(tx *sql.Tx) Session {
|
||||
}
|
||||
|
||||
func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
|
||||
return exec(t.Tx, q, args...)
|
||||
return t.ExecCtx(context.Background(), q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (sql.Result, error) {
|
||||
return exec(ctx, t.Tx, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) Prepare(q string) (StmtSession, error) {
|
||||
stmt, err := t.Tx.Prepare(q)
|
||||
return t.PrepareCtx(context.Background(), q)
|
||||
}
|
||||
|
||||
func (t txSession) PrepareCtx(ctx context.Context, q string) (StmtSession, error) {
|
||||
stmt, err := t.Tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -42,25 +51,43 @@ func (t txSession) Prepare(q string) (StmtSession, error) {
|
||||
}
|
||||
|
||||
func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return t.QueryRowCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
|
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return t.QueryRowPartialCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string,
|
||||
args ...interface{}) error {
|
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRow(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return t.QueryRowsCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) error {
|
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, true)
|
||||
}, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return query(t.Tx, func(rows *sql.Rows) error {
|
||||
return t.QueryRowsPartialCtx(context.Background(), v, q, args...)
|
||||
}
|
||||
|
||||
func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string,
|
||||
args ...interface{}) error {
|
||||
return query(ctx, t.Tx, func(rows *sql.Rows) error {
|
||||
return unmarshalRows(v, rows, false)
|
||||
}, q, args...)
|
||||
}
|
||||
@@ -76,17 +103,19 @@ func begin(db *sql.DB) (trans, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
|
||||
func transact(ctx context.Context, db *commonSqlConn, b beginnable,
|
||||
fn func(context.Context, Session) error) (err error) {
|
||||
conn, err := db.connProv()
|
||||
if err != nil {
|
||||
db.onError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
return transactOnConn(conn, b, fn)
|
||||
return transactOnConn(ctx, conn, b, fn)
|
||||
}
|
||||
|
||||
func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
|
||||
func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
|
||||
fn func(context.Context, Session) error) (err error) {
|
||||
var tx trans
|
||||
tx, err = b(conn)
|
||||
if err != nil {
|
||||
@@ -96,18 +125,18 @@ func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err err
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
|
||||
err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
|
||||
} else {
|
||||
err = fmt.Errorf("recoveer from %#v", p)
|
||||
}
|
||||
} else if err != nil {
|
||||
if e := tx.Rollback(); e != nil {
|
||||
err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
|
||||
err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
|
||||
}
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(tx)
|
||||
return fn(ctx, tx)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
@@ -26,26 +27,50 @@ func (mt *mockTx) Exec(q string, args ...interface{}) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRow(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRows(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Rollback() error {
|
||||
mt.status |= mockRollback
|
||||
return nil
|
||||
@@ -59,18 +84,20 @@ func beginMock(mock *mockTx) beginnable {
|
||||
|
||||
func TestTransactCommit(t *testing.T) {
|
||||
mock := &mockTx{}
|
||||
err := transactOnConn(nil, beginMock(mock), func(Session) error {
|
||||
return nil
|
||||
})
|
||||
err := transactOnConn(context.Background(), nil, beginMock(mock),
|
||||
func(context.Context, Session) error {
|
||||
return nil
|
||||
})
|
||||
assert.Equal(t, mockCommit, mock.status)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestTransactRollback(t *testing.T) {
|
||||
mock := &mockTx{}
|
||||
err := transactOnConn(nil, beginMock(mock), func(Session) error {
|
||||
return errors.New("rollback")
|
||||
})
|
||||
err := transactOnConn(context.Background(), nil, beginMock(mock),
|
||||
func(context.Context, Session) error {
|
||||
return errors.New("rollback")
|
||||
})
|
||||
assert.Equal(t, mockRollback, mock.status)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -109,9 +110,9 @@ func logInstanceError(datasource string, err error) {
|
||||
logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
|
||||
}
|
||||
|
||||
func logSqlError(stmt string, err error) {
|
||||
func logSqlError(ctx context.Context, stmt string, err error) {
|
||||
if err != nil && err != ErrNotFound {
|
||||
logx.Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||
logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
|
||||
)
|
||||
|
||||
// DocCommand generate markdown doc file
|
||||
// DocCommand generate Markdown doc file
|
||||
func DocCommand(c *cli.Context) error {
|
||||
dir := c.String("dir")
|
||||
if len(dir) == 0 {
|
||||
@@ -45,7 +45,7 @@ func DocCommand(c *cli.Context) error {
|
||||
for _, p := range files {
|
||||
api, err := parser.Parse(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse file: %s, err: %s", p, err.Error())
|
||||
return fmt.Errorf("parse file: %s, err: %w", p, err)
|
||||
}
|
||||
|
||||
api.Service = api.Service.JoinPrefix()
|
||||
|
||||
@@ -164,12 +164,12 @@ func writeFile(pkgs []*ast.Package, verbose bool) error {
|
||||
w := bytes.NewBuffer(nil)
|
||||
err := format.Node(w, fset, file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[rewriteImport] format file %s error: %+v", filename, err)
|
||||
return fmt.Errorf("[rewriteImport] format file %s error: %w", filename, err)
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(filename, w.Bytes(), os.ModePerm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("[rewriteImport] write file %s error: %+v", filename, err)
|
||||
return fmt.Errorf("[rewriteImport] write file %s error: %w", filename, err)
|
||||
}
|
||||
if verbose {
|
||||
console.Success("[OK] migrated %q successfully", filepath.Base(filename))
|
||||
|
||||
Reference in New Issue
Block a user