feat: support using session to execute statements in transaction (#3252)

This commit is contained in:
Kevin Wan
2023-05-17 22:15:24 +08:00
committed by GitHub
parent f0bdfb928f
commit bff5b81ad9
11 changed files with 526 additions and 126 deletions

4
.gitignore vendored
View File

@@ -14,9 +14,10 @@
**/.idea **/.idea
**/.DS_Store **/.DS_Store
**/logs **/logs
**/adhoc
**/coverage.txt
# for test purpose # for test purpose
**/adhoc
go.work go.work
go.work.sum go.work.sum
@@ -27,4 +28,3 @@ go.work.sum
# vim auto backup file # vim auto backup file
*~ *~
!OWNERS !OWNERS
coverage.txt

View File

@@ -226,3 +226,15 @@ func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error { func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
return cc.db.TransactCtx(ctx, fn) return cc.db.TransactCtx(ctx, fn)
} }
// WithSession returns a new CachedConn with given session.
// If query from session, the uncommitted data might be returned.
// Don't query for the uncommitted data, you should just use it,
// and don't use the cache for the uncommitted data.
// Not recommend to use cache within transactions due to consistency problem.
func (cc CachedConn) WithSession(session sqlx.Session) CachedConn {
return CachedConn{
db: sqlx.NewSqlConnFromSession(session),
cache: cc.cache,
}
}

View File

@@ -15,6 +15,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/alicebob/miniredis/v2" "github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/fx" "github.com/zeromicro/go-zero/core/fx"
@@ -24,6 +25,8 @@ import (
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/redis/redistest" "github.com/zeromicro/go-zero/core/stores/redis/redistest"
"github.com/zeromicro/go-zero/core/stores/sqlx" "github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/internal/dbtest"
) )
func init() { func init() {
@@ -39,7 +42,7 @@ func TestCachedConn_GetCache(t *testing.T) {
var value string var value string
err := c.GetCache("any", &value) err := c.GetCache("any", &value)
assert.Equal(t, ErrNotFound, err) assert.Equal(t, ErrNotFound, err)
r.Set("any", `"value"`) _ = r.Set("any", `"value"`)
err = c.GetCache("any", &value) err = c.GetCache("any", &value)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "value", value) assert.Equal(t, "value", value)
@@ -368,6 +371,24 @@ func TestStatFromMemory(t *testing.T) {
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit)) assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
} }
func TestCachedConn_DelCache(t *testing.T) {
r := redistest.CreateRedis(t)
const (
key = "user"
value = "any"
)
assert.NoError(t, r.Set(key, value))
c := NewNodeConn(&trackedConn{}, r, cache.WithExpiry(time.Second*30))
err := c.DelCache(key)
assert.Nil(t, err)
val, err := r.Get(key)
assert.Nil(t, err)
assert.Empty(t, val)
}
func TestCachedConnQueryRow(t *testing.T) { func TestCachedConnQueryRow(t *testing.T) {
r := redistest.CreateRedis(t) r := redistest.CreateRedis(t)
@@ -543,6 +564,125 @@ func TestNewConnWithCache(t *testing.T) {
assert.True(t, conn.execValue) assert.True(t, conn.execValue)
} }
func TestCachedConn_WithSession(t *testing.T) {
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
r := redistest.CreateRedis(t)
conn := CachedConn{
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
conn = conn.WithSession(sqlx.NewSessionFromTx(tx))
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "foo")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectCommit()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
conn = conn.WithSession(session)
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "foo")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
return nil
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.Error(t, conn.Transact(func(session sqlx.Session) error {
conn = conn.WithSession(session)
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "bar")
return err
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectCommit()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
var val string
conn = conn.WithSession(session)
err := conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
return conn.QueryRow(v, "any")
})
assert.Equal(t, "2", val)
return err
}))
val, err := r.Get("foo")
assert.NoError(t, err)
assert.Equal(t, `"2"`, val)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectCommit()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
var val string
conn = conn.WithSession(session)
assert.NoError(t, conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
return conn.QueryRow(v, "any")
}))
assert.Equal(t, "2", val)
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "foo")
return err
}))
val, err := r.Get("foo")
assert.NoError(t, err)
assert.Empty(t, val)
})
}
func resetStats() { func resetStats() {
atomic.StoreUint64(&stats.Total, 0) atomic.StoreUint64(&stats.Total, 0)
atomic.StoreUint64(&stats.Hit, 0) atomic.StoreUint64(&stats.Hit, 0)
@@ -554,35 +694,35 @@ type dummySqlConn struct {
queryRow func(any, string, ...any) error queryRow func(any, string, ...any) error
} }
func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) { func (d dummySqlConn) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
return nil, nil return nil, nil
} }
func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) { func (d dummySqlConn) PrepareCtx(_ context.Context, _ string) (sqlx.StmtSession, error) {
return nil, nil return nil, nil
} }
func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error { func (d dummySqlConn) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil return nil
} }
func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error { func (d dummySqlConn) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil return nil
} }
func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error { func (d dummySqlConn) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil return nil
} }
func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error { func (d dummySqlConn) TransactCtx(_ context.Context, _ func(context.Context, sqlx.Session) error) error {
return nil return nil
} }
func (d dummySqlConn) Exec(query string, args ...any) (sql.Result, error) { func (d dummySqlConn) Exec(_ string, _ ...any) (sql.Result, error) {
return nil, nil return nil, nil
} }
func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) { func (d dummySqlConn) Prepare(_ string) (sqlx.StmtSession, error) {
return nil, nil return nil, nil
} }
@@ -597,15 +737,15 @@ func (d dummySqlConn) QueryRowCtx(_ context.Context, v any, query string, args .
return nil return nil
} }
func (d dummySqlConn) QueryRowPartial(v any, query string, args ...any) error { func (d dummySqlConn) QueryRowPartial(_ any, _ string, _ ...any) error {
return nil return nil
} }
func (d dummySqlConn) QueryRows(v any, query string, args ...any) error { func (d dummySqlConn) QueryRows(_ any, _ string, _ ...any) error {
return nil return nil
} }
func (d dummySqlConn) QueryRowsPartial(v any, query string, args ...any) error { func (d dummySqlConn) QueryRowsPartial(_ any, _ string, _ ...any) error {
return nil return nil
} }

