feat: support using session to execute statements in transaction (#3252)
This commit is contained in:
@@ -6,7 +6,10 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/internal/dbtest"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,51 +26,51 @@ func (mt *mockTx) Commit() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Exec(q string, args ...any) (sql.Result, error) {
|
||||
func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
|
||||
func (mt *mockTx) Prepare(_ string) (StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
|
||||
func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRow(v any, q string, args ...any) error {
|
||||
func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
|
||||
func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowPartial(v any, q string, args ...any) error {
|
||||
func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
|
||||
func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRows(v any, q string, args ...any) error {
|
||||
func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
|
||||
func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowsPartial(v any, q string, args ...any) error {
|
||||
func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
|
||||
func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -101,3 +104,209 @@ func TestTransactRollback(t *testing.T) {
|
||||
assert.Equal(t, mockRollback, mock.status)
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestTxExceptions(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectCommit()
|
||||
conn := NewSqlConnFromDB(db)
|
||||
assert.NoError(t, conn.Transact(func(session Session) error {
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
conn := &commonSqlConn{
|
||||
connProv: func() (*sql.DB, error) {
|
||||
return nil, errors.New("foo")
|
||||
},
|
||||
beginTx: begin,
|
||||
onError: func(ctx context.Context, err error) {},
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
assert.Error(t, conn.Transact(func(session Session) error {
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
|
||||
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||
_, err := conn.RawDB()
|
||||
assert.Equal(t, errNoRawDBFromTx, err)
|
||||
assert.Equal(t, errCantNestTx, conn.Transact(nil))
|
||||
assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil))
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
conn := NewSqlConnFromDB(db)
|
||||
assert.Error(t, conn.Transact(func(session Session) error {
|
||||
return errors.New("foo")
|
||||
}))
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectRollback().WillReturnError(errors.New("foo"))
|
||||
conn := NewSqlConnFromDB(db)
|
||||
assert.Error(t, conn.Transact(func(session Session) error {
|
||||
panic("foo")
|
||||
}))
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectRollback()
|
||||
conn := NewSqlConnFromDB(db)
|
||||
assert.Error(t, conn.Transact(func(session Session) error {
|
||||
panic(errors.New("foo"))
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestTxSession(t *testing.T) {
|
||||
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||
res, err := conn.Exec("any")
|
||||
assert.NoError(t, err)
|
||||
last, err := res.LastInsertId()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(2), last)
|
||||
affected, err := res.RowsAffected()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3), affected)
|
||||
|
||||
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
|
||||
_, err = conn.Exec("any")
|
||||
assert.Equal(t, "foo", err.Error())
|
||||
})
|
||||
|
||||
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectPrepare("any")
|
||||
stmt, err := conn.Prepare("any")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, stmt)
|
||||
|
||||
mock.ExpectPrepare("any").WillReturnError(errors.New("foo"))
|
||||
_, err = conn.Prepare("any")
|
||||
assert.Equal(t, "foo", err.Error())
|
||||
})
|
||||
|
||||
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
|
||||
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||
var val string
|
||||
err := conn.QueryRow(&val, "any")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "foo", val)
|
||||
|
||||
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
|
||||
err = conn.QueryRow(&val, "any")
|
||||
assert.Equal(t, "foo", err.Error())
|
||||
})
|
||||
|
||||
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
|
||||
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||
var val string
|
||||
err := conn.QueryRowPartial(&val, "any")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "foo", val)
|
||||
|
||||
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
|
||||
err = conn.QueryRowPartial(&val, "any")
|
||||
assert.Equal(t, "foo", err.Error())
|
||||
})
|
||||
|
||||
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
|
||||
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||
var val []string
|
||||
err := conn.QueryRows(&val, "any")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"foo", "bar"}, val)
|
||||
|
||||
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
|
||||
err = conn.QueryRows(&val, "any")
|
||||
assert.Equal(t, "foo", err.Error())
|
||||
})
|
||||
|
||||
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
|
||||
mock.ExpectQuery("any").WillReturnRows(rows)
|
||||
var val []string
|
||||
err := conn.QueryRowsPartial(&val, "any")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"foo", "bar"}, val)
|
||||
|
||||
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
|
||||
err = conn.QueryRowsPartial(&val, "any")
|
||||
assert.Equal(t, "foo", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTxRollback(t *testing.T) {
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||
mock.ExpectQuery("foo").WillReturnError(errors.New("foo"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
conn := NewSqlConnFromDB(db)
|
||||
err := conn.Transact(func(session Session) error {
|
||||
c := NewSqlConnFromSession(session)
|
||||
_, err := c.Exec("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var val string
|
||||
return c.QueryRow(&val, "foo")
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
conn := NewSqlConnFromDB(db)
|
||||
err := conn.Transact(func(session Session) error {
|
||||
c := NewSqlConnFromSession(session)
|
||||
if _, err := c.Exec("any"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var val string
|
||||
assert.NoError(t, c.QueryRow(&val, "foo"))
|
||||
return nil
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||
mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar"))
|
||||
mock.ExpectCommit()
|
||||
|
||||
conn := NewSqlConnFromDB(db)
|
||||
err := conn.Transact(func(session Session) error {
|
||||
c := NewSqlConnFromSession(session)
|
||||
_, err := c.Exec("any")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var val string
|
||||
assert.NoError(t, c.QueryRow(&val, "foo"))
|
||||
assert.Equal(t, "bar", val)
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
|
||||
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
|
||||
sess := NewSessionFromTx(tx)
|
||||
conn := NewSqlConnFromSession(sess)
|
||||
f(conn, mock)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user