chore: refactor zrpc timeout (#3671)
This commit is contained in:
@@ -15,21 +15,22 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
|
||||
ServerSpecifiedTimeoutConf struct {
|
||||
// MethodTimeoutConf defines specified timeout for gRPC method.
|
||||
MethodTimeoutConf struct {
|
||||
FullMethod string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
specifiedTimeoutCache map[string]time.Duration
|
||||
methodTimeouts map[string]time.Duration
|
||||
)
|
||||
|
||||
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
|
||||
func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor {
|
||||
cache := cacheSpecifiedTimeout(specifiedTimeouts)
|
||||
func UnaryTimeoutInterceptor(timeout time.Duration,
|
||||
methodTimeouts ...MethodTimeoutConf) grpc.UnaryServerInterceptor {
|
||||
timeouts := buildMethodTimeouts(methodTimeouts)
|
||||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (any, error) {
|
||||
t := getTimeoutByUnaryServerInfo(info, timeout, cache)
|
||||
t := getTimeoutByUnaryServerInfo(info.FullMethod, timeouts, timeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, t)
|
||||
defer cancel()
|
||||
|
||||
@@ -72,27 +73,22 @@ func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerS
|
||||
}
|
||||
}
|
||||
|
||||
func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache {
|
||||
cache := make(specifiedTimeoutCache, len(specifiedTimeouts))
|
||||
for _, st := range specifiedTimeouts {
|
||||
func buildMethodTimeouts(timeouts []MethodTimeoutConf) methodTimeouts {
|
||||
mt := make(methodTimeouts, len(timeouts))
|
||||
for _, st := range timeouts {
|
||||
if st.FullMethod != "" {
|
||||
cache[st.FullMethod] = st.Timeout
|
||||
mt[st.FullMethod] = st.Timeout
|
||||
}
|
||||
}
|
||||
|
||||
return cache
|
||||
return mt
|
||||
}
|
||||
|
||||
func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration {
|
||||
if ts, ok := info.Server.(TimeoutStrategy); ok {
|
||||
return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout)
|
||||
} else if v, ok := specifiedTimeout[info.FullMethod]; ok {
|
||||
func getTimeoutByUnaryServerInfo(method string, timeouts methodTimeouts,
|
||||
defaultTimeout time.Duration) time.Duration {
|
||||
if v, ok := timeouts[method]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
type TimeoutStrategy interface {
|
||||
GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration
|
||||
}
|
||||
|
||||
@@ -103,13 +103,6 @@ type tempServer struct {
|
||||
func (s *tempServer) run(duration time.Duration) {
|
||||
time.Sleep(duration)
|
||||
}
|
||||
func (s *tempServer) GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration {
|
||||
if fullMethod == "/" {
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
return s.timeout
|
||||
}
|
||||
|
||||
func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
|
||||
type args struct {
|
||||
@@ -136,17 +129,6 @@ func TestUnaryTimeoutInterceptor_TimeoutStrategy(t *testing.T) {
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "do not timeout with timeout strategy",
|
||||
args: args{
|
||||
interceptorTimeout: time.Second,
|
||||
contextTimeout: time.Second * 5,
|
||||
serverTimeout: time.Second * 3,
|
||||
runTime: time.Second * 2,
|
||||
fullMethod: "/2s",
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "timeout with interceptor timeout",
|
||||
args: args{
|
||||
@@ -235,9 +217,9 @@ func TestUnaryTimeoutInterceptor_SpecifiedTimeout(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var specifiedTimeouts []ServerSpecifiedTimeoutConf
|
||||
var specifiedTimeouts []MethodTimeoutConf
|
||||
if tt.args.methodTimeout > 0 {
|
||||
specifiedTimeouts = []ServerSpecifiedTimeoutConf{
|
||||
specifiedTimeouts = []MethodTimeoutConf{
|
||||
{
|
||||
FullMethod: tt.args.method,
|
||||
Timeout: tt.args.methodTimeout,
|
||||
|
||||
Reference in New Issue
Block a user