feat: support using session to execute statements in transaction (#3252)

This commit is contained in:
Kevin Wan
2023-05-17 22:15:24 +08:00
committed by GitHub
parent f0bdfb928f
commit bff5b81ad9
11 changed files with 526 additions and 126 deletions

View File

@@ -6,7 +6,10 @@ import (
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/breaker"
"github.com/zeromicro/go-zero/internal/dbtest"
)
const (
@@ -23,51 +26,51 @@ func (mt *mockTx) Commit() error {
return nil
}
func (mt *mockTx) Exec(q string, args ...any) (sql.Result, error) {
func (mt *mockTx) Exec(_ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
func (mt *mockTx) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
return nil, nil
}
func (mt *mockTx) Prepare(query string) (StmtSession, error) {
func (mt *mockTx) Prepare(_ string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) PrepareCtx(ctx context.Context, query string) (StmtSession, error) {
func (mt *mockTx) PrepareCtx(_ context.Context, _ string) (StmtSession, error) {
return nil, nil
}
func (mt *mockTx) QueryRow(v any, q string, args ...any) error {
func (mt *mockTx) QueryRow(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowPartial(v any, q string, args ...any) error {
func (mt *mockTx) QueryRowPartial(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRows(v any, q string, args ...any) error {
func (mt *mockTx) QueryRows(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsPartial(v any, q string, args ...any) error {
func (mt *mockTx) QueryRowsPartial(_ any, _ string, _ ...any) error {
return nil
}
func (mt *mockTx) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
func (mt *mockTx) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
return nil
}
@@ -101,3 +104,209 @@ func TestTransactRollback(t *testing.T) {
assert.Equal(t, mockRollback, mock.status)
assert.NotNil(t, err)
}
func TestTxExceptions(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
assert.NoError(t, conn.Transact(func(session Session) error {
return nil
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := &commonSqlConn{
connProv: func() (*sql.DB, error) {
return nil, errors.New("foo")
},
beginTx: begin,
onError: func(ctx context.Context, err error) {},
brk: breaker.NewBreaker(),
}
assert.Error(t, conn.Transact(func(session Session) error {
return nil
}))
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
_, err := conn.RawDB()
assert.Equal(t, errNoRawDBFromTx, err)
assert.Equal(t, errCantNestTx, conn.Transact(nil))
assert.Equal(t, errCantNestTx, conn.TransactCtx(context.Background(), nil))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
return errors.New("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback().WillReturnError(errors.New("foo"))
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic("foo")
}))
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
assert.Error(t, conn.Transact(func(session Session) error {
panic(errors.New("foo"))
}))
})
}
func TestTxSession(t *testing.T) {
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
res, err := conn.Exec("any")
assert.NoError(t, err)
last, err := res.LastInsertId()
assert.NoError(t, err)
assert.Equal(t, int64(2), last)
affected, err := res.RowsAffected()
assert.NoError(t, err)
assert.Equal(t, int64(3), affected)
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
_, err = conn.Exec("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
mock.ExpectPrepare("any")
stmt, err := conn.Prepare("any")
assert.NoError(t, err)
assert.NotNil(t, stmt)
mock.ExpectPrepare("any").WillReturnError(errors.New("foo"))
_, err = conn.Prepare("any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRow(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRow(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo")
mock.ExpectQuery("any").WillReturnRows(rows)
var val string
err := conn.QueryRowPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, "foo", val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRows(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRows(&val, "any")
assert.Equal(t, "foo", err.Error())
})
runTxTest(t, func(conn SqlConn, mock sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"col"}).AddRow("foo").AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val []string
err := conn.QueryRowsPartial(&val, "any")
assert.NoError(t, err)
assert.Equal(t, []string{"foo", "bar"}, val)
mock.ExpectQuery("any").WillReturnError(errors.New("foo"))
err = conn.QueryRowsPartial(&val, "any")
assert.Equal(t, "foo", err.Error())
})
}
func TestTxRollback(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
return c.QueryRow(&val, "foo")
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
mock.ExpectRollback()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
if _, err := c.Exec("any"); err != nil {
return err
}
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
return nil
})
assert.Error(t, err)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
mock.ExpectBegin()
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
mock.ExpectQuery("foo").WillReturnRows(sqlmock.NewRows([]string{"col"}).AddRow("bar"))
mock.ExpectCommit()
conn := NewSqlConnFromDB(db)
err := conn.Transact(func(session Session) error {
c := NewSqlConnFromSession(session)
_, err := c.Exec("any")
assert.NoError(t, err)
var val string
assert.NoError(t, c.QueryRow(&val, "foo"))
assert.Equal(t, "bar", val)
return nil
})
assert.NoError(t, err)
})
}
func runTxTest(t *testing.T, f func(conn SqlConn, mock sqlmock.Sqlmock)) {
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
sess := NewSessionFromTx(tx)
conn := NewSqlConnFromSession(sess)
f(conn, mock)
})
}