From ed1c937998c466987aae9086791836a7727ab388 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sat, 11 Jun 2022 23:07:26 +0800 Subject: [PATCH] feat: convert grpc errors to http status codes (#1997) * feat: convert grpc errors to http status codes * chore: circuit break include unimplemented grpc error * chore: add reference link in comments --- rest/httpx/responses.go | 6 + rest/httpx/responses_test.go | 12 ++ rest/internal/errcode/grpc.go | 55 ++++++++ rest/internal/errcode/grpc_test.go | 123 ++++++++++++++++++ zrpc/internal/codes/accept.go | 2 +- .../serverinterceptors/timeoutinterceptor.go | 1 - 6 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 rest/internal/errcode/grpc.go create mode 100644 rest/internal/errcode/grpc_test.go diff --git a/rest/httpx/responses.go b/rest/httpx/responses.go index 355de21a..e3b2c919 100644 --- a/rest/httpx/responses.go +++ b/rest/httpx/responses.go @@ -6,6 +6,7 @@ import ( "sync" "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/rest/internal/errcode" "github.com/zeromicro/go-zero/rest/internal/header" ) @@ -23,9 +24,14 @@ func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, if handler == nil { if len(fns) > 0 { fns[0](w, err) + } else if errcode.IsGrpcError(err) { + // don't unwrap error and get status.Message(), + // it hides the rpc error headers. + http.Error(w, err.Error(), errcode.CodeFromGrpcError(err)) } else { http.Error(w, err.Error(), http.StatusBadRequest) } + return } diff --git a/rest/httpx/responses_test.go b/rest/httpx/responses_test.go index 2be5a7f6..dcd08a8a 100644 --- a/rest/httpx/responses_test.go +++ b/rest/httpx/responses_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/logx" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type message struct { @@ -95,6 +97,16 @@ func TestError(t *testing.T) { } } +func TestErrorWithGrpcError(t *testing.T) { + w := tracedResponseWriter{ + headers: make(map[string][]string), + } + Error(&w, status.Error(codes.Unavailable, "foo")) + assert.Equal(t, http.StatusServiceUnavailable, w.code) + assert.True(t, w.hasBody) + assert.True(t, strings.Contains(w.builder.String(), "foo")) +} + func TestErrorWithHandler(t *testing.T) { w := tracedResponseWriter{ headers: make(map[string][]string), diff --git a/rest/internal/errcode/grpc.go b/rest/internal/errcode/grpc.go new file mode 100644 index 00000000..625345d2 --- /dev/null +++ b/rest/internal/errcode/grpc.go @@ -0,0 +1,55 @@ +package errcode + +import ( + "net/http" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// CodeFromGrpcError converts the gRPC error to an HTTP status code. +// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto +func CodeFromGrpcError(err error) int { + code := status.Code(err) + switch code { + case codes.OK: + return http.StatusOK + case codes.InvalidArgument, codes.FailedPrecondition, codes.OutOfRange: + return http.StatusBadRequest + case codes.Unauthenticated: + return http.StatusUnauthorized + case codes.PermissionDenied: + return http.StatusForbidden + case codes.NotFound: + return http.StatusNotFound + case codes.Canceled: + return http.StatusRequestTimeout + case codes.AlreadyExists, codes.Aborted: + return http.StatusConflict + case codes.ResourceExhausted: + return http.StatusTooManyRequests + case codes.Internal, codes.DataLoss, codes.Unknown: + return http.StatusInternalServerError + case codes.Unimplemented: + return http.StatusNotImplemented + case codes.Unavailable: + return http.StatusServiceUnavailable + case codes.DeadlineExceeded: + return http.StatusGatewayTimeout + } + + return http.StatusInternalServerError +} + +// IsGrpcError checks if the error is a gRPC error. +func IsGrpcError(err error) bool { + if err == nil { + return false + } + + _, ok := err.(interface { + GRPCStatus() *status.Status + }) + + return ok +} diff --git a/rest/internal/errcode/grpc_test.go b/rest/internal/errcode/grpc_test.go new file mode 100644 index 00000000..7dd9adc7 --- /dev/null +++ b/rest/internal/errcode/grpc_test.go @@ -0,0 +1,123 @@ +package errcode + +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestCodeFromGrpcError(t *testing.T) { + tests := []struct { + name string + code codes.Code + want int + }{ + { + name: "OK", + code: codes.OK, + want: http.StatusOK, + }, + { + name: "Invalid argument", + code: codes.InvalidArgument, + want: http.StatusBadRequest, + }, + { + name: "Failed precondition", + code: codes.FailedPrecondition, + want: http.StatusBadRequest, + }, + { + name: "Out of range", + code: codes.OutOfRange, + want: http.StatusBadRequest, + }, + { + name: "Unauthorized", + code: codes.Unauthenticated, + want: http.StatusUnauthorized, + }, + { + name: "Permission denied", + code: codes.PermissionDenied, + want: http.StatusForbidden, + }, + { + name: "Not found", + code: codes.NotFound, + want: http.StatusNotFound, + }, + { + name: "Canceled", + code: codes.Canceled, + want: http.StatusRequestTimeout, + }, + { + name: "Already exists", + code: codes.AlreadyExists, + want: http.StatusConflict, + }, + { + name: "Aborted", + code: codes.Aborted, + want: http.StatusConflict, + }, + { + name: "Resource exhausted", + code: codes.ResourceExhausted, + want: http.StatusTooManyRequests, + }, + { + name: "Internal", + code: codes.Internal, + want: http.StatusInternalServerError, + }, + { + name: "Data loss", + code: codes.DataLoss, + want: http.StatusInternalServerError, + }, + { + name: "Unknown", + code: codes.Unknown, + want: http.StatusInternalServerError, + }, + { + name: "Unimplemented", + code: codes.Unimplemented, + want: http.StatusNotImplemented, + }, + { + name: "Unavailable", + code: codes.Unavailable, + want: http.StatusServiceUnavailable, + }, + { + name: "Deadline exceeded", + code: codes.DeadlineExceeded, + want: http.StatusGatewayTimeout, + }, + { + name: "Beyond defined error", + code: codes.Code(^uint32(0)), + want: http.StatusInternalServerError, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.want, CodeFromGrpcError(status.Error(test.code, "foo"))) + }) + } +} + +func TestIsGrpcError(t *testing.T) { + assert.True(t, IsGrpcError(status.Error(codes.Unknown, "foo"))) + assert.False(t, IsGrpcError(errors.New("foo"))) + assert.False(t, IsGrpcError(nil)) +} diff --git a/zrpc/internal/codes/accept.go b/zrpc/internal/codes/accept.go index 0ecb1275..8ecf292a 100644 --- a/zrpc/internal/codes/accept.go +++ b/zrpc/internal/codes/accept.go @@ -8,7 +8,7 @@ import ( // Acceptable checks if given error is acceptable. func Acceptable(err error) bool { switch status.Code(err) { - case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss: + case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss, codes.Unimplemented: return false default: return true diff --git a/zrpc/internal/serverinterceptors/timeoutinterceptor.go b/zrpc/internal/serverinterceptors/timeoutinterceptor.go index 4d168673..68937201 100644 --- a/zrpc/internal/serverinterceptors/timeoutinterceptor.go +++ b/zrpc/internal/serverinterceptors/timeoutinterceptor.go @@ -49,7 +49,6 @@ func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor return resp, err case <-ctx.Done(): err := ctx.Err() - if err == context.Canceled { err = status.Error(codes.Canceled, err.Error()) } else if err == context.DeadlineExceeded {