feat: support using session to execute statements in transaction (#3252)
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/fx"
|
||||
@@ -24,6 +25,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
|
||||
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/internal/dbtest"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -39,7 +42,7 @@ func TestCachedConn_GetCache(t *testing.T) {
|
||||
var value string
|
||||
err := c.GetCache("any", &value)
|
||||
assert.Equal(t, ErrNotFound, err)
|
||||
r.Set("any", `"value"`)
|
||||
_ = r.Set("any", `"value"`)
|
||||
err = c.GetCache("any", &value)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "value", value)
|
||||
@@ -368,6 +371,24 @@ func TestStatFromMemory(t *testing.T) {
|
||||
assert.Equal(t, uint64(9), atomic.LoadUint64(&stats.Hit))
|
||||
}
|
||||
|
||||
func TestCachedConn_DelCache(t *testing.T) {
|
||||
r := redistest.CreateRedis(t)
|
||||
|
||||
const (
|
||||
key = "user"
|
||||
value = "any"
|
||||
)
|
||||
assert.NoError(t, r.Set(key, value))
|
||||
|
||||
c := NewNodeConn(&trackedConn{}, r, cache.WithExpiry(time.Second*30))
|
||||
err := c.DelCache(key)
|
||||
assert.Nil(t, err)
|
||||
|
||||
val, err := r.Get(key)
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, val)
|
||||
}
|
||||
|
||||
func TestCachedConnQueryRow(t *testing.T) {
|
||||
r := redistest.CreateRedis(t)
|
||||
|
||||
@@ -543,6 +564,125 @@ func TestNewConnWithCache(t *testing.T) {
|
||||
assert.True(t, conn.execValue)
|
||||
}
|
||||
|
||||
func TestCachedConn_WithSession(t *testing.T) {
|
||||
dbtest.RunTxTest(t, func(tx *sql.Tx, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||
|
||||
r := redistest.CreateRedis(t)
|
||||
conn := CachedConn{
|
||||
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||
}
|
||||
conn = conn.WithSession(sqlx.NewSessionFromTx(tx))
|
||||
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
|
||||
return conn.Exec("any")
|
||||
}, "foo")
|
||||
assert.NoError(t, err)
|
||||
last, err := res.LastInsertId()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(2), last)
|
||||
affected, err := res.RowsAffected()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3), affected)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||
mock.ExpectCommit()
|
||||
|
||||
r := redistest.CreateRedis(t)
|
||||
conn := CachedConn{
|
||||
db: sqlx.NewSqlConnFromDB(db),
|
||||
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||
}
|
||||
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
|
||||
conn = conn.WithSession(session)
|
||||
res, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
|
||||
return conn.Exec("any")
|
||||
}, "foo")
|
||||
assert.NoError(t, err)
|
||||
last, err := res.LastInsertId()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(2), last)
|
||||
affected, err := res.RowsAffected()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3), affected)
|
||||
return nil
|
||||
}))
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("any").WillReturnError(errors.New("foo"))
|
||||
mock.ExpectRollback()
|
||||
|
||||
r := redistest.CreateRedis(t)
|
||||
conn := CachedConn{
|
||||
db: sqlx.NewSqlConnFromDB(db),
|
||||
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||
}
|
||||
assert.Error(t, conn.Transact(func(session sqlx.Session) error {
|
||||
conn = conn.WithSession(session)
|
||||
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
|
||||
return conn.Exec("any")
|
||||
}, "bar")
|
||||
return err
|
||||
}))
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
|
||||
mock.ExpectCommit()
|
||||
|
||||
r := redistest.CreateRedis(t)
|
||||
conn := CachedConn{
|
||||
db: sqlx.NewSqlConnFromDB(db),
|
||||
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||
}
|
||||
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
|
||||
var val string
|
||||
conn = conn.WithSession(session)
|
||||
err := conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
|
||||
return conn.QueryRow(v, "any")
|
||||
})
|
||||
assert.Equal(t, "2", val)
|
||||
return err
|
||||
}))
|
||||
val, err := r.Get("foo")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `"2"`, val)
|
||||
})
|
||||
|
||||
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery("any").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(2))
|
||||
mock.ExpectExec("any").WillReturnResult(sqlmock.NewResult(2, 3))
|
||||
mock.ExpectCommit()
|
||||
|
||||
r := redistest.CreateRedis(t)
|
||||
conn := CachedConn{
|
||||
db: sqlx.NewSqlConnFromDB(db),
|
||||
cache: cache.NewNode(r, syncx.NewSingleFlight(), stats, sql.ErrNoRows),
|
||||
}
|
||||
assert.NoError(t, conn.Transact(func(session sqlx.Session) error {
|
||||
var val string
|
||||
conn = conn.WithSession(session)
|
||||
assert.NoError(t, conn.QueryRow(&val, "foo", func(conn sqlx.SqlConn, v interface{}) error {
|
||||
return conn.QueryRow(v, "any")
|
||||
}))
|
||||
assert.Equal(t, "2", val)
|
||||
_, err := conn.Exec(func(conn sqlx.SqlConn) (sql.Result, error) {
|
||||
return conn.Exec("any")
|
||||
}, "foo")
|
||||
return err
|
||||
}))
|
||||
val, err := r.Get("foo")
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, val)
|
||||
})
|
||||
}
|
||||
|
||||
func resetStats() {
|
||||
atomic.StoreUint64(&stats.Total, 0)
|
||||
atomic.StoreUint64(&stats.Hit, 0)
|
||||
@@ -554,35 +694,35 @@ type dummySqlConn struct {
|
||||
queryRow func(any, string, ...any) error
|
||||
}
|
||||
|
||||
func (d dummySqlConn) ExecCtx(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
func (d dummySqlConn) ExecCtx(_ context.Context, _ string, _ ...any) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) PrepareCtx(ctx context.Context, query string) (sqlx.StmtSession, error) {
|
||||
func (d dummySqlConn) PrepareCtx(_ context.Context, _ string) (sqlx.StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowPartialCtx(ctx context.Context, v any, query string, args ...any) error {
|
||||
func (d dummySqlConn) QueryRowPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowsCtx(ctx context.Context, v any, query string, args ...any) error {
|
||||
func (d dummySqlConn) QueryRowsCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowsPartialCtx(ctx context.Context, v any, query string, args ...any) error {
|
||||
func (d dummySqlConn) QueryRowsPartialCtx(_ context.Context, _ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) TransactCtx(ctx context.Context, fn func(context.Context, sqlx.Session) error) error {
|
||||
func (d dummySqlConn) TransactCtx(_ context.Context, _ func(context.Context, sqlx.Session) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) Exec(query string, args ...any) (sql.Result, error) {
|
||||
func (d dummySqlConn) Exec(_ string, _ ...any) (sql.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) Prepare(query string) (sqlx.StmtSession, error) {
|
||||
func (d dummySqlConn) Prepare(_ string) (sqlx.StmtSession, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -597,15 +737,15 @@ func (d dummySqlConn) QueryRowCtx(_ context.Context, v any, query string, args .
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowPartial(v any, query string, args ...any) error {
|
||||
func (d dummySqlConn) QueryRowPartial(_ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRows(v any, query string, args ...any) error {
|
||||
func (d dummySqlConn) QueryRows(_ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dummySqlConn) QueryRowsPartial(v any, query string, args ...any) error {
|
||||
func (d dummySqlConn) QueryRowsPartial(_ any, _ string, _ ...any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user