From 3c894a3fb7d75340aa703ed0d8a2c45276a79b64 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Tue, 2 Nov 2021 20:42:22 +0800 Subject: [PATCH] feat: simplify the grpc tls authentication (#1199) --- zrpc/client.go | 6 ++-- zrpc/internal/client.go | 54 ++++++------------------------------ zrpc/internal/client_test.go | 7 +++++ 3 files changed, 17 insertions(+), 50 deletions(-) diff --git a/zrpc/client.go b/zrpc/client.go index 77b444cf..262c5523 100644 --- a/zrpc/client.go +++ b/zrpc/client.go @@ -20,12 +20,10 @@ var ( WithTimeout = internal.WithTimeout // WithRetry is an alias of internal.WithRetry. WithRetry = internal.WithRetry + // WithTransportCredentials return a func to make the gRPC calls secured with given credentials. + WithTransportCredentials = internal.WithTransportCredentials // WithUnaryClientInterceptor is an alias of internal.WithUnaryClientInterceptor. WithUnaryClientInterceptor = internal.WithUnaryClientInterceptor - // WithTlsClientFromUnilateral is an alias of internal.WithTlsClientFromUnilateral - WithTlsClientFromUnilateral = internal.WithTlsClientFromUnilateral - // WithTlsClientFromMutual is an alias of internal.WithTlsClientFromMutual - WithTlsClientFromMutual = internal.WithTlsClientFromMutual ) type ( diff --git a/zrpc/internal/client.go b/zrpc/internal/client.go index 2366fb73..68b0a837 100644 --- a/zrpc/internal/client.go +++ b/zrpc/internal/client.go @@ -2,12 +2,8 @@ package internal import ( "context" - "crypto/tls" - "crypto/x509" "errors" "fmt" - "io/ioutil" - "log" "strings" "time" @@ -147,51 +143,17 @@ func WithRetry() ClientOption { } } +// WithTransportCredentials return a func to make the gRPC calls secured with given credentials. +func WithTransportCredentials(creds credentials.TransportCredentials) ClientOption { + return func(options *ClientOptions) { + options.Secure = true + options.DialOptions = append(options.DialOptions, grpc.WithTransportCredentials(creds)) + } +} + // WithUnaryClientInterceptor returns a func to customize a ClientOptions with given interceptor. func WithUnaryClientInterceptor(interceptor grpc.UnaryClientInterceptor) ClientOption { return func(options *ClientOptions) { options.DialOptions = append(options.DialOptions, WithUnaryClientInterceptors(interceptor)) } } - -// WithTlsClientFromUnilateral return a func to customize a ClientOptions Verify with Unilateralism authentication. -func WithTlsClientFromUnilateral(crt, domainName string) ClientOption { - return func(options *ClientOptions) { - c, err := credentials.NewClientTLSFromFile(crt, domainName) - if err != nil { - log.Fatalf("credentials.NewClientTLSFromFile err: %v", err) - } - - options.Secure = true - options.DialOptions = append(options.DialOptions, grpc.WithTransportCredentials(c)) - } -} - -// WithTlsClientFromMutual return a func to customize a ClientOptions Verify with mutual authentication. -func WithTlsClientFromMutual(crtFile, keyFile, caFile string) ClientOption { - return func(options *ClientOptions) { - cert, err := tls.LoadX509KeyPair(crtFile, keyFile) - if err != nil { - log.Fatalf("tls.LoadX509KeyPair err: %v", err) - } - - certPool := x509.NewCertPool() - ca, err := ioutil.ReadFile(caFile) - if err != nil { - log.Fatalf("credentials: failed to ReadFile CA certificates err: %v", err) - } - - if !certPool.AppendCertsFromPEM(ca) { - log.Fatalf("credentials: failed to append certificates err: %v", err) - } - - config := &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: certPool, - } - - options.Secure = true - options.DialOptions = append(options.DialOptions, - grpc.WithTransportCredentials(credentials.NewTLS(config))) - } -} diff --git a/zrpc/internal/client_test.go b/zrpc/internal/client_test.go index 643f3076..cab758f9 100644 --- a/zrpc/internal/client_test.go +++ b/zrpc/internal/client_test.go @@ -38,6 +38,13 @@ func TestWithNonBlock(t *testing.T) { assert.True(t, options.NonBlock) } +func TestWithTransportCredentials(t *testing.T) { + var options ClientOptions + opt := WithTransportCredentials(nil) + opt(&options) + assert.Equal(t, 1, len(options.DialOptions)) +} + func TestWithUnaryClientInterceptor(t *testing.T) { var options ClientOptions opt := WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply interface{},