From bff5b81ad952f5a3068fae9daa2d64ec2f8bed15 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Wed, 17 May 2023 22:15:24 +0800 Subject: [PATCH] feat: support using session to execute statements in transaction (#3252) --- .gitignore | 4 +- core/stores/sqlc/cachedsql.go | 12 ++ core/stores/sqlc/cachedsql_test.go | 164 ++++++++++++++++-- core/stores/sqlx/bulkinserter_test.go | 24 +-- core/stores/sqlx/errors.go | 14 ++ core/stores/sqlx/orm_test.go | 136 +++++++-------- core/stores/sqlx/sqlconn.go | 10 +- core/stores/sqlx/sqlconn_test.go | 2 +- core/stores/sqlx/tx.go | 16 ++ core/stores/sqlx/tx_test.go | 233 ++++++++++++++++++++++++-- internal/dbtest/sql.go | 37 ++++ 11 files changed, 526 insertions(+), 126 deletions(-) create mode 100644 core/stores/sqlx/errors.go create mode 100644 internal/dbtest/sql.go diff --git a/.gitignore b/.gitignore index b38171f7..c306dab2 100644 --- a/.gitignore +++ b/.gitignore @@ -14,9 +14,10 @@ **/.idea **/.DS_Store **/logs +**/adhoc +**/coverage.txt # for test purpose -**/adhoc go.work go.work.sum @@ -27,4 +28,3 @@ go.work.sum # vim auto backup file *~ !OWNERS -coverage.txt diff --git a/core/stores/sqlc/cachedsql.go b/core/stores/sqlc/cachedsql.go index 52e5ee04..27692636 100644 --- a/core/stores/sqlc/cachedsql.go +++ b/core/stores/sqlc/cachedsql.go @@ -226,3 +226,15 @@ func (cc CachedConn) Transact(fn func(sqlx.Session) error) error { func (cc CachedConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error { return cc.db.TransactCtx(ctx, fn) } + +// WithSession returns a new CachedConn with given session. +// If query from session, the uncommitted data might be returned. +// Don't query for the uncommitted data, you should just use it, +// and don't use the cache for the uncommitted data. +// Not recommend to use cache within transactions due to consistency problem. +func (cc CachedConn) WithSession(session sqlx.Session) CachedConn { + return CachedConn{ + db: sqlx.NewSqlConnFromSession(session), + cache: cc.cache, + } +} diff --git a/core/stores/sqlc/cachedsql_test.go b/core/stores/sqlc/cachedsql_test.go index 5e97af26..064199cb 100644 --- a/core/stores/sqlc/cachedsql_test.go +++ b/core/stores/sqlc/cachedsql_test.go @@ -15,6 +15,7 @@ import ( "testing" "time" + "github.com/DATA-DOG/go-sqlmock" "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/fx" @@ -24,6 +25,8 @@ import ( "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/stores/redis/redistest" "github.com/zeromicro/go-zero/core/stores/sqlx" + "github.com/zeromicro/go-zero/core/syncx" + "github.com/zeromicro/go-zero/internal/dbtest" ) func init() { @@ -39,7 +42,7 @@ func TestCachedConn_GetCache(t *testing.T) { var value string err := c.GetCache("any", &value) assert.Equal(t, ErrNotFound, err) - r.Set("any", `"value"`) + _ = r.Set("any", `"value"`) err = c.GetCache("any", &value) assert.Nil(t, err) assert.Equal(t, "value", value) @@ -368,6 +371,24 @@ func TestStatFromMemory(t *testing.T) { assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit)) } +func TestCachedConn_DelCache(t *testing.T) { + r := redistest.CreateRedis(t) + + const ( + key = "user" + value = "any" + ) + assert.NoError(t, r.Set(key, value)) + + c := NewNodeConn(&trackedConn{}, r, cache.WithExpiry(time.Second*30)) + err := c.DelCache(key) + assert.Nil(t, err) + + val, err := r.Get(key) + assert.Nil(t, err) + assert.Empty(t, val) +} + func TestCachedConnQueryRow(t *testing.T) { r := redistest.CreateRedis(t) @@ -543,6 +564,125 @@ func TestNewConnWithCache(t *testing.T) { assert.True(t, conn.execValue) } +func TestCachedConn_WithSession(t *testing.T) { + dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) { + mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) + + r := redistest.CreateRedis(t) + conn := CachedConn{ + cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows), + } + conn = conn.WithSession(sqlx.NewSessionFromTx(tx)) + res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) { + return conn.Exec("any") + }, "foo") + assert.NoError(t, err) + last, err := res.LastInsertId() + assert.NoError(t, err) + assert.Equal(t, int64(2), last) + affected, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(3), affected) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) + mock.ExpectCommit() + + r := redistest.CreateRedis(t) + conn := CachedConn{ + db: sqlx.NewSqlConnFromDB(db), + cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows), + } + assert.NoError(t, conn.Transact(func(session sqlx.Session) error { + conn = conn.WithSession(session) + res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) { + return conn.Exec("any") + }, "foo") + assert.NoError(t, err) + last, err := res.LastInsertId() + assert.NoError(t, err) + assert.Equal(t, int64(2), last) + affected, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(3), affected) + return nil + })) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("any").WillReturnError(errors.New("foo")) + mock.ExpectRollback() + + r := redistest.CreateRedis(t) + conn := CachedConn{ + db: sqlx.NewSqlConnFromDB(db), + cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows), + } + assert.Error(t, conn.Transact(func(session sqlx.Session) error { + conn = conn.WithSession(session) + _, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) { + return conn.Exec("any") + }, "bar") + return err + })) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) + mock.ExpectCommit() + + r := redistest.CreateRedis(t) + conn := CachedConn{ + db: sqlx.NewSqlConnFromDB(db), + cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows), + } + assert.NoError(t, conn.Transact(func(session sqlx.Session) error { + var val string + conn = conn.WithSession(session) + err := conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error { + return conn.QueryRow(v, "any") + }) + assert.Equal(t, "2", val) + return err + })) + val, err := r.Get("foo") + assert.NoError(t, err) + assert.Equal(t, `"2"`, val) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2)) + mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) + mock.ExpectCommit() + + r := redistest.CreateRedis(t) + conn := CachedConn{ + db: sqlx.NewSqlConnFromDB(db), + cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows), + } + assert.NoError(t, conn.Transact(func(session sqlx.Session) error { + var val string + conn = conn.WithSession(session) + assert.NoError(t, conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error { + return conn.QueryRow(v, "any") + })) + assert.Equal(t, "2", val) + _, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) { + return conn.Exec("any") + }, "foo") + return err + })) + val, err := r.Get("foo") + assert.NoError(t, err) + assert.Empty(t, val) + }) +} + func resetStats() { atomic.StoreUint64(&stats.Total, 0) atomic.StoreUint64(&stats.Hit, 0) @@ -554,35 +694,35 @@ type dummySqlConn struct { queryRow func(any, string, ...any) error } -func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) { +func (d dummySqlConn) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) { return nil, nil } -func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) { +func (d dummySqlConn) PrepareCtx(_ context.Context, _ string) (sqlx.StmtSession, error) { return nil, nil } -func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error { +func (d dummySqlConn) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } -func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error { +func (d dummySqlConn) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } -func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error { +func (d dummySqlConn) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } -func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error { +func (d dummySqlConn) TransactCtx(_ context.Context, _ func(context.Context, sqlx.Session) error) error { return nil } -func (d dummySqlConn) Exec(query string, args ...any) (sql.Result, error) { +func (d dummySqlConn) Exec(_ string, _ ...any) (sql.Result, error) { return nil, nil } -func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) { +func (d dummySqlConn) Prepare(_ string) (sqlx.StmtSession, error) { return nil, nil } @@ -597,15 +737,15 @@ func (d dummySqlConn) QueryRowCtx(_ context.Context, v any, query string, args . return nil } -func (d dummySqlConn) QueryRowPartial(v any, query string, args ...any) error { +func (d dummySqlConn) QueryRowPartial(_ any, _ string, _ ...any) error { return nil } -func (d dummySqlConn) QueryRows(v any, query string, args ...any) error { +func (d dummySqlConn) QueryRows(_ any, _ string, _ ...any) error { return nil } -func (d dummySqlConn) QueryRowsPartial(v any, query string, args ...any) error { +func (d dummySqlConn) QueryRowsPartial(_ any, _ string, _ ...any) error { return nil } diff --git a/core/stores/sqlx/bulkinserter_test.go b/core/stores/sqlx/bulkinserter_test.go index 52348252..ae4bca1b 100644 --- a/core/stores/sqlx/bulkinserter_test.go +++ b/core/stores/sqlx/bulkinserter_test.go @@ -9,7 +9,7 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/internal/dbtest" ) type mockedConn struct { @@ -81,7 +81,7 @@ func (c *mockedConn) Transact(func(session Session) error) error { } func TestBulkInserter(t *testing.T) { - runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { var conn mockedConn inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES(?, ?, ?)`) assert.Nil(t, err) @@ -98,7 +98,7 @@ func TestBulkInserter(t *testing.T) { } func TestBulkInserterSuffix(t *testing.T) { - runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { var conn mockedConn inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom, user, count) VALUES`+ `(?, ?, ?) ON DUPLICATE KEY UPDATE is_overtime=VALUES(is_overtime)`) @@ -119,7 +119,7 @@ func TestBulkInserterSuffix(t *testing.T) { } func TestBulkInserterBadStatement(t *testing.T) { - runSqlTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { var conn mockedConn _, err := NewBulkInserter(&conn, "foo") assert.NotNil(t, err) @@ -144,19 +144,3 @@ func TestBulkInserter_Update(t *testing.T) { assert.NotNil(t, inserter.UpdateStmt("foo")) assert.NotNil(t, inserter.Insert("foo", "bar")) } - -func runSqlTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { - logx.Disable() - - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - defer db.Close() - - fn(db, mock) - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} diff --git a/core/stores/sqlx/errors.go b/core/stores/sqlx/errors.go new file mode 100644 index 00000000..efe0b159 --- /dev/null +++ b/core/stores/sqlx/errors.go @@ -0,0 +1,14 @@ +package sqlx + +import ( + "database/sql" + "errors" +) + +var ( + // ErrNotFound is an alias of sql.ErrNoRows + ErrNotFound = sql.ErrNoRows + + errCantNestTx = errors.New("cannot nest transactions") + errNoRawDBFromTx = errors.New("cannot get raw db from transaction") +) diff --git a/core/stores/sqlx/orm_test.go b/core/stores/sqlx/orm_test.go index 2efd04d6..4aab1bba 100644 --- a/core/stores/sqlx/orm_test.go +++ b/core/stores/sqlx/orm_test.go @@ -8,11 +8,11 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/internal/dbtest" ) func TestUnmarshalRowBool(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -25,7 +25,7 @@ func TestUnmarshalRowBool(t *testing.T) { } func TestUnmarshalRowBoolNotSettable(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -37,7 +37,7 @@ func TestUnmarshalRowBoolNotSettable(t *testing.T) { } func TestUnmarshalRowInt(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -50,7 +50,7 @@ func TestUnmarshalRowInt(t *testing.T) { } func TestUnmarshalRowInt8(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -63,7 +63,7 @@ func TestUnmarshalRowInt8(t *testing.T) { } func TestUnmarshalRowInt16(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -76,7 +76,7 @@ func TestUnmarshalRowInt16(t *testing.T) { } func TestUnmarshalRowInt32(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -89,7 +89,7 @@ func TestUnmarshalRowInt32(t *testing.T) { } func TestUnmarshalRowInt64(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -102,7 +102,7 @@ func TestUnmarshalRowInt64(t *testing.T) { } func TestUnmarshalRowUint(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -115,7 +115,7 @@ func TestUnmarshalRowUint(t *testing.T) { } func TestUnmarshalRowUint8(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -128,7 +128,7 @@ func TestUnmarshalRowUint8(t *testing.T) { } func TestUnmarshalRowUint16(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("4") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -141,7 +141,7 @@ func TestUnmarshalRowUint16(t *testing.T) { } func TestUnmarshalRowUint32(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -154,7 +154,7 @@ func TestUnmarshalRowUint32(t *testing.T) { } func TestUnmarshalRowUint64(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("6") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -167,7 +167,7 @@ func TestUnmarshalRowUint64(t *testing.T) { } func TestUnmarshalRowFloat32(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("7") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -180,7 +180,7 @@ func TestUnmarshalRowFloat32(t *testing.T) { } func TestUnmarshalRowFloat64(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -193,7 +193,7 @@ func TestUnmarshalRowFloat64(t *testing.T) { } func TestUnmarshalRowString(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { const expect = "hello" rs := sqlmock.NewRows([]string{"value"}).FromCSVString(expect) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -212,7 +212,7 @@ func TestUnmarshalRowStruct(t *testing.T) { Age int }) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -230,7 +230,7 @@ func TestUnmarshalRowStructWithTags(t *testing.T) { Name string `db:"name"` }) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -248,7 +248,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) { Name string `db:"name"` }) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name"}).FromCSVString("liao") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -259,7 +259,7 @@ func TestUnmarshalRowStructWithTagsWrongColumns(t *testing.T) { } func TestUnmarshalRowsBool(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []bool{true, false} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -273,7 +273,7 @@ func TestUnmarshalRowsBool(t *testing.T) { } func TestUnmarshalRowsInt(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []int{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -287,7 +287,7 @@ func TestUnmarshalRowsInt(t *testing.T) { } func TestUnmarshalRowsInt8(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []int8{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -301,7 +301,7 @@ func TestUnmarshalRowsInt8(t *testing.T) { } func TestUnmarshalRowsInt16(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []int16{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -315,7 +315,7 @@ func TestUnmarshalRowsInt16(t *testing.T) { } func TestUnmarshalRowsInt32(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []int32{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -329,7 +329,7 @@ func TestUnmarshalRowsInt32(t *testing.T) { } func TestUnmarshalRowsInt64(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []int64{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -343,7 +343,7 @@ func TestUnmarshalRowsInt64(t *testing.T) { } func TestUnmarshalRowsUint(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []uint{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -357,7 +357,7 @@ func TestUnmarshalRowsUint(t *testing.T) { } func TestUnmarshalRowsUint8(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []uint8{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -371,7 +371,7 @@ func TestUnmarshalRowsUint8(t *testing.T) { } func TestUnmarshalRowsUint16(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []uint16{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -385,7 +385,7 @@ func TestUnmarshalRowsUint16(t *testing.T) { } func TestUnmarshalRowsUint32(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []uint32{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -399,7 +399,7 @@ func TestUnmarshalRowsUint32(t *testing.T) { } func TestUnmarshalRowsUint64(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []uint64{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -413,7 +413,7 @@ func TestUnmarshalRowsUint64(t *testing.T) { } func TestUnmarshalRowsFloat32(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []float32{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -427,7 +427,7 @@ func TestUnmarshalRowsFloat32(t *testing.T) { } func TestUnmarshalRowsFloat64(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []float64{2, 3} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -441,7 +441,7 @@ func TestUnmarshalRowsFloat64(t *testing.T) { } func TestUnmarshalRowsString(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []string{"hello", "world"} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -457,7 +457,7 @@ func TestUnmarshalRowsString(t *testing.T) { func TestUnmarshalRowsBoolPtr(t *testing.T) { yes := true no := false - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*bool{&yes, &no} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("1\n0") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -473,7 +473,7 @@ func TestUnmarshalRowsBoolPtr(t *testing.T) { func TestUnmarshalRowsIntPtr(t *testing.T) { two := 2 three := 3 - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*int{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -489,7 +489,7 @@ func TestUnmarshalRowsIntPtr(t *testing.T) { func TestUnmarshalRowsInt8Ptr(t *testing.T) { two := int8(2) three := int8(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*int8{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -505,7 +505,7 @@ func TestUnmarshalRowsInt8Ptr(t *testing.T) { func TestUnmarshalRowsInt16Ptr(t *testing.T) { two := int16(2) three := int16(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*int16{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -521,7 +521,7 @@ func TestUnmarshalRowsInt16Ptr(t *testing.T) { func TestUnmarshalRowsInt32Ptr(t *testing.T) { two := int32(2) three := int32(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*int32{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -537,7 +537,7 @@ func TestUnmarshalRowsInt32Ptr(t *testing.T) { func TestUnmarshalRowsInt64Ptr(t *testing.T) { two := int64(2) three := int64(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*int64{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -553,7 +553,7 @@ func TestUnmarshalRowsInt64Ptr(t *testing.T) { func TestUnmarshalRowsUintPtr(t *testing.T) { two := uint(2) three := uint(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*uint{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -569,7 +569,7 @@ func TestUnmarshalRowsUintPtr(t *testing.T) { func TestUnmarshalRowsUint8Ptr(t *testing.T) { two := uint8(2) three := uint8(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*uint8{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -585,7 +585,7 @@ func TestUnmarshalRowsUint8Ptr(t *testing.T) { func TestUnmarshalRowsUint16Ptr(t *testing.T) { two := uint16(2) three := uint16(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*uint16{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -601,7 +601,7 @@ func TestUnmarshalRowsUint16Ptr(t *testing.T) { func TestUnmarshalRowsUint32Ptr(t *testing.T) { two := uint32(2) three := uint32(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*uint32{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -617,7 +617,7 @@ func TestUnmarshalRowsUint32Ptr(t *testing.T) { func TestUnmarshalRowsUint64Ptr(t *testing.T) { two := uint64(2) three := uint64(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*uint64{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -633,7 +633,7 @@ func TestUnmarshalRowsUint64Ptr(t *testing.T) { func TestUnmarshalRowsFloat32Ptr(t *testing.T) { two := float32(2) three := float32(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*float32{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -649,7 +649,7 @@ func TestUnmarshalRowsFloat32Ptr(t *testing.T) { func TestUnmarshalRowsFloat64Ptr(t *testing.T) { two := float64(2) three := float64(3) - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*float64{&two, &three} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("2\n3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -665,7 +665,7 @@ func TestUnmarshalRowsFloat64Ptr(t *testing.T) { func TestUnmarshalRowsStringPtr(t *testing.T) { hello := "hello" world := "world" - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { expect := []*string{&hello, &world} rs := sqlmock.NewRows([]string{"value"}).FromCSVString("hello\nworld") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -697,7 +697,7 @@ func TestUnmarshalRowsStruct(t *testing.T) { Age int64 } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { @@ -736,7 +736,7 @@ func TestUnmarshalRowsStructWithNullStringType(t *testing.T) { NullString sql.NullString `db:"value"` } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "value"}).AddRow( "first", "firstnullstring").AddRow("second", nil) mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -771,7 +771,7 @@ func TestUnmarshalRowsStructWithTags(t *testing.T) { Name string `db:"name"` } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { @@ -812,7 +812,7 @@ func TestUnmarshalRowsStructAndEmbeddedAnonymousStructWithTags(t *testing.T) { Embed } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { @@ -854,7 +854,7 @@ func TestUnmarshalRowsStructAndEmbeddedStructPtrAnonymousWithTags(t *testing.T) *Embed } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age", "value"}).FromCSVString("first,2,3\nsecond,3,4") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { @@ -888,7 +888,7 @@ func TestUnmarshalRowsStructPtr(t *testing.T) { Age int64 } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { @@ -921,7 +921,7 @@ func TestUnmarshalRowsStructWithTagsPtr(t *testing.T) { Name string `db:"name"` } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { @@ -954,7 +954,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) { Name string `db:"name"` } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("first,2\nsecond,3") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) assert.Nil(t, query(context.Background(), db, func(rows *sql.Rows) error { @@ -969,7 +969,7 @@ func TestUnmarshalRowsStructWithTagsPtrWithInnerPtr(t *testing.T) { } func TestCommonSqlConn_QueryRowOptional(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs) @@ -1019,7 +1019,7 @@ func TestUnmarshalRowError(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{"age"}).FromCSVString("5") mock.ExpectQuery("select (.+) from users where user=?").WithArgs( "anyone").WillReturnRows(rs) @@ -1091,7 +1091,7 @@ func TestAnonymousStructPr(t *testing.T) { Name string `db:"name"` } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{ "name", "age", @@ -1139,7 +1139,7 @@ func TestAnonymousStructPrError(t *testing.T) { Name string `db:"name"` } - runOrmTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { rs := sqlmock.NewRows([]string{ "name", "age", @@ -1154,7 +1154,7 @@ func TestAnonymousStructPrError(t *testing.T) { WithArgs("anyone").WillReturnRows(rs) assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error { return unmarshalRows(&value, rows, true) - }, "select name, age,grade,discipline,class_name,score from users where user=?", + }, "select name, age, grade, discipline, class_name, score from users where user=?", "anyone")) if len(value) > 0 { assert.Equal(t, value[0].score, 0) @@ -1162,22 +1162,6 @@ func TestAnonymousStructPrError(t *testing.T) { }) } -func runOrmTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { - logx.Disable() - - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - defer db.Close() - - fn(db, mock) - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - type mockedScanner struct { colErr error scanErr error diff --git a/core/stores/sqlx/sqlconn.go b/core/stores/sqlx/sqlconn.go index 7945ce1f..e1297251 100644 --- a/core/stores/sqlx/sqlconn.go +++ b/core/stores/sqlx/sqlconn.go @@ -11,9 +11,6 @@ import ( // spanName is used to identify the span name for the SQL execution. const spanName = "sql" -// ErrNotFound is an alias of sql.ErrNoRows -var ErrNotFound = sql.ErrNoRows - type ( // Session stands for raw connections or transaction sessions Session interface { @@ -131,6 +128,13 @@ func NewSqlConnFromDB(db *sql.DB, opts ...SqlOption) SqlConn { return conn } +// NewSqlConnFromSession returns a SqlConn with the given session. +func NewSqlConnFromSession(session Session) SqlConn { + return txConn{ + Session: session, + } +} + func (db *commonSqlConn) Exec(q string, args ...any) (result sql.Result, err error) { return db.ExecCtx(context.Background(), q, args...) } diff --git a/core/stores/sqlx/sqlconn_test.go b/core/stores/sqlx/sqlconn_test.go index 9ceb36a9..cf1a148e 100644 --- a/core/stores/sqlx/sqlconn_test.go +++ b/core/stores/sqlx/sqlconn_test.go @@ -55,7 +55,7 @@ func TestSqlConn(t *testing.T) { } func buildConn() (mock sqlmock.Sqlmock, err error) { - connManager.GetResource(mockedDatasource, func() (io.Closer, error) { + _, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) { var db *sql.DB var err error db, mock, err = sqlmock.New() diff --git a/core/stores/sqlx/tx.go b/core/stores/sqlx/tx.go index e0d25a17..d983077c 100644 --- a/core/stores/sqlx/tx.go +++ b/core/stores/sqlx/tx.go @@ -15,11 +15,27 @@ type ( Rollback() error } + txConn struct { + Session + } + txSession struct { *sql.Tx } ) +func (s txConn) RawDB() (*sql.DB, error) { + return nil, errNoRawDBFromTx +} + +func (s txConn) Transact(_ func(Session) error) error { + return errCantNestTx +} + +func (s txConn) TransactCtx(_ context.Context, _ func(context.Context, Session) error) error { + return errCantNestTx +} + // NewSessionFromTx returns a Session with the given sql.Tx. // Use it with caution, it's provided for other ORM to interact with. func NewSessionFromTx(tx *sql.Tx) Session { diff --git a/core/stores/sqlx/tx_test.go b/core/stores/sqlx/tx_test.go index 9dd11a10..685f4494 100644 --- a/core/stores/sqlx/tx_test.go +++ b/core/stores/sqlx/tx_test.go @@ -6,7 +6,10 @@ import ( "errors" "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/breaker" + "github.com/zeromicro/go-zero/internal/dbtest" ) const ( @@ -23,51 +26,51 @@ func (mt *mockTx) Commit() error { return nil } -func (mt *mockTx) Exec(q string, args ...any) (sql.Result, error) { +func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) { return nil, nil } -func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) { +func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) { return nil, nil } -func (mt *mockTx) Prepare(query string) (StmtSession, error) { +func (mt *mockTx) Prepare(_ string) (StmtSession, error) { return nil, nil } -func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) { +func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) { return nil, nil } -func (mt *mockTx) QueryRow(v any, q string, args ...any) error { +func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error { return nil } -func (mt *mockTx) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error { +func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } -func (mt *mockTx) QueryRowPartial(v any, q string, args ...any) error { +func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error { return nil } -func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error { +func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } -func (mt *mockTx) QueryRows(v any, q string, args ...any) error { +func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error { return nil } -func (mt *mockTx) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error { +func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } -func (mt *mockTx) QueryRowsPartial(v any, q string, args ...any) error { +func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error { return nil } -func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error { +func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error { return nil } @@ -101,3 +104,209 @@ func TestTransactRollback(t *testing.T) { assert.Equal(t, mockRollback, mock.status) assert.NotNil(t, err) } + +func TestTxExceptions(t *testing.T) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectCommit() + conn := NewSqlConnFromDB(db) + assert.NoError(t, conn.Transact(func(session Session) error { + return nil + })) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + conn := &commonSqlConn{ + connProv: func() (*sql.DB, error) { + return nil, errors.New("foo") + }, + beginTx: begin, + onError: func(ctx context.Context, err error) {}, + brk: breaker.NewBreaker(), + } + assert.Error(t, conn.Transact(func(session Session) error { + return nil + })) + }) + + runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { + _, err := conn.RawDB() + assert.Equal(t, errNoRawDBFromTx, err) + assert.Equal(t, errCantNestTx, conn.Transact(nil)) + assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil)) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + conn := NewSqlConnFromDB(db) + assert.Error(t, conn.Transact(func(session Session) error { + return errors.New("foo") + })) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectRollback().WillReturnError(errors.New("foo")) + conn := NewSqlConnFromDB(db) + assert.Error(t, conn.Transact(func(session Session) error { + panic("foo") + })) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectRollback() + conn := NewSqlConnFromDB(db) + assert.Error(t, conn.Transact(func(session Session) error { + panic(errors.New("foo")) + })) + }) +} + +func TestTxSession(t *testing.T) { + runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { + mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) + res, err := conn.Exec("any") + assert.NoError(t, err) + last, err := res.LastInsertId() + assert.NoError(t, err) + assert.Equal(t, int64(2), last) + affected, err := res.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(3), affected) + + mock.ExpectExec("any").WillReturnError(errors.New("foo")) + _, err = conn.Exec("any") + assert.Equal(t, "foo", err.Error()) + }) + + runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { + mock.ExpectPrepare("any") + stmt, err := conn.Prepare("any") + assert.NoError(t, err) + assert.NotNil(t, stmt) + + mock.ExpectPrepare("any").WillReturnError(errors.New("foo")) + _, err = conn.Prepare("any") + assert.Equal(t, "foo", err.Error()) + }) + + runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"col"}).AddRow("foo") + mock.ExpectQuery("any").WillReturnRows(rows) + var val string + err := conn.QueryRow(&val, "any") + assert.NoError(t, err) + assert.Equal(t, "foo", val) + + mock.ExpectQuery("any").WillReturnError(errors.New("foo")) + err = conn.QueryRow(&val, "any") + assert.Equal(t, "foo", err.Error()) + }) + + runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"col"}).AddRow("foo") + mock.ExpectQuery("any").WillReturnRows(rows) + var val string + err := conn.QueryRowPartial(&val, "any") + assert.NoError(t, err) + assert.Equal(t, "foo", val) + + mock.ExpectQuery("any").WillReturnError(errors.New("foo")) + err = conn.QueryRowPartial(&val, "any") + assert.Equal(t, "foo", err.Error()) + }) + + runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + var val []string + err := conn.QueryRows(&val, "any") + assert.NoError(t, err) + assert.Equal(t, []string{"foo", "bar"}, val) + + mock.ExpectQuery("any").WillReturnError(errors.New("foo")) + err = conn.QueryRows(&val, "any") + assert.Equal(t, "foo", err.Error()) + }) + + runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) { + rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar") + mock.ExpectQuery("any").WillReturnRows(rows) + var val []string + err := conn.QueryRowsPartial(&val, "any") + assert.NoError(t, err) + assert.Equal(t, []string{"foo", "bar"}, val) + + mock.ExpectQuery("any").WillReturnError(errors.New("foo")) + err = conn.QueryRowsPartial(&val, "any") + assert.Equal(t, "foo", err.Error()) + }) +} + +func TestTxRollback(t *testing.T) { + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) + mock.ExpectQuery("foo").WillReturnError(errors.New("foo")) + mock.ExpectRollback() + + conn := NewSqlConnFromDB(db) + err := conn.Transact(func(session Session) error { + c := NewSqlConnFromSession(session) + _, err := c.Exec("any") + assert.NoError(t, err) + + var val string + return c.QueryRow(&val, "foo") + }) + assert.Error(t, err) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("any").WillReturnError(errors.New("foo")) + mock.ExpectRollback() + + conn := NewSqlConnFromDB(db) + err := conn.Transact(func(session Session) error { + c := NewSqlConnFromSession(session) + if _, err := c.Exec("any"); err != nil { + return err + } + + var val string + assert.NoError(t, c.QueryRow(&val, "foo")) + return nil + }) + assert.Error(t, err) + }) + + dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3)) + mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar")) + mock.ExpectCommit() + + conn := NewSqlConnFromDB(db) + err := conn.Transact(func(session Session) error { + c := NewSqlConnFromSession(session) + _, err := c.Exec("any") + assert.NoError(t, err) + + var val string + assert.NoError(t, c.QueryRow(&val, "foo")) + assert.Equal(t, "bar", val) + return nil + }) + assert.NoError(t, err) + }) +} + +func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) { + dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) { + sess := NewSessionFromTx(tx) + conn := NewSqlConnFromSession(sess) + f(conn, mock) + }) +} diff --git a/internal/dbtest/sql.go b/internal/dbtest/sql.go new file mode 100644 index 00000000..2e1be298 --- /dev/null +++ b/internal/dbtest/sql.go @@ -0,0 +1,37 @@ +package dbtest + +import ( + "database/sql" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +// RunTest runs a test function with a mock database. +func RunTest(t *testing.T, fn func(db *sql.DB, mock sqlmock.Sqlmock)) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer func() { + _ = db.Close() + }() + + fn(db, mock) + + if err = mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + +// RunTxTest runs a test function with a mock database in a transaction. +func RunTxTest(t *testing.T, f func(tx *sql.Tx, mock sqlmock.Sqlmock)) { + RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) { + mock.ExpectBegin() + tx, err := db.Begin() + if assert.NoError(t, err) { + f(tx, mock) + } + }) +}