chore: refactor zrpc timeout (#3671)

This commit is contained in:
Kevin Wan
2023-10-26 08:55:26 +08:00
committed by GitHub
parent 842c4d81cc
commit 922efbfc2d
10 changed files with 63 additions and 87 deletions

View File

@@ -7,11 +7,17 @@ import (
"google.golang.org/grpc"
)
// TimeoutCallOption is a call option that controls timeout.
type TimeoutCallOption struct {
grpc.EmptyCallOption
timeout time.Duration
}
// TimeoutInterceptor is an interceptor that controls timeout.
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
t := getTimeoutByCallOptions(opts, timeout)
t := getTimeoutFromCallOptions(opts, timeout)
if t <= 0 {
return invoker(ctx, method, req, reply, cc, opts...)
}
@@ -23,24 +29,19 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
}
}
func getTimeoutByCallOptions(callOptions []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
for _, callOption := range callOptions {
if o, ok := callOption.(TimeoutCallOption); ok {
// WithCallTimeout returns a call option that controls method call timeout.
func WithCallTimeout(timeout time.Duration) grpc.CallOption {
return TimeoutCallOption{
timeout: timeout,
}
}
func getTimeoutFromCallOptions(opts []grpc.CallOption, defaultTimeout time.Duration) time.Duration {
for _, opt := range opts {
if o, ok := opt.(TimeoutCallOption); ok {
return o.timeout
}
}
return defaultTimeout
}
type TimeoutCallOption struct {
grpc.EmptyCallOption
timeout time.Duration
}
func WithTimeoutCallOption(timeout time.Duration) grpc.CallOption {
return TimeoutCallOption{
timeout: timeout,
}
}

View File

@@ -114,7 +114,7 @@ func TestTimeoutInterceptor_TimeoutCallOption(t *testing.T) {
cc := new(grpc.ClientConn)
var co []grpc.CallOption
if tt.args.callOptionTimeout > 0 {
co = append(co, WithTimeoutCallOption(tt.args.callOptionTimeout))
co = append(co, WithCallTimeout(tt.args.callOptionTimeout))
}
err := interceptor(context.Background(), "/foo", nil, nil, cc,

View File

@@ -25,5 +25,6 @@ type (
Breaker bool `json:",default=true"`
}
ServerSpecifiedTimeoutConf = serverinterceptors.ServerSpecifiedTimeoutConf
// MethodTimeoutConf defines specified timeout for gRPC methods.
MethodTimeoutConf = serverinterceptors.MethodTimeoutConf
)

View File

@@ -15,21 +15,22 @@ import (
)
type (
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
ServerSpecifiedTimeoutConf struct {
// MethodTimeoutConf defines specified timeout for gRPC method.
MethodTimeoutConf struct {
FullMethod string
Timeout time.Duration
}
specifiedTimeoutCache map[string]time.Duration
methodTimeouts map[string]time.Duration
)
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor {
cache := cacheSpecifiedTimeout(specifiedTimeouts)
func UnaryTimeoutInterceptor(timeout time.Duration,
methodTimeouts ...MethodTimeoutConf) grpc.UnaryServerInterceptor {
timeouts := buildMethodTimeouts(methodTimeouts)
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error) {
t := getTimeoutByUnaryServerInfo(info, timeout, cache)
t := getTimeoutByUnaryServerInfo(info.FullMethod, timeouts, timeout)
ctx, cancel := context.WithTimeout(ctx, t)
defer cancel()
@@ -72,27 +73,22 @@ func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerS
}
}
func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache {
cache := make(specifiedTimeoutCache, len(specifiedTimeouts))
for _, st := range specifiedTimeouts {
func buildMethodTimeouts(timeouts []MethodTimeoutConf) methodTimeouts {
mt := make(methodTimeouts, len(timeouts))
for _, st := range timeouts {
if st.FullMethod != "" {
cache[st.FullMethod] = st.Timeout
mt[st.FullMethod] = st.Timeout
}
}
return cache
return mt
}
func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration {
if ts, ok := info.Server.(TimeoutStrategy); ok {
return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout)
} else if v, ok := specifiedTimeout[info.FullMethod]; ok {
func getTimeoutByUnaryServerInfo(method string, timeouts methodTimeouts,
defaultTimeout time.Duration) time.Duration {
if v, ok := timeouts[method]; ok {
return v
}
return defaultTimeout
}
type TimeoutStrategy interface {
GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration
}

View File

@@ -103,13 +103,6 @@ type tempServer struct {
func (s *tempServer) run(duration time.Duration) {
time.Sleep(duration)
}
func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
if fullMethod == "/" {
return defaultTimeout
}
return s.timeout
}
func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
type args struct {
@@ -136,17 +129,6 @@ func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
},
wantErr: nil,
},
{
name: "do not timeout with timeout strategy",
args: args{
interceptorTimeout: time.Second,
contextTimeout: time.Second * 5,
serverTimeout: time.Second * 3,
runTime: time.Second * 2,
fullMethod: "/2s",
},
wantErr: nil,
},
{
name: "timeout with interceptor timeout",
args: args{
@@ -235,9 +217,9 @@ func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var specifiedTimeouts []ServerSpecifiedTimeoutConf
var specifiedTimeouts []MethodTimeoutConf
if tt.args.methodTimeout > 0 {
specifiedTimeouts = []ServerSpecifiedTimeoutConf{
specifiedTimeouts = []MethodTimeoutConf{
{
FullMethod: tt.args.method,
Timeout: tt.args.methodTimeout,