diff --git a/rpcx/internal/serverinterceptors/timeoutinterceptor_test.go b/rpcx/internal/serverinterceptors/timeoutinterceptor_test.go index 84626185..0a442907 100644 --- a/rpcx/internal/serverinterceptors/timeoutinterceptor_test.go +++ b/rpcx/internal/serverinterceptors/timeoutinterceptor_test.go @@ -2,6 +2,7 @@ package serverinterceptors import ( "context" + "sync" "testing" "time" @@ -13,9 +14,28 @@ func TestUnaryTimeoutInterceptor(t *testing.T) { interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10) _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ FullMethod: "/", - }, func( - ctx context.Context, req interface{}) (interface{}, error) { + }, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) assert.Nil(t, err) } + +func TestUnaryTimeoutInterceptor_timeout(t *testing.T) { + const timeout = time.Millisecond * 10 + interceptor := UnaryTimeoutInterceptor(timeout) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + var wg sync.WaitGroup + wg.Add(1) + _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{ + FullMethod: "/", + }, func(ctx context.Context, req interface{}) (interface{}, error) { + defer wg.Done() + tm, ok := ctx.Deadline() + assert.True(t, ok) + assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond))) + return nil, nil + }) + wg.Wait() + assert.Nil(t, err) +} diff --git a/rpcx/internal/serverinterceptors/tracinginterceptor_test.go b/rpcx/internal/serverinterceptors/tracinginterceptor_test.go new file mode 100644 index 00000000..86fbac36 --- /dev/null +++ b/rpcx/internal/serverinterceptors/tracinginterceptor_test.go @@ -0,0 +1,48 @@ +package serverinterceptors + +import ( + "context" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tal-tech/go-zero/core/trace/tracespec" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +func TestUnaryTracingInterceptor(t *testing.T) { + interceptor := UnaryTracingInterceptor("foo") + var run int32 + var wg sync.WaitGroup + wg.Add(1) + _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{ + FullMethod: "/", + }, func(ctx context.Context, req interface{}) (interface{}, error) { + defer wg.Done() + atomic.AddInt32(&run, 1) + return nil, nil + }) + wg.Wait() + assert.Nil(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&run)) +} + +func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) { + interceptor := UnaryTracingInterceptor("foo") + var wg sync.WaitGroup + wg.Add(1) + var md metadata.MD + ctx := metadata.NewIncomingContext(context.Background(), md) + _, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{ + FullMethod: "/", + }, func(ctx context.Context, req interface{}) (interface{}, error) { + defer wg.Done() + assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).TraceId()) > 0) + assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).SpanId()) > 0) + return nil, nil + }) + wg.Wait() + assert.Nil(t, err) +}