redis增加tls支持 (#595)

* redis连接增加支持tls选项

* 优化redis tls config 写法

* redis增加tls支持

* 增加redis tls测试用例,但redis tls local server不支持,测试用例全部NotNil

Co-authored-by: liuyi <liuyi@fangyb.com>
Co-authored-by: yi.liu <yi.liu@xshoppy.com>
This commit is contained in:
r00mz
2021-04-07 20:44:16 +08:00
committed by GitHub
parent 05e37ee20f
commit 8cb6490724
5 changed files with 91 additions and 14 deletions

View File

@@ -14,9 +14,10 @@ var (
type ( type (
// A RedisConf is a redis config. // A RedisConf is a redis config.
RedisConf struct { RedisConf struct {
Host string Host string
Type string `json:",default=node,options=node|cluster"` Type string `json:",default=node,options=node|cluster"`
Pass string `json:",optional"` Pass string `json:",optional"`
TLSFlag bool `json:",default=false,options=true|false"`
} }
// A RedisKeyConf is a redis config with key. // A RedisKeyConf is a redis config with key.
@@ -28,6 +29,9 @@ type (
// NewRedis returns a Redis. // NewRedis returns a Redis.
func (rc RedisConf) NewRedis() *Redis { func (rc RedisConf) NewRedis() *Redis {
if rc.TLSFlag {
return NewRedisWithTLS(rc.Host, rc.Type, rc.TLSFlag, rc.Pass)
}
return NewRedis(rc.Host, rc.Type, rc.Pass) return NewRedis(rc.Host, rc.Type, rc.Pass)
} }

View File

@@ -37,10 +37,11 @@ type (
// Redis defines a redis node/cluster. It is thread-safe. // Redis defines a redis node/cluster. It is thread-safe.
Redis struct { Redis struct {
Addr string Addr string
Type string Type string
Pass string Pass string
brk breaker.Breaker brk breaker.Breaker
TLSFlag bool
} }
// RedisNode interface represents a redis node. // RedisNode interface represents a redis node.
@@ -71,16 +72,21 @@ type (
// NewRedis returns a Redis. // NewRedis returns a Redis.
func NewRedis(redisAddr, redisType string, redisPass ...string) *Redis { func NewRedis(redisAddr, redisType string, redisPass ...string) *Redis {
return NewRedisWithTLS(redisAddr, redisType, false, redisPass...)
}
func NewRedisWithTLS(redisAddr, redisType string, tlsFlag bool, redisPass ...string) *Redis {
var pass string var pass string
for _, v := range redisPass { for _, v := range redisPass {
pass = v pass = v
} }
return &Redis{ return &Redis{
Addr: redisAddr, Addr: redisAddr,
Type: redisType, Type: redisType,
Pass: pass, Pass: pass,
brk: breaker.NewBreaker(), brk: breaker.NewBreaker(),
TLSFlag: tlsFlag,
} }
} }
@@ -1704,9 +1710,17 @@ func acceptable(err error) bool {
func getRedis(r *Redis) (RedisNode, error) { func getRedis(r *Redis) (RedisNode, error) {
switch r.Type { switch r.Type {
case ClusterType: case ClusterType:
return getCluster(r.Addr, r.Pass) if r.TLSFlag {
return getClusterWithTLS(r.Addr, r.Pass, r.TLSFlag)
} else {
return getCluster(r.Addr, r.Pass)
}
case NodeType: case NodeType:
return getClient(r.Addr, r.Pass) if r.TLSFlag {
return getClientWithTLS(r.Addr, r.Pass, r.TLSFlag)
} else {
return getClient(r.Addr, r.Pass)
}
default: default:
return nil, fmt.Errorf("redis type '%s' is not supported", r.Type) return nil, fmt.Errorf("redis type '%s' is not supported", r.Type)
} }

View File

@@ -1,6 +1,7 @@
package redis package redis
import ( import (
"crypto/tls"
"errors" "errors"
"io" "io"
"strconv" "strconv"
@@ -26,6 +27,20 @@ func TestRedis_Exists(t *testing.T) {
}) })
} }
func TestRedisTLS_Exists(t *testing.T) {
runOnRedisTLS(t, func(client *Redis) {
_, err := NewRedisWithTLS(client.Addr, "", true).Exists("a")
assert.NotNil(t, err)
ok, err := client.Exists("a")
assert.NotNil(t, err)
assert.False(t, ok)
assert.NotNil(t, client.Set("a", "b"))
ok, err = client.Exists("a")
assert.NotNil(t, err)
assert.False(t, ok)
})
}
func TestRedis_Eval(t *testing.T) { func TestRedis_Eval(t *testing.T) {
runOnRedis(t, func(client *Redis) { runOnRedis(t, func(client *Redis) {
_, err := NewRedis(client.Addr, "").Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"}) _, err := NewRedis(client.Addr, "").Eval(`redis.call("EXISTS", KEYS[1])`, []string{"notexist"})
@@ -1062,8 +1077,28 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) {
client.Close() client.Close()
} }
}() }()
fn(NewRedis(s.Addr(), NodeType)) fn(NewRedis(s.Addr(), NodeType))
}
func runOnRedisTLS(t *testing.T, fn func(client *Redis)) {
s, err := miniredis.RunTLS(&tls.Config{
Certificates: make([]tls.Certificate, 1),
InsecureSkipVerify: true,
})
assert.Nil(t, err)
defer func() {
client, err := clientManager.GetResource(s.Addr(), func() (io.Closer, error) {
return nil, errors.New("should already exist")
})
if err != nil {
t.Error(err)
}
if client != nil {
client.Close()
}
}()
fn(NewRedisWithTLS(s.Addr(), NodeType, true))
} }
type mockedNode struct { type mockedNode struct {

View File

@@ -1,6 +1,7 @@
package redis package redis
import ( import (
"crypto/tls"
"io" "io"
red "github.com/go-redis/redis" red "github.com/go-redis/redis"
@@ -16,13 +17,24 @@ const (
var clientManager = syncx.NewResourceManager() var clientManager = syncx.NewResourceManager()
func getClient(server, pass string) (*red.Client, error) { func getClient(server, pass string) (*red.Client, error) {
return getClientWithTLS(server, pass, false)
}
func getClientWithTLS(server, pass string, tlsFlag bool) (*red.Client, error) {
val, err := clientManager.GetResource(server, func() (io.Closer, error) { val, err := clientManager.GetResource(server, func() (io.Closer, error) {
var tlsConfig *tls.Config = nil
if tlsFlag {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
store := red.NewClient(&red.Options{ store := red.NewClient(&red.Options{
Addr: server, Addr: server,
Password: pass, Password: pass,
DB: defaultDatabase, DB: defaultDatabase,
MaxRetries: maxRetries, MaxRetries: maxRetries,
MinIdleConns: idleConns, MinIdleConns: idleConns,
TLSConfig: tlsConfig,
}) })
store.WrapProcess(process) store.WrapProcess(process)
return store, nil return store, nil

View File

@@ -1,6 +1,7 @@
package redis package redis
import ( import (
"crypto/tls"
"io" "io"
red "github.com/go-redis/redis" red "github.com/go-redis/redis"
@@ -10,12 +11,23 @@ import (
var clusterManager = syncx.NewResourceManager() var clusterManager = syncx.NewResourceManager()
func getCluster(server, pass string) (*red.ClusterClient, error) { func getCluster(server, pass string) (*red.ClusterClient, error) {
return getClusterWithTLS(server, pass, false)
}
func getClusterWithTLS(server, pass string, tlsFlag bool) (*red.ClusterClient, error) {
val, err := clusterManager.GetResource(server, func() (io.Closer, error) { val, err := clusterManager.GetResource(server, func() (io.Closer, error) {
var tlsConfig *tls.Config = nil
if tlsFlag {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
store := red.NewClusterClient(&red.ClusterOptions{ store := red.NewClusterClient(&red.ClusterOptions{
Addrs: []string{server}, Addrs: []string{server},
Password: pass, Password: pass,
MaxRetries: maxRetries, MaxRetries: maxRetries,
MinIdleConns: idleConns, MinIdleConns: idleConns,
TLSConfig: tlsConfig,
}) })
store.WrapProcess(process) store.WrapProcess(process)