diff --git a/core/fx/retry.go b/core/fx/retry.go index 2e1bda8c..035be0e2 100644 --- a/core/fx/retry.go +++ b/core/fx/retry.go @@ -1,31 +1,87 @@ package fx -import "github.com/zeromicro/go-zero/core/errorx" +import ( + "context" + "errors" + "time" + + "github.com/zeromicro/go-zero/core/errorx" +) const defaultRetryTimes = 3 +var errTimeout = errors.New("retry timeout") + type ( // RetryOption defines the method to customize DoWithRetry. RetryOption func(*retryOptions) retryOptions struct { - times int + times int + interval time.Duration + timeout time.Duration } ) // DoWithRetry runs fn, and retries if failed. Default to retry 3 times. +// Note that if the fn function accesses global variables outside the function and performs modification operations, +// it is best to lock them, otherwise there may be data race issues func DoWithRetry(fn func() error, opts ...RetryOption) error { + return retry(fn, opts...) +} + +// DoWithRetryCtx runs fn, and retries if failed. Default to retry 3 times. +// fn retryCount indicates the current number of retries,starting from 0 +// Note that if the fn function accesses global variables outside the function and performs modification operations, +// it is best to lock them, otherwise there may be data race issues +func DoWithRetryCtx(fn func(ctx context.Context, retryCount int) error, opts ...RetryOption) error { + return retry(fn, opts...) +} + +func retry(fn interface{}, opts ...RetryOption) error { options := newRetryOptions() for _, opt := range opts { opt(options) } + sign := make(chan error, 1) var berr errorx.BatchError + var cancelFunc context.CancelFunc + ctx := context.Background() + if options.timeout > 0 { + ctx, cancelFunc = context.WithTimeout(ctx, options.timeout) + defer cancelFunc() + } + for i := 0; i < options.times; i++ { - if err := fn(); err != nil { - berr.Add(err) - } else { - return nil + go func(retryCount int) { + switch f := fn.(type) { + case func() error: + sign <- f() + case func(ctx context.Context, retryCount int) error: + sign <- f(ctx, retryCount) + } + }(i) + + select { + case err := <-sign: + if err != nil { + berr.Add(err) + } else { + return nil + } + case <-ctx.Done(): + berr.Add(errTimeout) + return berr.Err() + } + + if options.interval > 0 { + select { + case <-ctx.Done(): + berr.Add(errTimeout) + return berr.Err() + case <-time.After(options.interval): + } } } @@ -39,8 +95,22 @@ func WithRetry(times int) RetryOption { } } -func newRetryOptions() *retryOptions { - return &retryOptions{ - times: defaultRetryTimes, +func WithInterval(interval time.Duration) RetryOption { + return func(options *retryOptions) { + options.interval = interval + } +} + +func WithTimeout(timeout time.Duration) RetryOption { + return func(options *retryOptions) { + options.timeout = timeout + } +} + +func newRetryOptions() *retryOptions { + return &retryOptions{ + times: defaultRetryTimes, + interval: 0, + timeout: 0, } } diff --git a/core/fx/retry_test.go b/core/fx/retry_test.go index 520ec914..2686ca17 100644 --- a/core/fx/retry_test.go +++ b/core/fx/retry_test.go @@ -1,8 +1,10 @@ package fx import ( + "context" "errors" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -12,31 +14,103 @@ func TestRetry(t *testing.T) { return errors.New("any") })) - var times int + times1 := 0 assert.Nil(t, DoWithRetry(func() error { - times++ - if times == defaultRetryTimes { + times1++ + if times1 == defaultRetryTimes { return nil } return errors.New("any") })) - times = 0 + times2 := 0 assert.NotNil(t, DoWithRetry(func() error { - times++ - if times == defaultRetryTimes+1 { + times2++ + if times2 == defaultRetryTimes+1 { return nil } return errors.New("any") })) total := 2 * defaultRetryTimes - times = 0 + times3 := 0 assert.Nil(t, DoWithRetry(func() error { - times++ - if times == total { + times3++ + if times3 == total { return nil } return errors.New("any") }, WithRetry(total))) } + +func TestRetryWithTimeout(t *testing.T) { + assert.Nil(t, DoWithRetry(func() error { + return nil + }, WithTimeout(time.Second*10))) + + times1 := 0 + assert.Nil(t, DoWithRetry(func() error { + times1++ + if times1 == 1 { + return errors.New("any ") + } + time.Sleep(time.Second * 3) + return nil + }, WithTimeout(time.Second*5))) + + total := defaultRetryTimes + times2 := 0 + assert.Nil(t, DoWithRetry(func() error { + times2++ + if times2 == total { + return nil + } + time.Sleep(time.Second) + return errors.New("any") + }, WithTimeout(time.Second*(time.Duration(total)+2)))) + + assert.NotNil(t, DoWithRetry(func() error { + return errors.New("any") + }, WithTimeout(time.Second*5))) +} + +func TestRetryWithInterval(t *testing.T) { + times1 := 0 + assert.NotNil(t, DoWithRetry(func() error { + times1++ + if times1 == 1 { + return errors.New("any") + } + time.Sleep(time.Second * 3) + return nil + }, WithTimeout(time.Second*5), WithInterval(time.Second*3))) + + times2 := 0 + assert.NotNil(t, DoWithRetry(func() error { + times2++ + if times2 == 2 { + return nil + } + time.Sleep(time.Second * 3) + return errors.New("any ") + }, WithTimeout(time.Second*5), WithInterval(time.Second*3))) + +} + +func TestRetryCtx(t *testing.T) { + assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error { + if retryCount == 0 { + return errors.New("any") + } + time.Sleep(time.Second * 3) + return nil + }, WithTimeout(time.Second*5), WithInterval(time.Second*3))) + + assert.NotNil(t, DoWithRetryCtx(func(ctx context.Context, retryCount int) error { + if retryCount == 1 { + return nil + } + time.Sleep(time.Second * 3) + return errors.New("any ") + }, WithTimeout(time.Second*5), WithInterval(time.Second*3))) +}