diff --git a/gateway/internal/descriptorsource.go b/gateway/internal/descriptorsource.go index 8e1895bc..9b3abbd3 100644 --- a/gateway/internal/descriptorsource.go +++ b/gateway/internal/descriptorsource.go @@ -37,52 +37,51 @@ func GetMethods(source grpcurl.DescriptorSource) ([]Method, error) { for _, method := range svcMethods { rpcPath := fmt.Sprintf("%s/%s", svc, method.GetName()) ext := proto.GetExtension(method.GetMethodOptions(), annotations.E_Http) - if ext == nil { - methods = append(methods, Method{ - RpcPath: rpcPath, - }) - continue - } + switch rule := ext.(type) { + case *annotations.HttpRule: + if rule == nil { + methods = append(methods, Method{ + RpcPath: rpcPath, + }) + continue + } - httpExt, ok := ext.(*annotations.HttpRule) - if !ok { - methods = append(methods, Method{ - RpcPath: rpcPath, - }) - continue - } - - switch rule := httpExt.GetPattern().(type) { - case *annotations.HttpRule_Get: - methods = append(methods, Method{ - HttpMethod: http.MethodGet, - HttpPath: adjustHttpPath(rule.Get), - RpcPath: rpcPath, - }) - case *annotations.HttpRule_Post: - methods = append(methods, Method{ - HttpMethod: http.MethodPost, - HttpPath: adjustHttpPath(rule.Post), - RpcPath: rpcPath, - }) - case *annotations.HttpRule_Put: - methods = append(methods, Method{ - HttpMethod: http.MethodPut, - HttpPath: adjustHttpPath(rule.Put), - RpcPath: rpcPath, - }) - case *annotations.HttpRule_Delete: - methods = append(methods, Method{ - HttpMethod: http.MethodDelete, - HttpPath: adjustHttpPath(rule.Delete), - RpcPath: rpcPath, - }) - case *annotations.HttpRule_Patch: - methods = append(methods, Method{ - HttpMethod: http.MethodPatch, - HttpPath: adjustHttpPath(rule.Patch), - RpcPath: rpcPath, - }) + switch httpRule := rule.GetPattern().(type) { + case *annotations.HttpRule_Get: + methods = append(methods, Method{ + HttpMethod: http.MethodGet, + HttpPath: adjustHttpPath(httpRule.Get), + RpcPath: rpcPath, + }) + case *annotations.HttpRule_Post: + methods = append(methods, Method{ + HttpMethod: http.MethodPost, + HttpPath: adjustHttpPath(httpRule.Post), + RpcPath: rpcPath, + }) + case *annotations.HttpRule_Put: + methods = append(methods, Method{ + HttpMethod: http.MethodPut, + HttpPath: adjustHttpPath(httpRule.Put), + RpcPath: rpcPath, + }) + case *annotations.HttpRule_Delete: + methods = append(methods, Method{ + HttpMethod: http.MethodDelete, + HttpPath: adjustHttpPath(httpRule.Delete), + RpcPath: rpcPath, + }) + case *annotations.HttpRule_Patch: + methods = append(methods, Method{ + HttpMethod: http.MethodPatch, + HttpPath: adjustHttpPath(httpRule.Patch), + RpcPath: rpcPath, + }) + default: + methods = append(methods, Method{ + RpcPath: rpcPath, + }) + } default: methods = append(methods, Method{ RpcPath: rpcPath, diff --git a/gateway/internal/descriptorsource_test.go b/gateway/internal/descriptorsource_test.go index d0957132..2b84452e 100644 --- a/gateway/internal/descriptorsource_test.go +++ b/gateway/internal/descriptorsource_test.go @@ -2,11 +2,13 @@ package internal import ( "encoding/base64" + "errors" "net/http" "os" "testing" "github.com/fullstorydev/grpcurl" + "github.com/jhump/protoreflect/desc" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/hash" ) @@ -75,3 +77,50 @@ func TestGetMethodsWithAnnotations(t *testing.T) { }, }, methods) } + +func TestGetMethodsBadCases(t *testing.T) { + t.Run("no services", func(t *testing.T) { + source := &mockDescriptorSource{ + servicesErr: errors.New("no services"), + } + _, err := GetMethods(source) + assert.NotNil(t, err) + }) + + t.Run("no symbol in services", func(t *testing.T) { + source := &mockDescriptorSource{ + services: []string{"hello.Hello"}, + symbolErr: errors.New("no symbol"), + } + _, err := GetMethods(source) + assert.NotNil(t, err) + }) + + t.Run("no symbol in services", func(t *testing.T) { + source := &mockDescriptorSource{ + services: []string{"hello.Hello"}, + symbolErr: errors.New("no symbol"), + } + _, err := GetMethods(source) + assert.NotNil(t, err) + }) +} + +type mockDescriptorSource struct { + symbolDesc desc.Descriptor + symbolErr error + services []string + servicesErr error +} + +func (m *mockDescriptorSource) AllExtensionsForType(_ string) ([]*desc.FieldDescriptor, error) { + return nil, nil +} + +func (m *mockDescriptorSource) FindSymbol(_ string) (desc.Descriptor, error) { + return m.symbolDesc, m.symbolErr +} + +func (m *mockDescriptorSource) ListServices() ([]string, error) { + return m.services, m.servicesErr +}