diff --git a/core/stores/redis/redislock.go b/core/stores/redis/redislock.go index 53c3aea6..e53c70e2 100644 --- a/core/stores/redis/redislock.go +++ b/core/stores/redis/redislock.go @@ -1,12 +1,14 @@ package redis import ( + "context" "math/rand" "strconv" "sync/atomic" "time" red "github.com/go-redis/redis/v8" + "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stringx" ) @@ -34,6 +36,7 @@ type RedisLock struct { seconds uint32 key string id string + ctx context.Context } func init() { @@ -51,8 +54,9 @@ func NewRedisLock(store *Redis, key string) *RedisLock { // Acquire acquires the lock. func (rl *RedisLock) Acquire() (bool, error) { + rl.fillCtx() seconds := atomic.LoadUint32(&rl.seconds) - resp, err := rl.store.Eval(lockCommand, []string{rl.key}, []string{ + resp, err := rl.store.EvalCtx(rl.ctx, lockCommand, []string{rl.key}, []string{ rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance), }) if err == red.Nil { @@ -75,7 +79,8 @@ func (rl *RedisLock) Acquire() (bool, error) { // Release releases the lock. func (rl *RedisLock) Release() (bool, error) { - resp, err := rl.store.Eval(delCommand, []string{rl.key}, []string{rl.id}) + rl.fillCtx() + resp, err := rl.store.EvalCtx(rl.ctx, delCommand, []string{rl.key}, []string{rl.id}) if err != nil { return false, err } @@ -92,3 +97,14 @@ func (rl *RedisLock) Release() (bool, error) { func (rl *RedisLock) SetExpire(seconds int) { atomic.StoreUint32(&rl.seconds, uint32(seconds)) } + +// WithContext set context. +func (rl *RedisLock) WithContext(ctx context.Context) { + rl.ctx = ctx +} + +func (rl *RedisLock) fillCtx() { + if rl.ctx == nil { + rl.ctx = context.Background() + } +} diff --git a/core/stores/redis/redislock_test.go b/core/stores/redis/redislock_test.go index 55ec9ef7..410b46a7 100644 --- a/core/stores/redis/redislock_test.go +++ b/core/stores/redis/redislock_test.go @@ -1,33 +1,51 @@ package redis import ( + "context" "testing" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/stringx" ) func TestRedisLock(t *testing.T) { - runOnRedis(t, func(client *Redis) { - key := stringx.Rand() - firstLock := NewRedisLock(client, key) - firstLock.SetExpire(5) - firstAcquire, err := firstLock.Acquire() - assert.Nil(t, err) - assert.True(t, firstAcquire) + testFn := func(ctx context.Context) func(client *Redis) { + return func(client *Redis) { + key := stringx.Rand() + firstLock := NewRedisLock(client, key) + if ctx != nil { + firstLock.WithContext(ctx) + } + firstLock.SetExpire(5) + firstAcquire, err := firstLock.Acquire() + assert.Nil(t, err) + assert.True(t, firstAcquire) - secondLock := NewRedisLock(client, key) - secondLock.SetExpire(5) - againAcquire, err := secondLock.Acquire() - assert.Nil(t, err) - assert.False(t, againAcquire) + secondLock := NewRedisLock(client, key) + if ctx != nil { + secondLock.WithContext(ctx) + } + secondLock.SetExpire(5) + againAcquire, err := secondLock.Acquire() + assert.Nil(t, err) + assert.False(t, againAcquire) - release, err := firstLock.Release() - assert.Nil(t, err) - assert.True(t, release) + release, err := firstLock.Release() + assert.Nil(t, err) + assert.True(t, release) - endAcquire, err := secondLock.Acquire() - assert.Nil(t, err) - assert.True(t, endAcquire) + endAcquire, err := secondLock.Acquire() + assert.Nil(t, err) + assert.True(t, endAcquire) + } + } + + t.Run("normal", func(t *testing.T) { + runOnRedis(t, testFn(nil)) + }) + + t.Run("withContext", func(t *testing.T) { + runOnRedis(t, testFn(context.Background())) }) }