fix: format error should not trigger circuit breaker in sqlx (#3437)

This commit is contained in:
Kevin Wan
2023-07-23 20:40:03 +08:00
committed by GitHub
parent 05db706c62
commit ff04356704
6 changed files with 77 additions and 12 deletions

View File

@@ -291,12 +291,19 @@ func (db *commonSqlConn) TransactCtx(ctx context.Context, fn func(context.Contex
}
func (db *commonSqlConn) acceptable(err error) bool {
ok := err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled
if db.accept == nil {
return ok
if err == nil || err == sql.ErrNoRows || err == sql.ErrTxDone || err == context.Canceled {
return true
}
return ok || db.accept(err)
if _, ok := err.(acceptableError); ok {
return true
}
if db.accept == nil {
return false
}
return db.accept(err)
}
func (db *commonSqlConn) queryRows(ctx context.Context, scanner func(*sql.Rows) error,

View File

@@ -236,6 +236,33 @@ func TestStatement(t *testing.T) {
})
}
func TestBreakerWithFormatError(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
for i := 0; i < 1000; i++ {
var val string
if !assert.NotEqual(t, breaker.ErrServiceUnavailable,
conn.QueryRow(&val, "any ?, ?", "foo")) {
break
}
}
})
}
func TestBreakerWithScanError(t *testing.T) {
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
conn := NewSqlConnFromDB(db, withMysqlAcceptable())
for i := 0; i < 1000; i++ {
rows := sqlmock.NewRows([]string{"foo"}).AddRow("bar")
mock.ExpectQuery("any").WillReturnRows(rows)
var val int
if !assert.NotEqual(t, breaker.ErrServiceUnavailable, conn.QueryRow(&val, "any")) {
break
}
}
})
}
func buildConn() (mock sqlmock.Sqlmock, err error) {
_, err = connManager.GetResource(mockedDatasource, func() (io.Closer, error) {
var db *sql.DB

View File

@@ -51,7 +51,13 @@ func escape(input string) string {
return b.String()
}
func format(query string, args ...any) (string, error) {
func format(query string, args ...any) (val string, err error) {
defer func() {
if err != nil {
err = newAcceptableError(err)
}
}()
numArgs := len(args)
if numArgs == 0 {
return query, nil
@@ -66,7 +72,8 @@ func format(query string, args ...any) (string, error) {
switch ch {
case '?':
if argIndex >= numArgs {
return "", fmt.Errorf("%d ? in sql, but less arguments provided", argIndex)
return "", fmt.Errorf("%d ? in sql, but only %d arguments provided",
argIndex+1, numArgs)
}
writeValue(&b, args[argIndex])
@@ -165,3 +172,17 @@ func writeValue(buf *strings.Builder, arg any) {
buf.WriteString(mapping.Repr(v))
}
}
type acceptableError struct {
err error
}
func newAcceptableError(err error) error {
return acceptableError{
err: err,
}
}
func (e acceptableError) Error() string {
return e.err.Error()
}