update: expand the retry method to support timeout and interval control (#3283)

This commit is contained in:
Xiaoju Jiang
2023-05-28 10:17:50 +08:00
committed by GitHub
parent 32f78668db
commit 8d48e34eed
2 changed files with 162 additions and 18 deletions

View File

@@ -1,31 +1,87 @@
package fx package fx
import "github.com/zeromicro/go-zero/core/errorx" import (
"context"
"errors"
"time"
"github.com/zeromicro/go-zero/core/errorx"
)
const defaultRetryTimes = 3 const defaultRetryTimes = 3
var errTimeout = errors.New("retry timeout")
type ( type (
// RetryOption defines the method to customize DoWithRetry. // RetryOption defines the method to customize DoWithRetry.
RetryOption func(*retryOptions) RetryOption func(*retryOptions)
retryOptions struct { retryOptions struct {
times int times int
interval time.Duration
timeout time.Duration
} }
) )
// DoWithRetry runs fn, and retries if failed. Default to retry 3 times. // DoWithRetry runs fn, and retries if failed. Default to retry 3 times.
// Note that if the fn function accesses global variables outside the function and performs modification operations,
// it is best to lock them, otherwise there may be data race issues
func DoWithRetry(fn func() error, opts ...RetryOption) error { func DoWithRetry(fn func() error, opts ...RetryOption) error {
return retry(fn, opts...)
}
// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times.
// fn retryCount indicates the current number of retries,starting from 0
// Note that if the fn function accesses global variables outside the function and performs modification operations,
// it is best to lock them, otherwise there may be data race issues
func DoWithRetryCtx(fn func(ctx context.Context, retryCount int) error, opts ...RetryOption) error {
return retry(fn, opts...)
}
func retry(fn interface{}, opts ...RetryOption) error {
options := newRetryOptions() options := newRetryOptions()
for _, opt := range opts { for _, opt := range opts {
opt(options) opt(options)
} }
sign := make(chan error, 1)
var berr errorx.BatchError var berr errorx.BatchError
var cancelFunc context.CancelFunc
ctx := context.Background()
if options.timeout > 0 {
ctx, cancelFunc = context.WithTimeout(ctx, options.timeout)
defer cancelFunc()
}
for i := 0; i < options.times; i++ { for i := 0; i < options.times; i++ {
if err := fn(); err != nil { go func(retryCount int) {
berr.Add(err) switch f := fn.(type) {
} else { case func() error:
return nil sign <- f()
case func(ctx context.Context, retryCount int) error:
sign <- f(ctx, retryCount)
}
}(i)
select {
case err := <-sign:
if err != nil {
berr.Add(err)
} else {
return nil
}
case <-ctx.Done():
berr.Add(errTimeout)
return berr.Err()
}
if options.interval > 0 {
select {
case <-ctx.Done():
berr.Add(errTimeout)
return berr.Err()
case <-time.After(options.interval):
}
} }
} }
@@ -39,8 +95,22 @@ func WithRetry(times int) RetryOption {
} }
} }
func newRetryOptions() *retryOptions { func WithInterval(interval time.Duration) RetryOption {
return &retryOptions{ return func(options *retryOptions) {
times: defaultRetryTimes, options.interval = interval
}
}
func WithTimeout(timeout time.Duration) RetryOption {
return func(options *retryOptions) {
options.timeout = timeout
}
}
func newRetryOptions() *retryOptions {
return &retryOptions{
times: defaultRetryTimes,
interval: 0,
timeout: 0,
} }
} }

View File

@@ -1,8 +1,10 @@
package fx package fx
import ( import (
"context"
"errors" "errors"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -12,31 +14,103 @@ func TestRetry(t *testing.T) {
return errors.New("any") return errors.New("any")
})) }))
var times int times1 := 0
assert.Nil(t, DoWithRetry(func() error { assert.Nil(t, DoWithRetry(func() error {
times++ times1++
if times == defaultRetryTimes { if times1 == defaultRetryTimes {
return nil return nil
} }
return errors.New("any") return errors.New("any")
})) }))
times = 0 times2 := 0
assert.NotNil(t, DoWithRetry(func() error { assert.NotNil(t, DoWithRetry(func() error {
times++ times2++
if times == defaultRetryTimes+1 { if times2 == defaultRetryTimes+1 {
return nil return nil
} }
return errors.New("any") return errors.New("any")
})) }))
total := 2 * defaultRetryTimes total := 2 * defaultRetryTimes
times = 0 times3 := 0
assert.Nil(t, DoWithRetry(func() error { assert.Nil(t, DoWithRetry(func() error {
times++ times3++
if times == total { if times3 == total {
return nil return nil
} }
return errors.New("any") return errors.New("any")
}, WithRetry(total))) }, WithRetry(total)))
} }
func TestRetryWithTimeout(t *testing.T) {
assert.Nil(t, DoWithRetry(func() error {
return nil
}, WithTimeout(time.Second*10)))
times1 := 0
assert.Nil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any ")
}
time.Sleep(time.Second * 3)
return nil
}, WithTimeout(time.Second*5)))
total := defaultRetryTimes
times2 := 0
assert.Nil(t, DoWithRetry(func() error {
times2++
if times2 == total {
return nil
}
time.Sleep(time.Second)
return errors.New("any")
}, WithTimeout(time.Second*(time.Duration(total)+2))))
assert.NotNil(t, DoWithRetry(func() error {
return errors.New("any")
}, WithTimeout(time.Second*5)))
}
func TestRetryWithInterval(t *testing.T) {
times1 := 0
assert.NotNil(t, DoWithRetry(func() error {
times1++
if times1 == 1 {
return errors.New("any")
}
time.Sleep(time.Second * 3)
return nil
}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
times2 := 0
assert.NotNil(t, DoWithRetry(func() error {
times2++
if times2 == 2 {
return nil
}
time.Sleep(time.Second * 3)
return errors.New("any ")
}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
}
func TestRetryCtx(t *testing.T) {
assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error {
if retryCount == 0 {
return errors.New("any")
}
time.Sleep(time.Second * 3)
return nil
}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error {
if retryCount == 1 {
return nil
}
time.Sleep(time.Second * 3)
return errors.New("any ")
}, WithTimeout(time.Second*5), WithInterval(time.Second*3)))
}