fix: UpdateStmt doesn't update the statement correctly in sqlx/bulkinserter.go (#3607)
This commit is contained in:
@@ -30,7 +30,7 @@ const (
|
|||||||
leftSquareBracket = '['
|
leftSquareBracket = '['
|
||||||
rightSquareBracket = ']'
|
rightSquareBracket = ']'
|
||||||
segmentSeparator = ','
|
segmentSeparator = ','
|
||||||
intSize = 32 << (^uint(0) >> 63)
|
intSize = 32 << (^uint(0) >> 63) // 32 or 64
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/executors"
|
"github.com/zeromicro/go-zero/core/executors"
|
||||||
@@ -30,6 +31,7 @@ type (
|
|||||||
executor *executors.PeriodicalExecutor
|
executor *executors.PeriodicalExecutor
|
||||||
inserter *dbInserter
|
inserter *dbInserter
|
||||||
stmt bulkStmt
|
stmt bulkStmt
|
||||||
|
lock sync.RWMutex // guards stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
bulkStmt struct {
|
bulkStmt struct {
|
||||||
@@ -65,6 +67,9 @@ func (bi *BulkInserter) Flush() {
|
|||||||
|
|
||||||
// Insert inserts given args.
|
// Insert inserts given args.
|
||||||
func (bi *BulkInserter) Insert(args ...any) error {
|
func (bi *BulkInserter) Insert(args ...any) error {
|
||||||
|
bi.lock.RLock()
|
||||||
|
defer bi.lock.RUnlock()
|
||||||
|
|
||||||
value, err := format(bi.stmt.valueFormat, args...)
|
value, err := format(bi.stmt.valueFormat, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -95,6 +100,11 @@ func (bi *BulkInserter) UpdateStmt(stmt string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bi.lock.Lock()
|
||||||
|
defer bi.lock.Unlock()
|
||||||
|
|
||||||
|
// with write lock, it doesn't matter what's the order of setting bi.stmt and calling flush.
|
||||||
|
bi.stmt = bkStmt
|
||||||
bi.executor.Flush()
|
bi.executor.Flush()
|
||||||
bi.executor.Sync(func() {
|
bi.executor.Sync(func() {
|
||||||
bi.inserter.stmt = bkStmt
|
bi.inserter.stmt = bkStmt
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
@@ -13,14 +16,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type mockedConn struct {
|
type mockedConn struct {
|
||||||
query string
|
query string
|
||||||
args []any
|
args []any
|
||||||
execErr error
|
execErr error
|
||||||
|
updateCallback func(query string, args []any)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) {
|
func (c *mockedConn) ExecCtx(_ context.Context, query string, args ...any) (sql.Result, error) {
|
||||||
c.query = query
|
c.query = query
|
||||||
c.args = args
|
c.args = args
|
||||||
|
if c.updateCallback != nil {
|
||||||
|
c.updateCallback(query, args)
|
||||||
|
}
|
||||||
|
|
||||||
return nil, c.execErr
|
return nil, c.execErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,3 +152,50 @@ func TestBulkInserter_Update(t *testing.T) {
|
|||||||
assert.NotNil(t, inserter.UpdateStmt("foo"))
|
assert.NotNil(t, inserter.UpdateStmt("foo"))
|
||||||
assert.NotNil(t, inserter.Insert("foo", "bar"))
|
assert.NotNil(t, inserter.Insert("foo", "bar"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBulkInserter_UpdateStmt(t *testing.T) {
|
||||||
|
var updated int32
|
||||||
|
conn := mockedConn{
|
||||||
|
execErr: errors.New("foo"),
|
||||||
|
updateCallback: func(query string, args []any) {
|
||||||
|
count := atomic.AddInt32(&updated, 1)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, 100, strings.Count(query, "foo"))
|
||||||
|
if count == 1 {
|
||||||
|
assert.Equal(t, 0, strings.Count(query, "bar"))
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 100, strings.Count(query, "bar"))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
inserter, err := NewBulkInserter(&conn, `INSERT INTO classroom_dau(classroom) VALUES(?)`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var wg1 sync.WaitGroup
|
||||||
|
wg1.Add(2)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg1.Done()
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
assert.NoError(t, inserter.Insert("foo"))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg1.Wait()
|
||||||
|
|
||||||
|
assert.NoError(t, inserter.UpdateStmt(`INSERT INTO classroom_dau(classroom, user) VALUES(?, ?)`))
|
||||||
|
|
||||||
|
var wg2 sync.WaitGroup
|
||||||
|
wg2.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg2.Done()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
assert.NoError(t, inserter.Insert("foo", "bar"))
|
||||||
|
}
|
||||||
|
inserter.Flush()
|
||||||
|
}()
|
||||||
|
wg2.Wait()
|
||||||
|
|
||||||
|
assert.Equal(t, int32(2), atomic.LoadInt32(&updated))
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user