feat: support using session to execute statements in transaction (#3252)

This commit is contained in:
Kevin Wan
2023-05-17 22:15:24 +08:00
committed by GitHub
parent f0bdfb928f
commit bff5b81ad9
11 changed files with 526 additions and 126 deletions

4
.gitignore vendored
View File

@@ -14,9 +14,10 @@
**/.idea
**/.DS_Store
**/logs
**/adhoc
**/coverage.txt
# for test purpose
**/adhoc
go.work
go.work.sum
@@ -27,4 +28,3 @@ go.work.sum
# vim auto backup file
*~
!OWNERS
coverage.txt

View File

@@ -226,3 +226,15 @@ func (cc CachedConn) Transact(fn func(sqlx.Session) error) error {
func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
return cc.db.TransactCtx(ctx, fn)
}
// WithSession returns a new CachedConn with given session.
// If query from session, the uncommitted data might be returned.
// Don't query for the uncommitted data, you should just use it,
// and don't use the cache for the uncommitted data.
// Not recommend to use cache within transactions due to consistency problem.
func (cc CachedConn) WithSession(session sqlx.Session) CachedConn {
return CachedConn{
db: sqlx.NewSqlConnFromSession(session),
cache: cc.cache,
}
}

View File

@@ -15,6 +15,7 @@ import (
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/fx"
@@ -24,6 +25,8 @@ import (
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/syncx"
"github.com/zeromicro/go-zero/internal/dbtest"
)
func init() {
@@ -39,7 +42,7 @@ func TestCachedConn_GetCache(t *testing.T) {
var value string
err := c.GetCache("any", &value)
assert.Equal(t, ErrNotFound, err)
r.Set("any", `"value"`)
_ = r.Set("any", `"value"`)
err = c.GetCache("any", &value)
assert.Nil(t, err)
assert.Equal(t, "value", value)
@@ -368,6 +371,24 @@ func TestStatFromMemory(t *testing.T) {
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
}
func TestCachedConn_DelCache(t *testing.T) {
r := redistest.CreateRedis(t)
const (
key = "user"
value = "any"
)
assert.NoError(t, r.Set(key, value))
c := NewNodeConn(&trackedConn{}, r, cache.WithExpiry(time.Second*30))
err := c.DelCache(key)
assert.Nil(t, err)
val, err := r.Get(key)
assert.Nil(t, err)
assert.Empty(t, val)
}
func TestCachedConnQueryRow(t *testing.T) {
r := redistest.CreateRedis(t)
@@ -543,6 +564,125 @@ func TestNewConnWithCache(t *testing.T) {
assert.True(t, conn.execValue)
}
func TestCachedConn_WithSession(t *testing.T) {
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
r := redistest.CreateRedis(t)
conn := CachedConn{
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
conn = conn.WithSession(sqlx.NewSessionFromTx(tx))
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "foo")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectCommit()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
conn = conn.WithSession(session)
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "foo")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
return nil
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.Error(t, conn.Transact(func(session sqlx.Session) error {
conn = conn.WithSession(session)
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "bar")
return err
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectCommit()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
var val string
conn = conn.WithSession(session)
err := conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
return conn.QueryRow(v, "any")
})
assert.Equal(t, "2", val)
return err
}))
val, err := r.Get("foo")
assert.NoError(t, err)
assert.Equal(t, `"2"`, val)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectCommit()
r := redistest.CreateRedis(t)
conn := CachedConn{
db: sqlx.NewSqlConnFromDB(db),
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
}
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
var val string
conn = conn.WithSession(session)
assert.NoError(t, conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
return conn.QueryRow(v, "any")
}))
assert.Equal(t, "2", val)
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
return conn.Exec("any")
}, "foo")
return err
}))
val, err := r.Get("foo")
assert.NoError(t, err)
assert.Empty(t, val)
})
}
func resetStats() {
atomic.StoreUint64(&stats.Total, 0)
atomic.StoreUint64(&stats.Hit, 0)
@@ -554,35 +694,35 @@ type dummySqlConn struct {
queryRow func(any, string, ...any) error
}
func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
func (d dummySqlConn) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) {
func (d dummySqlConn) PrepareCtx(_ context.Context, _ string) (sqlx.StmtSession, error) {
return nil, nil
}
func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
func (d dummySqlConn) TransactCtx(_ context.Context, _ func(context.Context, sqlx.Session) error) error {
return nil
}
func (d dummySqlConn) Exec(query string, args ...any) (sql.Result, error) {
func (d dummySqlConn) Exec(_ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
func (d dummySqlConn) Prepare(_ string) (sqlx.StmtSession, error) {
return nil, nil
}
@@ -597,15 +737,15 @@ func (d dummySqlConn) QueryRowCtx(_ context.Context, v any, query string, args .
return nil
}
func (d dummySqlConn) QueryRowPartial(v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowPartial(_ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) QueryRows(v any, query string, args ...any) error {
func (d dummySqlConn) QueryRows(_ any, _ string, _ ...any) error {
return nil
}
func (d dummySqlConn) QueryRowsPartial(v any, query string, args ...any) error {
func (d dummySqlConn) QueryRowsPartial(_ any, _ string, _ ...any) error {
return nil
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/internal/dbtest"
)
type mockedConn struct {
@@ -81,7 +81,7 @@ func (c *mockedConn) Transact(func(session Session) error) error {
}
func TestBulkInserter(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`)
assert.Nil(t, err)
@@ -98,7 +98,7 @@ func TestBulkInserter(t *testing.T) {
}
func TestBulkInserterSuffix(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+
`(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`)
@@ -119,7 +119,7 @@ func TestBulkInserterSuffix(t *testing.T) {
}
func TestBulkInserterBadStatement(t *testing.T) {
runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var conn mockedConn
_, err := NewBulkInserter(&conn, "foo")
assert.NotNil(t, err)
@@ -144,19 +144,3 @@ func TestBulkInserter_Update(t *testing.T) {
assert.NotNil(t, inserter.UpdateStmt("foo"))
assert.NotNil(t, inserter.Insert("foo", "bar"))
}
func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

View File

@@ -0,0 +1,14 @@
package sqlx
import (
"database/sql"
"errors"
)
var (
// ErrNotFound is an alias of sql.ErrNoRows
ErrNotFound = sql.ErrNoRows
errCantNestTx = errors.New("cannot nest transactions")
errNoRawDBFromTx = errors.New("cannot get raw db from transaction")
)

View File

@@ -8,11 +8,11 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/internal/dbtest"
)
func TestUnmarshalRowBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -25,7 +25,7 @@ func TestUnmarshalRowBool(t *testing.T) {
}
func TestUnmarshalRowBoolNotSettable(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -37,7 +37,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) {
}
func TestUnmarshalRowInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -50,7 +50,7 @@ func TestUnmarshalRowInt(t *testing.T) {
}
func TestUnmarshalRowInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -63,7 +63,7 @@ func TestUnmarshalRowInt8(t *testing.T) {
}
func TestUnmarshalRowInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -76,7 +76,7 @@ func TestUnmarshalRowInt16(t *testing.T) {
}
func TestUnmarshalRowInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -89,7 +89,7 @@ func TestUnmarshalRowInt32(t *testing.T) {
}
func TestUnmarshalRowInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -102,7 +102,7 @@ func TestUnmarshalRowInt64(t *testing.T) {
}
func TestUnmarshalRowUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -115,7 +115,7 @@ func TestUnmarshalRowUint(t *testing.T) {
}
func TestUnmarshalRowUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -128,7 +128,7 @@ func TestUnmarshalRowUint8(t *testing.T) {
}
func TestUnmarshalRowUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -141,7 +141,7 @@ func TestUnmarshalRowUint16(t *testing.T) {
}
func TestUnmarshalRowUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -154,7 +154,7 @@ func TestUnmarshalRowUint32(t *testing.T) {
}
func TestUnmarshalRowUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -167,7 +167,7 @@ func TestUnmarshalRowUint64(t *testing.T) {
}
func TestUnmarshalRowFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -180,7 +180,7 @@ func TestUnmarshalRowFloat32(t *testing.T) {
}
func TestUnmarshalRowFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -193,7 +193,7 @@ func TestUnmarshalRowFloat64(t *testing.T) {
}
func TestUnmarshalRowString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
const expect = "hello"
rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -212,7 +212,7 @@ func TestUnmarshalRowStruct(t *testing.T) {
Age int
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -230,7 +230,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
Name string `db:"name"`
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -248,7 +248,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
Name string `db:"name"`
})
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -259,7 +259,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) {
}
func TestUnmarshalRowsBool(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []bool{true, false}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -273,7 +273,7 @@ func TestUnmarshalRowsBool(t *testing.T) {
}
func TestUnmarshalRowsInt(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -287,7 +287,7 @@ func TestUnmarshalRowsInt(t *testing.T) {
}
func TestUnmarshalRowsInt8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -301,7 +301,7 @@ func TestUnmarshalRowsInt8(t *testing.T) {
}
func TestUnmarshalRowsInt16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -315,7 +315,7 @@ func TestUnmarshalRowsInt16(t *testing.T) {
}
func TestUnmarshalRowsInt32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -329,7 +329,7 @@ func TestUnmarshalRowsInt32(t *testing.T) {
}
func TestUnmarshalRowsInt64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []int64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -343,7 +343,7 @@ func TestUnmarshalRowsInt64(t *testing.T) {
}
func TestUnmarshalRowsUint(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -357,7 +357,7 @@ func TestUnmarshalRowsUint(t *testing.T) {
}
func TestUnmarshalRowsUint8(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint8{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -371,7 +371,7 @@ func TestUnmarshalRowsUint8(t *testing.T) {
}
func TestUnmarshalRowsUint16(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint16{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -385,7 +385,7 @@ func TestUnmarshalRowsUint16(t *testing.T) {
}
func TestUnmarshalRowsUint32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -399,7 +399,7 @@ func TestUnmarshalRowsUint32(t *testing.T) {
}
func TestUnmarshalRowsUint64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []uint64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -413,7 +413,7 @@ func TestUnmarshalRowsUint64(t *testing.T) {
}
func TestUnmarshalRowsFloat32(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []float32{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -427,7 +427,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) {
}
func TestUnmarshalRowsFloat64(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []float64{2, 3}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -441,7 +441,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) {
}
func TestUnmarshalRowsString(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []string{"hello", "world"}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -457,7 +457,7 @@ func TestUnmarshalRowsString(t *testing.T) {
func TestUnmarshalRowsBoolPtr(t *testing.T) {
yes := true
no := false
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*bool{&yes, &no}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -473,7 +473,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) {
func TestUnmarshalRowsIntPtr(t *testing.T) {
two := 2
three := 3
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -489,7 +489,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) {
func TestUnmarshalRowsInt8Ptr(t *testing.T) {
two := int8(2)
three := int8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -505,7 +505,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) {
func TestUnmarshalRowsInt16Ptr(t *testing.T) {
two := int16(2)
three := int16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -521,7 +521,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) {
func TestUnmarshalRowsInt32Ptr(t *testing.T) {
two := int32(2)
three := int32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -537,7 +537,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) {
func TestUnmarshalRowsInt64Ptr(t *testing.T) {
two := int64(2)
three := int64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*int64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -553,7 +553,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) {
func TestUnmarshalRowsUintPtr(t *testing.T) {
two := uint(2)
three := uint(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -569,7 +569,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) {
func TestUnmarshalRowsUint8Ptr(t *testing.T) {
two := uint8(2)
three := uint8(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint8{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -585,7 +585,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) {
func TestUnmarshalRowsUint16Ptr(t *testing.T) {
two := uint16(2)
three := uint16(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint16{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -601,7 +601,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) {
func TestUnmarshalRowsUint32Ptr(t *testing.T) {
two := uint32(2)
three := uint32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -617,7 +617,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) {
func TestUnmarshalRowsUint64Ptr(t *testing.T) {
two := uint64(2)
three := uint64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*uint64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -633,7 +633,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) {
func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
two := float32(2)
three := float32(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*float32{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -649,7 +649,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) {
func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
two := float64(2)
three := float64(3)
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*float64{&two, &three}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -665,7 +665,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) {
func TestUnmarshalRowsStringPtr(t *testing.T) {
hello := "hello"
world := "world"
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
expect := []*string{&hello, &world}
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -697,7 +697,7 @@ func TestUnmarshalRowsStruct(t *testing.T) {
Age int64
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -736,7 +736,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) {
NullString sql.NullString `db:"value"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "value"}).AddRow(
"first", "firstnullstring").AddRow("second", nil)
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -771,7 +771,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -812,7 +812,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) {
Embed
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -854,7 +854,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T)
*Embed
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -888,7 +888,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) {
Age int64
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -921,7 +921,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -954,7 +954,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error {
@@ -969,7 +969,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) {
}
func TestCommonSqlConn_QueryRowOptional(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -1019,7 +1019,7 @@ func TestUnmarshalRowError(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs(
"anyone").WillReturnRows(rs)
@@ -1091,7 +1091,7 @@ func TestAnonymousStructPr(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{
"name",
"age",
@@ -1139,7 +1139,7 @@ func TestAnonymousStructPrError(t *testing.T) {
Name string `db:"name"`
}
runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{
"name",
"age",
@@ -1154,7 +1154,7 @@ func TestAnonymousStructPrError(t *testing.T) {
WithArgs("anyone").WillReturnRows(rs)
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age,grade,discipline,class_name,score from users where user=?",
}, "select name, age, grade, discipline, class_name, score from users where user=?",
"anyone"))
if len(value) > 0 {
assert.Equal(t, value[0].score, 0)
@@ -1162,22 +1162,6 @@ func TestAnonymousStructPrError(t *testing.T) {
})
}
func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
logx.Disable()
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer db.Close()
fn(db, mock)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
type mockedScanner struct {
colErr error
scanErr error

View File

@@ -11,9 +11,6 @@ import (
// spanName is used to identify the span name for the SQL execution.
const spanName = "sql"
// ErrNotFound is an alias of sql.ErrNoRows
var ErrNotFound = sql.ErrNoRows
type (
// Session stands for raw connections or transaction sessions
Session interface {
@@ -131,6 +128,13 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn {
return conn
}
// NewSqlConnFromSession returns a SqlConn with the given session.
func NewSqlConnFromSession(session Session) SqlConn {
return txConn{
Session: session,
}
}
func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) {
return db.ExecCtx(context.Background(), q, args...)
}

View File

@@ -55,7 +55,7 @@ func TestSqlConn(t *testing.T) {
}
func buildConn() (mock sqlmock.Sqlmock, err error) {
connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB
var err error
db, mock, err = sqlmock.New()

View File

@@ -15,11 +15,27 @@ type (
Rollback() error
}
txConn struct {
Session
}
txSession struct {
*sql.Tx
}
)
func (s txConn) RawDB() (*sql.DB, error) {
return nil, errNoRawDBFromTx
}
func (s txConn) Transact(_ func(Session) error) error {
return errCantNestTx
}
func (s txConn) TransactCtx(_ context.Context, _ func(context.Context, Session) error) error {
return errCantNestTx
}
// NewSessionFromTx returns a Session with the given sql.Tx.
// Use it with caution, it's provided for other ORM to interact with.
func NewSessionFromTx(tx *sql.Tx) Session {

View File

@@ -6,7 +6,10 @@ import (
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/internal/dbtest"
)
const (
@@ -23,51 +26,51 @@ func (mt *mockTx) Commit() error {
return nil
}
func (mt *mockTx) Exec(q string, args ...any) (sql.Result, error) {
func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
func (mt *mockTx) Prepare(_ string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) QueryRow(v any, q string, args ...any) error {
func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowPartial(v any, q string, args ...any) error {
func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRows(v any, q string, args ...any) error {
func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsPartial(v any, q string, args ...any) error {
func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
@@ -101,3 +104,209 @@ func TestTransactRollback(t *testing.T) {
assert.Equal(t, mockRollback, mock.status)
assert.NotNil(t, err)
}
func TestTxExceptions(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
assert.NoError(t, conn.Transact(func(session Session) error {
return nil
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return nil, errors.New("foo")
},
beginTx: begin,
onError: func(ctx context.Context, err error) {},
brk: breaker.NewBreaker(),
}
assert.Error(t, conn.Transact(func(session Session) error {
return nil
}))
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
_, err := conn.RawDB()
assert.Equal(t, errNoRawDBFromTx, err)
assert.Equal(t, errCantNestTx, conn.Transact(nil))
assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
return errors.New("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback().WillReturnError(errors.New("foo"))
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic(errors.New("foo"))
}))
})
}
func TestTxSession(t *testing.T) {
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
res, err := conn.Exec("any")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
_, err = conn.Exec("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
assert.NotNil(t, stmt)
mock.ExpectPrepare("any").WillReturnError(errors.New("foo"))
_, err = conn.Prepare("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRow(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRow(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRowPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRows(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRows(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRowsPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowsPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
}
func TestTxRollback(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
return c.QueryRow(&val, "foo")
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
if _, err := c.Exec("any"); err != nil {
return err
}
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
return nil
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar"))
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
assert.Equal(t, "bar", val)
return nil
})
assert.NoError(t, err)
})
}
func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
sess := NewSessionFromTx(tx)
conn := NewSqlConnFromSession(sess)
f(conn, mock)
})
}

37
internal/dbtest/sql.go Normal file
View File

@@ -0,0 +1,37 @@
package dbtest
import (
"database/sql"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
)
// RunTest runs a test function with a mock database.
func RunTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer func() {
_ = db.Close()
}()
fn(db, mock)
if err = mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
// RunTxTest runs a test function with a mock database in a transaction.
func RunTxTest(t *testing.T, f func(tx *sql.Tx, mock sqlmock.Sqlmock)) {
RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
tx, err := db.Begin()
if assert.NoError(t, err) {
f(tx, mock)
}
})
}