View File

@@ -9,7 +9,7 @@ import (
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/internal/dbtest"
) )
type mockedConn struct { type mockedConn struct {
@@ -81,7 +81,7 @@ func (c *mockedConn) Transact(func(session Session) error) error {
} }
func TestBulkInserter(t *testing.T) { func TestBulkInserter(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`) inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
assert.Nil(t, err) assert.Nil(t, err)
@@ -98,7 +98,7 @@ func TestBulkInserter(t *testing.T) {
} }
func TestBulkInserterSuffix(t *testing.T) { func TestBulkInserterSuffix(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+ inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`) `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
@@ -119,7 +119,7 @@ func TestBulkInserterSuffix(t *testing.T) {
} }
func TestBulkInserterBadStatement(t *testing.T) { func TestBulkInserterBadStatement(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn var conn mockedConn
_, err := NewBulkInserter(&conn, "foo") _, err := NewBulkInserter(&conn, "foo")
assert.NotNil(t, err) assert.NotNil(t, err)
@@ -144,19 +144,3 @@ func TestBulkInserter_Update(t *testing.T) {
assert.NotNil(t, inserter.UpdateStmt("foo")) assert.NotNil(t, inserter.UpdateStmt("foo"))
assert.NotNil(t, inserter.Insert("foo", "bar")) assert.NotNil(t, inserter.Insert("foo", "bar"))
} }
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

View File

@@ -0,0 +1,14 @@
package sqlx
import (
"database/sql"
"errors"
)
var (
// ErrNotFound is an alias of sql.ErrNoRows
ErrNotFound = sql.ErrNoRows
errCantNestTx = errors.New("cannot nest transactions")
errNoRawDBFromTx = errors.New("cannot get raw db from transaction")
)

View File

@@ -8,11 +8,11 @@ import (
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/internal/dbtest"
) )
func TestUnmarshalRowBool(t *testing.T) { func TestUnmarshalRowBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -25,7 +25,7 @@ func TestUnmarshalRowBool(t *testing.T) {
} }
func TestUnmarshalRowBoolNotSettable(t *testing.T) { func TestUnmarshalRowBoolNotSettable(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -37,7 +37,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
} }
func TestUnmarshalRowInt(t *testing.T) { func TestUnmarshalRowInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -50,7 +50,7 @@ func TestUnmarshalRowInt(t *testing.T) {
} }
func TestUnmarshalRowInt8(t *testing.T) { func TestUnmarshalRowInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -63,7 +63,7 @@ func TestUnmarshalRowInt8(t *testing.T) {
} }
func TestUnmarshalRowInt16(t *testing.T) { func TestUnmarshalRowInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -76,7 +76,7 @@ func TestUnmarshalRowInt16(t *testing.T) {
} }
func TestUnmarshalRowInt32(t *testing.T) { func TestUnmarshalRowInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -89,7 +89,7 @@ func TestUnmarshalRowInt32(t *testing.T) {
} }
func TestUnmarshalRowInt64(t *testing.T) { func TestUnmarshalRowInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -102,7 +102,7 @@ func TestUnmarshalRowInt64(t *testing.T) {
} }
func TestUnmarshalRowUint(t *testing.T) { func TestUnmarshalRowUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -115,7 +115,7 @@ func TestUnmarshalRowUint(t *testing.T) {
} }
func TestUnmarshalRowUint8(t *testing.T) { func TestUnmarshalRowUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -128,7 +128,7 @@ func TestUnmarshalRowUint8(t *testing.T) {
} }
func TestUnmarshalRowUint16(t *testing.T) { func TestUnmarshalRowUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -141,7 +141,7 @@ func TestUnmarshalRowUint16(t *testing.T) {
} }
func TestUnmarshalRowUint32(t *testing.T) { func TestUnmarshalRowUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -154,7 +154,7 @@ func TestUnmarshalRowUint32(t *testing.T) {
} }
func TestUnmarshalRowUint64(t *testing.T) { func TestUnmarshalRowUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -167,7 +167,7 @@ func TestUnmarshalRowUint64(t *testing.T) {
} }
func TestUnmarshalRowFloat32(t *testing.T) { func TestUnmarshalRowFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -180,7 +180,7 @@ func TestUnmarshalRowFloat32(t *testing.T) {
} }
func TestUnmarshalRowFloat64(t *testing.T) { func TestUnmarshalRowFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -193,7 +193,7 @@ func TestUnmarshalRowFloat64(t *testing.T) {
} }
func TestUnmarshalRowString(t *testing.T) { func TestUnmarshalRowString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
const expect = "hello" const expect = "hello"
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect) rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -212,7 +212,7 @@ func TestUnmarshalRowStruct(t *testing.T) {
Age int Age int
}) })
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -230,7 +230,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
Name string `db:"name"` Name string `db:"name"`
}) })
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -248,7 +248,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
Name string `db:"name"` Name string `db:"name"`
}) })
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao") rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -259,7 +259,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
} }
func TestUnmarshalRowsBool(t *testing.T) { func TestUnmarshalRowsBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []bool{true, false} expect := []bool{true, false}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -273,7 +273,7 @@ func TestUnmarshalRowsBool(t *testing.T) {
} }
func TestUnmarshalRowsInt(t *testing.T) { func TestUnmarshalRowsInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int{2, 3} expect := []int{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -287,7 +287,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
} }
func TestUnmarshalRowsInt8(t *testing.T) { func TestUnmarshalRowsInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int8{2, 3} expect := []int8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -301,7 +301,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
} }
func TestUnmarshalRowsInt16(t *testing.T) { func TestUnmarshalRowsInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int16{2, 3} expect := []int16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -315,7 +315,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
} }
func TestUnmarshalRowsInt32(t *testing.T) { func TestUnmarshalRowsInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int32{2, 3} expect := []int32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -329,7 +329,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
} }
func TestUnmarshalRowsInt64(t *testing.T) { func TestUnmarshalRowsInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int64{2, 3} expect := []int64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -343,7 +343,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
} }
func TestUnmarshalRowsUint(t *testing.T) { func TestUnmarshalRowsUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint{2, 3} expect := []uint{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -357,7 +357,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
} }
func TestUnmarshalRowsUint8(t *testing.T) { func TestUnmarshalRowsUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint8{2, 3} expect := []uint8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -371,7 +371,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
} }
func TestUnmarshalRowsUint16(t *testing.T) { func TestUnmarshalRowsUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint16{2, 3} expect := []uint16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -385,7 +385,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
} }
func TestUnmarshalRowsUint32(t *testing.T) { func TestUnmarshalRowsUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint32{2, 3} expect := []uint32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -399,7 +399,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
} }
func TestUnmarshalRowsUint64(t *testing.T) { func TestUnmarshalRowsUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint64{2, 3} expect := []uint64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -413,7 +413,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
} }
func TestUnmarshalRowsFloat32(t *testing.T) { func TestUnmarshalRowsFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []float32{2, 3} expect := []float32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -427,7 +427,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
} }
func TestUnmarshalRowsFloat64(t *testing.T) { func TestUnmarshalRowsFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []float64{2, 3} expect := []float64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -441,7 +441,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
} }
func TestUnmarshalRowsString(t *testing.T) { func TestUnmarshalRowsString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []string{"hello", "world"} expect := []string{"hello", "world"}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -457,7 +457,7 @@ func TestUnmarshalRowsString(t *testing.T) {
func TestUnmarshalRowsBoolPtr(t *testing.T) { func TestUnmarshalRowsBoolPtr(t *testing.T) {
yes := true yes := true
no := false no := false
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*bool{&yes, &no} expect := []*bool{&yes, &no}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -473,7 +473,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
func TestUnmarshalRowsIntPtr(t *testing.T) { func TestUnmarshalRowsIntPtr(t *testing.T) {
two := 2 two := 2
three := 3 three := 3
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int{&two, &three} expect := []*int{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -489,7 +489,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
func TestUnmarshalRowsInt8Ptr(t *testing.T) { func TestUnmarshalRowsInt8Ptr(t *testing.T) {
two := int8(2) two := int8(2)
three := int8(3) three := int8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int8{&two, &three} expect := []*int8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -505,7 +505,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
func TestUnmarshalRowsInt16Ptr(t *testing.T) { func TestUnmarshalRowsInt16Ptr(t *testing.T) {
two := int16(2) two := int16(2)
three := int16(3) three := int16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int16{&two, &three} expect := []*int16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -521,7 +521,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
func TestUnmarshalRowsInt32Ptr(t *testing.T) { func TestUnmarshalRowsInt32Ptr(t *testing.T) {
two := int32(2) two := int32(2)
three := int32(3) three := int32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int32{&two, &three} expect := []*int32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -537,7 +537,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
func TestUnmarshalRowsInt64Ptr(t *testing.T) { func TestUnmarshalRowsInt64Ptr(t *testing.T) {
two := int64(2) two := int64(2)
three := int64(3) three := int64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int64{&two, &three} expect := []*int64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -553,7 +553,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
func TestUnmarshalRowsUintPtr(t *testing.T) { func TestUnmarshalRowsUintPtr(t *testing.T) {
two := uint(2) two := uint(2)
three := uint(3) three := uint(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint{&two, &three} expect := []*uint{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -569,7 +569,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
func TestUnmarshalRowsUint8Ptr(t *testing.T) { func TestUnmarshalRowsUint8Ptr(t *testing.T) {
two := uint8(2) two := uint8(2)
three := uint8(3) three := uint8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint8{&two, &three} expect := []*uint8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -585,7 +585,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
func TestUnmarshalRowsUint16Ptr(t *testing.T) { func TestUnmarshalRowsUint16Ptr(t *testing.T) {
two := uint16(2) two := uint16(2)
three := uint16(3) three := uint16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint16{&two, &three} expect := []*uint16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -601,7 +601,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
func TestUnmarshalRowsUint32Ptr(t *testing.T) { func TestUnmarshalRowsUint32Ptr(t *testing.T) {
two := uint32(2) two := uint32(2)
three := uint32(3) three := uint32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint32{&two, &three} expect := []*uint32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -617,7 +617,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
func TestUnmarshalRowsUint64Ptr(t *testing.T) { func TestUnmarshalRowsUint64Ptr(t *testing.T) {
two := uint64(2) two := uint64(2)
three := uint64(3) three := uint64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint64{&two, &three} expect := []*uint64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -633,7 +633,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
func TestUnmarshalRowsFloat32Ptr(t *testing.T) { func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
two := float32(2) two := float32(2)
three := float32(3) three := float32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*float32{&two, &three} expect := []*float32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -649,7 +649,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
func TestUnmarshalRowsFloat64Ptr(t *testing.T) { func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
two := float64(2) two := float64(2)
three := float64(3) three := float64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*float64{&two, &three} expect := []*float64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -665,7 +665,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
func TestUnmarshalRowsStringPtr(t *testing.T) { func TestUnmarshalRowsStringPtr(t *testing.T) {
hello := "hello" hello := "hello"
world := "world" world := "world"
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*string{&hello, &world} expect := []*string{&hello, &world}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld") rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -697,7 +697,7 @@ func TestUnmarshalRowsStruct(t *testing.T) {
Age int64 Age int64
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -736,7 +736,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
NullString sql.NullString `db:"value"` NullString sql.NullString `db:"value"`
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow( rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
"first", "firstnullstring").AddRow("second", nil) "first", "firstnullstring").AddRow("second", nil)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -771,7 +771,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
Name string `db:"name"` Name string `db:"name"`
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -812,7 +812,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
Embed Embed
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -854,7 +854,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
*Embed *Embed
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -888,7 +888,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
Age int64 Age int64
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -921,7 +921,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
Name string `db:"name"` Name string `db:"name"`
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -954,7 +954,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
Name string `db:"name"` Name string `db:"name"`
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -969,7 +969,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
} }
func TestCommonSqlConn_QueryRowOptional(t *testing.T) { func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5") rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -1019,7 +1019,7 @@ func TestUnmarshalRowError(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5") rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs( mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
"anyone").WillReturnRows(rs) "anyone").WillReturnRows(rs)
@@ -1091,7 +1091,7 @@ func TestAnonymousStructPr(t *testing.T) {
Name string `db:"name"` Name string `db:"name"`
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{ rs := sqlmock.NewRows([]string{
"name", "name",
"age", "age",
@@ -1139,7 +1139,7 @@ func TestAnonymousStructPrError(t *testing.T) {
Name string `db:"name"` Name string `db:"name"`
} }
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{ rs := sqlmock.NewRows([]string{
"name", "name",
"age", "age",
@@ -1154,7 +1154,7 @@ func TestAnonymousStructPrError(t *testing.T) {
WithArgs("anyone").WillReturnRows(rs) WithArgs("anyone").WillReturnRows(rs)
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error { assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true) return unmarshalRows(&value, rows, true)
}, "select name, age,grade,discipline,class_name,score from users where user=?", }, "select name, age, grade, discipline, class_name, score from users where user=?",
"anyone")) "anyone"))
if len(value) > 0 { if len(value) > 0 {
assert.Equal(t, value[0].score, 0) assert.Equal(t, value[0].score, 0)
@@ -1162,22 +1162,6 @@ func TestAnonymousStructPrError(t *testing.T) {
}) })
} }
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
type mockedScanner struct { type mockedScanner struct {
colErr error colErr error
scanErr error scanErr error

View File

@@ -11,9 +11,6 @@ import (
// spanName is used to identify the span name for the SQL execution. // spanName is used to identify the span name for the SQL execution.
const spanName = "sql" const spanName = "sql"
// ErrNotFound is an alias of sql.ErrNoRows
var ErrNotFound = sql.ErrNoRows
type ( type (
// Session stands for raw connections or transaction sessions // Session stands for raw connections or transaction sessions
Session interface { Session interface {
@@ -131,6 +128,13 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
return conn return conn
} }
// NewSqlConnFromSession returns a SqlConn with the given session.
func NewSqlConnFromSession(session Session) SqlConn {
return txConn{
Session: session,
}
}
func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) { func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
return db.ExecCtx(context.Background(), q, args...) return db.ExecCtx(context.Background(), q, args...)
} }

View File

@@ -55,7 +55,7 @@ func TestSqlConn(t *testing.T) {
} }
func buildConn() (mock sqlmock.Sqlmock, err error) { func buildConn() (mock sqlmock.Sqlmock, err error) {
connManager.GetResource(mockedDatasource, func() (io.Closer, error) { _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB var db *sql.DB
var err error var err error
db, mock, err = sqlmock.New() db, mock, err = sqlmock.New()

View File

@@ -15,11 +15,27 @@ type (
Rollback() error Rollback() error
} }
txConn struct {
Session
}
txSession struct { txSession struct {
*sql.Tx *sql.Tx
} }
) )
func (s txConn) RawDB() (*sql.DB, error) {
return nil, errNoRawDBFromTx
}
func (s txConn) Transact(_ func(Session) error) error {
return errCantNestTx
}
func (s txConn) TransactCtx(_ context.Context, _ func(context.Context, Session) error) error {
return errCantNestTx
}
// NewSessionFromTx returns a Session with the given sql.Tx. // NewSessionFromTx returns a Session with the given sql.Tx.
// Use it with caution, it's provided for other ORM to interact with. // Use it with caution, it's provided for other ORM to interact with.
func NewSessionFromTx(tx *sql.Tx) Session { func NewSessionFromTx(tx *sql.Tx) Session {

View File

@@ -6,7 +6,10 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/internal/dbtest"
) )
const ( const (
@@ -23,51 +26,51 @@ func (mt *mockTx) Commit() error {
return nil return nil
} }
func (mt *mockTx) Exec(q string, args ...any) (sql.Result, error) { func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) {
return nil, nil return nil, nil
} }
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) { func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
return nil, nil return nil, nil
} }
func (mt *mockTx) Prepare(query string) (StmtSession, error) { func (mt *mockTx) Prepare(_ string) (StmtSession, error) {
return nil, nil return nil, nil
} }
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) { func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) {
return nil, nil return nil, nil
} }
func (mt *mockTx) QueryRow(v any, q string, args ...any) error { func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error {
return nil return nil
} }
func (mt *mockTx) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error { func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil return nil
} }
func (mt *mockTx) QueryRowPartial(v any, q string, args ...any) error { func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error {
return nil return nil
} }
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error { func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil return nil
} }
func (mt *mockTx) QueryRows(v any, q string, args ...any) error { func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error {
return nil return nil
} }
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error { func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil return nil
} }
func (mt *mockTx) QueryRowsPartial(v any, q string, args ...any) error { func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error {
return nil return nil
} }
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error { func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil return nil
} }
@@ -101,3 +104,209 @@ func TestTransactRollback(t *testing.T) {
assert.Equal(t, mockRollback, mock.status) assert.Equal(t, mockRollback, mock.status)
assert.NotNil(t, err) assert.NotNil(t, err)
} }
func TestTxExceptions(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
assert.NoError(t, conn.Transact(func(session Session) error {
return nil
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return nil, errors.New("foo")
},
beginTx: begin,
onError: func(ctx context.Context, err error) {},
brk: breaker.NewBreaker(),
}
assert.Error(t, conn.Transact(func(session Session) error {
return nil
}))
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
_, err := conn.RawDB()
assert.Equal(t, errNoRawDBFromTx, err)
assert.Equal(t, errCantNestTx, conn.Transact(nil))
assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
return errors.New("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback().WillReturnError(errors.New("foo"))
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic(errors.New("foo"))
}))
})
}
func TestTxSession(t *testing.T) {
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
res, err := conn.Exec("any")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
_, err = conn.Exec("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
assert.NotNil(t, stmt)
mock.ExpectPrepare("any").WillReturnError(errors.New("foo"))
_, err = conn.Prepare("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRow(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRow(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRowPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRows(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRows(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRowsPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowsPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
}
func TestTxRollback(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
return c.QueryRow(&val, "foo")
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
if _, err := c.Exec("any"); err != nil {
return err
}
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
return nil
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar"))
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
assert.Equal(t, "bar", val)
return nil
})
assert.NoError(t, err)
})
}
func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
sess := NewSessionFromTx(tx)
conn := NewSqlConnFromSession(sess)
f(conn, mock)
})
}

37
internal/dbtest/sql.go Normal file
View File

@@ -0,0 +1,37 @@
package dbtest
import (
"database/sql"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
// RunTest runs a test function with a mock database.
func RunTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer func() {
_ = db.Close()
}()
fn(db, mock)
if err = mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
// RunTxTest runs a test function with a mock database in a transaction.
func RunTxTest(t *testing.T, f func(tx *sql.Tx, mock sqlmock.Sqlmock)) {
RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
tx, err := db.Begin()
if assert.NoError(t, err) {
f(tx, mock)
}
})
}