chore: add more tests (#3009)

This commit is contained in:
Kevin Wan
2023-03-10 20:48:10 +08:00
committed by GitHub
parent 3a493cd6a6
commit c8a17a97be
2 changed files with 93 additions and 45 deletions

View File

@@ -37,52 +37,51 @@ func GetMethods(source grpcurl.DescriptorSource) ([]Method, error) {
for _, method := range svcMethods { for _, method := range svcMethods {
rpcPath := fmt.Sprintf("%s/%s", svc, method.GetName()) rpcPath := fmt.Sprintf("%s/%s", svc, method.GetName())
ext := proto.GetExtension(method.GetMethodOptions(), annotations.E_Http) ext := proto.GetExtension(method.GetMethodOptions(), annotations.E_Http)
if ext == nil { switch rule := ext.(type) {
methods = append(methods, Method{ case *annotations.HttpRule:
RpcPath: rpcPath, if rule == nil {
}) methods = append(methods, Method{
continue RpcPath: rpcPath,
} })
continue
}
httpExt, ok := ext.(*annotations.HttpRule) switch httpRule := rule.GetPattern().(type) {
if !ok { case *annotations.HttpRule_Get:
methods = append(methods, Method{ methods = append(methods, Method{
RpcPath: rpcPath, HttpMethod: http.MethodGet,
}) HttpPath: adjustHttpPath(httpRule.Get),
continue RpcPath: rpcPath,
} })
case *annotations.HttpRule_Post:
switch rule := httpExt.GetPattern().(type) { methods = append(methods, Method{
case *annotations.HttpRule_Get: HttpMethod: http.MethodPost,
methods = append(methods, Method{ HttpPath: adjustHttpPath(httpRule.Post),
HttpMethod: http.MethodGet, RpcPath: rpcPath,
HttpPath: adjustHttpPath(rule.Get), })
RpcPath: rpcPath, case *annotations.HttpRule_Put:
}) methods = append(methods, Method{
case *annotations.HttpRule_Post: HttpMethod: http.MethodPut,
methods = append(methods, Method{ HttpPath: adjustHttpPath(httpRule.Put),
HttpMethod: http.MethodPost, RpcPath: rpcPath,
HttpPath: adjustHttpPath(rule.Post), })
RpcPath: rpcPath, case *annotations.HttpRule_Delete:
}) methods = append(methods, Method{
case *annotations.HttpRule_Put: HttpMethod: http.MethodDelete,
methods = append(methods, Method{ HttpPath: adjustHttpPath(httpRule.Delete),
HttpMethod: http.MethodPut, RpcPath: rpcPath,
HttpPath: adjustHttpPath(rule.Put), })
RpcPath: rpcPath, case *annotations.HttpRule_Patch:
}) methods = append(methods, Method{
case *annotations.HttpRule_Delete: HttpMethod: http.MethodPatch,
methods = append(methods, Method{ HttpPath: adjustHttpPath(httpRule.Patch),
HttpMethod: http.MethodDelete, RpcPath: rpcPath,
HttpPath: adjustHttpPath(rule.Delete), })
RpcPath: rpcPath, default:
}) methods = append(methods, Method{
case *annotations.HttpRule_Patch: RpcPath: rpcPath,
methods = append(methods, Method{ })
HttpMethod: http.MethodPatch, }
HttpPath: adjustHttpPath(rule.Patch),
RpcPath: rpcPath,
})
default: default:
methods = append(methods, Method{ methods = append(methods, Method{
RpcPath: rpcPath, RpcPath: rpcPath,

View File

@@ -2,11 +2,13 @@ package internal
import ( import (
"encoding/base64" "encoding/base64"
"errors"
"net/http" "net/http"
"os" "os"
"testing" "testing"
"github.com/fullstorydev/grpcurl" "github.com/fullstorydev/grpcurl"
"github.com/jhump/protoreflect/desc"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/hash" "github.com/zeromicro/go-zero/core/hash"
) )
@@ -75,3 +77,50 @@ func TestGetMethodsWithAnnotations(t *testing.T) {
}, },
}, methods) }, methods)
} }
func TestGetMethodsBadCases(t *testing.T) {
t.Run("no services", func(t *testing.T) {
source := &mockDescriptorSource{
servicesErr: errors.New("no services"),
}
_, err := GetMethods(source)
assert.NotNil(t, err)
})
t.Run("no symbol in services", func(t *testing.T) {
source := &mockDescriptorSource{
services: []string{"hello.Hello"},
symbolErr: errors.New("no symbol"),
}
_, err := GetMethods(source)
assert.NotNil(t, err)
})
t.Run("no symbol in services", func(t *testing.T) {
source := &mockDescriptorSource{
services: []string{"hello.Hello"},
symbolErr: errors.New("no symbol"),
}
_, err := GetMethods(source)
assert.NotNil(t, err)
})
}
type mockDescriptorSource struct {
symbolDesc desc.Descriptor
symbolErr error
services []string
servicesErr error
}
func (m *mockDescriptorSource) AllExtensionsForType(_ string) ([]*desc.FieldDescriptor, error) {
return nil, nil
}
func (m *mockDescriptorSource) FindSymbol(_ string) (desc.Descriptor, error) {
return m.symbolDesc, m.symbolErr
}
func (m *mockDescriptorSource) ListServices() ([]string, error) {
return m.services, m.servicesErr
}