fix: mysql WithAcceptable bug (#3986)
This commit is contained in:
@@ -13,7 +13,7 @@ const (
|
|||||||
|
|
||||||
// NewMysql returns a mysql connection.
|
// NewMysql returns a mysql connection.
|
||||||
func NewMysql(datasource string, opts ...SqlOption) SqlConn {
|
func NewMysql(datasource string, opts ...SqlOption) SqlConn {
|
||||||
opts = append(opts, withMysqlAcceptable())
|
opts = append([]SqlOption{withMysqlAcceptable()}, opts...)
|
||||||
return NewSqlConn(mysqlDriverName, datasource, opts...)
|
return NewSqlConn(mysqlDriverName, datasource, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ package sqlx
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-sql-driver/mysql"
|
"github.com/go-sql-driver/mysql"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/stat"
|
"github.com/zeromicro/go-zero/core/stat"
|
||||||
@@ -38,7 +38,6 @@ func TestBreakerOnNotHandlingDuplicateEntry(t *testing.T) {
|
|||||||
func TestMysqlAcceptable(t *testing.T) {
|
func TestMysqlAcceptable(t *testing.T) {
|
||||||
conn := NewMysql("nomysql").(*commonSqlConn)
|
conn := NewMysql("nomysql").(*commonSqlConn)
|
||||||
withMysqlAcceptable()(conn)
|
withMysqlAcceptable()(conn)
|
||||||
assert.EqualValues(t, reflect.ValueOf(mysqlAcceptable).Pointer(), reflect.ValueOf(conn.accept).Pointer())
|
|
||||||
assert.True(t, mysqlAcceptable(nil))
|
assert.True(t, mysqlAcceptable(nil))
|
||||||
assert.False(t, mysqlAcceptable(errors.New("any")))
|
assert.False(t, mysqlAcceptable(errors.New("any")))
|
||||||
assert.False(t, mysqlAcceptable(new(mysql.MySQLError)))
|
assert.False(t, mysqlAcceptable(new(mysql.MySQLError)))
|
||||||
|
|||||||
@@ -315,6 +315,13 @@ func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows)
|
|||||||
// acceptable is the func to check if the error can be accepted.
|
// acceptable is the func to check if the error can be accepted.
|
||||||
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
func WithAcceptable(acceptable func(err error) bool) SqlOption {
|
||||||
return func(conn *commonSqlConn) {
|
return func(conn *commonSqlConn) {
|
||||||
conn.accept = acceptable
|
if conn.accept == nil {
|
||||||
|
conn.accept = acceptable
|
||||||
|
} else {
|
||||||
|
pre := conn.accept
|
||||||
|
conn.accept = func(err error) bool {
|
||||||
|
return pre(err) || acceptable(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/breaker"
|
"github.com/zeromicro/go-zero/core/breaker"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"github.com/zeromicro/go-zero/core/stores/dbtest"
|
"github.com/zeromicro/go-zero/core/stores/dbtest"
|
||||||
@@ -264,6 +265,45 @@ func TestBreakerWithScanError(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithAcceptable(t *testing.T) {
|
||||||
|
var (
|
||||||
|
acceptableErr = errors.New("acceptable")
|
||||||
|
acceptableErr2 = errors.New("acceptable2")
|
||||||
|
acceptableErr3 = errors.New("acceptable3")
|
||||||
|
)
|
||||||
|
opts := []SqlOption{
|
||||||
|
WithAcceptable(func(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return errors.Is(err, acceptableErr)
|
||||||
|
}),
|
||||||
|
WithAcceptable(func(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return errors.Is(err, acceptableErr2)
|
||||||
|
}),
|
||||||
|
WithAcceptable(func(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return errors.Is(err, acceptableErr3)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
var conn = &commonSqlConn{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, conn.accept(nil))
|
||||||
|
assert.False(t, conn.accept(assert.AnError))
|
||||||
|
assert.True(t, conn.accept(acceptableErr))
|
||||||
|
assert.True(t, conn.accept(acceptableErr2))
|
||||||
|
assert.True(t, conn.accept(acceptableErr3))
|
||||||
|
}
|
||||||
|
|
||||||
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
func buildConn() (mock sqlmock.Sqlmock, err error) {
|
||||||
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
|
||||||
var db *sql.DB
|
var db *sql.DB
|
||||||
|
|||||||
Reference in New Issue
Block a user