feat: support the specified timeout of rpc methods (#2742)

Co-authored-by: hanzijian <hanzijian@52tt.com>
Co-authored-by: Kevin Wan <wanjunfeng@gmail.com>
This commit is contained in:
vankillua
2023-10-25 21:01:57 +08:00
committed by GitHub
parent 2a335c7608
commit 842c4d81cc
10 changed files with 378 additions and 29 deletions

View File

@@ -14,11 +14,23 @@ import (
"google.golang.org/grpc/status"
)
type (
// ServerSpecifiedTimeoutConf defines specified timeout for gRPC method.
ServerSpecifiedTimeoutConf struct {
FullMethod string
Timeout time.Duration
}
specifiedTimeoutCache map[string]time.Duration
)
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
func UnaryTimeoutInterceptor(timeout time.Duration, specifiedTimeouts ...ServerSpecifiedTimeoutConf) grpc.UnaryServerInterceptor {
cache := cacheSpecifiedTimeout(specifiedTimeouts)
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (any, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
t := getTimeoutByUnaryServerInfo(info, timeout, cache)
ctx, cancel := context.WithTimeout(ctx, t)
defer cancel()
var resp any
@@ -59,3 +71,28 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor
}
}
}
func cacheSpecifiedTimeout(specifiedTimeouts []ServerSpecifiedTimeoutConf) specifiedTimeoutCache {
cache := make(specifiedTimeoutCache, len(specifiedTimeouts))
for _, st := range specifiedTimeouts {
if st.FullMethod != "" {
cache[st.FullMethod] = st.Timeout
}
}
return cache
}
func getTimeoutByUnaryServerInfo(info *grpc.UnaryServerInfo, defaultTimeout time.Duration, specifiedTimeout specifiedTimeoutCache) time.Duration {
if ts, ok := info.Server.(TimeoutStrategy); ok {
return ts.GetTimeoutByFullMethod(info.FullMethod, defaultTimeout)
} else if v, ok := specifiedTimeout[info.FullMethod]; ok {
return v
}
return defaultTimeout
}
type TimeoutStrategy interface {
GetTimeoutByFullMethod(fullMethod string, defaultTimeout time.Duration) time.Duration
}