diff --git a/core/stores/cache/cache.go b/core/stores/cache/cache.go index 05346e83..c49cc020 100644 --- a/core/stores/cache/cache.go +++ b/core/stores/cache/cache.go @@ -9,6 +9,7 @@ import ( "github.com/zeromicro/go-zero/core/errorx" "github.com/zeromicro/go-zero/core/hash" + "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/syncx" ) @@ -62,12 +63,12 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error, } if len(c) == 1 { - return NewNode(c[0].NewRedis(), barrier, st, errNotFound, opts...) + return NewNode(redis.MustNewRedis(c[0].RedisConf), barrier, st, errNotFound, opts...) } dispatcher := hash.NewConsistentHash() for _, node := range c { - cn := NewNode(node.NewRedis(), barrier, st, errNotFound, opts...) + cn := NewNode(redis.MustNewRedis(node.RedisConf), barrier, st, errNotFound, opts...) dispatcher.AddWithWeight(cn, node.Weight) } diff --git a/core/stores/cache/cache_test.go b/core/stores/cache/cache_test.go index 9c5b843a..a7b50d19 100644 --- a/core/stores/cache/cache_test.go +++ b/core/stores/cache/cache_test.go @@ -163,12 +163,10 @@ func TestCache_SetDel(t *testing.T) { r1, err := miniredis.Run() assert.NoError(t, err) defer r1.Close() - r1.SetError("mock error") r2, err := miniredis.Run() assert.NoError(t, err) defer r2.Close() - r2.SetError("mock error") conf := ClusterConf{ { @@ -187,6 +185,8 @@ func TestCache_SetDel(t *testing.T) { }, } c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder) + r1.SetError("mock error") + r2.SetError("mock error") assert.NoError(t, c.Del("a", "b", "c")) }) } diff --git a/core/stores/kv/store.go b/core/stores/kv/store.go index fcaf07c1..7831ebe1 100644 --- a/core/stores/kv/store.go +++ b/core/stores/kv/store.go @@ -164,7 +164,7 @@ func NewStore(c KvConf) Store { // because Store and redis.Redis has different methods. dispatcher := hash.NewConsistentHash() for _, node := range c { - cn := node.NewRedis() + cn := redis.MustNewRedis(node.RedisConf) dispatcher.AddWithWeight(cn, node.Weight) } diff --git a/core/stores/redis/conf.go b/core/stores/redis/conf.go index f8ee90c6..711ce817 100644 --- a/core/stores/redis/conf.go +++ b/core/stores/redis/conf.go @@ -9,6 +9,8 @@ var ( ErrEmptyType = errors.New("empty redis type") // ErrEmptyKey is an error that indicates no redis key is set. ErrEmptyKey = errors.New("empty redis key") + // ErrPing is an error that indicates ping failed. + ErrPing = errors.New("ping redis failed") ) type ( @@ -27,7 +29,7 @@ type ( } ) -// NewRedis returns a Redis. +// Deprecated: use MustNewRedis or NewRedis instead. func (rc RedisConf) NewRedis() *Redis { var opts []Option if rc.Type == ClusterType { diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index f6884efe..d36b9463 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log" "strconv" "time" @@ -85,8 +86,46 @@ type ( StringCmd = red.StringCmd ) -// New returns a Redis with given options. +// Deprecated: use MustNewRedis or NewRedis instead. func New(addr string, opts ...Option) *Redis { + return newRedis(addr, opts...) +} + +// MustNewRedis returns a Redis with given options. +func MustNewRedis(conf RedisConf, opts ...Option) *Redis { + rds, err := NewRedis(conf, opts...) + if err != nil { + log.Fatal(err) + } + + return rds +} + +// NewRedis returns a Redis with given options. +func NewRedis(conf RedisConf, opts ...Option) (*Redis, error) { + if err := conf.Validate(); err != nil { + return nil, err + } + + if conf.Type == ClusterType { + opts = append([]Option{Cluster()}, opts...) + } + if len(conf.Pass) > 0 { + opts = append([]Option{WithPass(conf.Pass)}, opts...) + } + if conf.Tls { + opts = append([]Option{WithTLS()}, opts...) + } + + rds := newRedis(conf.Host, opts...) + if !rds.Ping() { + return nil, ErrPing + } + + return rds, nil +} + +func newRedis(addr string, opts ...Option) *Redis { r := &Redis{ Addr: addr, Type: NodeType, diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go index 2c7e52e2..d89ec37b 100644 --- a/core/stores/redis/redis_test.go +++ b/core/stores/redis/redis_test.go @@ -16,6 +16,116 @@ import ( "github.com/zeromicro/go-zero/core/stringx" ) +func TestNewRedis(t *testing.T) { + r1, err := miniredis.Run() + assert.NoError(t, err) + defer r1.Close() + + r2, err := miniredis.Run() + assert.NoError(t, err) + defer r2.Close() + r2.SetError("mock") + + tests := []struct { + name string + RedisConf + ok bool + redisErr bool + }{ + { + name: "missing host", + RedisConf: RedisConf{ + Host: "", + Type: NodeType, + Pass: "", + }, + ok: false, + }, + { + name: "missing type", + RedisConf: RedisConf{ + Host: "localhost:6379", + Type: "", + Pass: "", + }, + ok: false, + }, + { + name: "ok", + RedisConf: RedisConf{ + Host: r1.Addr(), + Type: NodeType, + Pass: "", + }, + ok: true, + }, + { + name: "ok", + RedisConf: RedisConf{ + Host: r1.Addr(), + Type: ClusterType, + Pass: "", + }, + ok: true, + }, + { + name: "password", + RedisConf: RedisConf{ + Host: r1.Addr(), + Type: NodeType, + Pass: "pw", + }, + ok: true, + }, + { + name: "tls", + RedisConf: RedisConf{ + Host: r1.Addr(), + Type: NodeType, + Tls: true, + }, + ok: true, + }, + { + name: "node error", + RedisConf: RedisConf{ + Host: r2.Addr(), + Type: NodeType, + Pass: "", + }, + ok: true, + redisErr: true, + }, + { + name: "cluster error", + RedisConf: RedisConf{ + Host: r2.Addr(), + Type: ClusterType, + Pass: "", + }, + ok: true, + redisErr: true, + }, + } + + for _, test := range tests { + t.Run(stringx.RandId(), func(t *testing.T) { + rds, err := NewRedis(test.RedisConf) + if test.ok { + if test.redisErr { + assert.Error(t, err) + assert.Nil(t, rds) + } else { + assert.NoError(t, err) + assert.NotNil(t, rds) + } + } else { + assert.Error(t, err) + } + }) + } +} + func TestRedis_Decr(t *testing.T) { runOnRedis(t, func(client *Redis) { _, err := New(client.Addr, badType()).Decr("a") diff --git a/zrpc/server.go b/zrpc/server.go index 99a637ba..5327b83e 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -7,6 +7,7 @@ import ( "github.com/zeromicro/go-zero/core/load" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stat" + "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/zrpc/internal" "github.com/zeromicro/go-zero/zrpc/internal/auth" "github.com/zeromicro/go-zero/zrpc/internal/serverinterceptors" @@ -120,7 +121,12 @@ func setupInterceptors(server internal.Server, c RpcServerConf, metrics *stat.Me } if c.Auth { - authenticator, err := auth.NewAuthenticator(c.Redis.NewRedis(), c.Redis.Key, c.StrictControl) + rds, err := redis.NewRedis(c.Redis.RedisConf) + if err != nil { + return err + } + + authenticator, err := auth.NewAuthenticator(rds, c.Redis.Key, c.StrictControl) if err != nil { return err } diff --git a/zrpc/server_test.go b/zrpc/server_test.go index 4cdcb8d8..d35f6c3a 100644 --- a/zrpc/server_test.go +++ b/zrpc/server_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/discov" "github.com/zeromicro/go-zero/core/logx" @@ -16,12 +17,16 @@ import ( ) func TestServer_setupInterceptors(t *testing.T) { + rds, err := miniredis.Run() + assert.NoError(t, err) + defer rds.Close() + server := new(mockedServer) - err := setupInterceptors(server, RpcServerConf{ + conf := RpcServerConf{ Auth: true, Redis: redis.RedisKeyConf{ RedisConf: redis.RedisConf{ - Host: "any", + Host: rds.Addr(), Type: redis.NodeType, }, Key: "foo", @@ -35,10 +40,15 @@ func TestServer_setupInterceptors(t *testing.T) { Prometheus: true, Breaker: true, }, - }, new(stat.Metrics)) + } + err = setupInterceptors(server, conf, new(stat.Metrics)) assert.Nil(t, err) assert.Equal(t, 3, len(server.unaryInterceptors)) assert.Equal(t, 1, len(server.streamInterceptors)) + + rds.SetError("mock error") + err = setupInterceptors(server, conf, new(stat.Metrics)) + assert.Error(t, err) } func TestServer(t *testing.T) {