add more tests
This commit is contained in:
62
rpcx/internal/auth/credential_test.go
Normal file
62
rpcx/internal/auth/credential_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestParseCredential(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
withNil bool
|
||||
withEmptyMd bool
|
||||
app string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
withNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty md",
|
||||
withEmptyMd: true,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var ctx context.Context
|
||||
if test.withNil {
|
||||
ctx = context.Background()
|
||||
} else if test.withEmptyMd {
|
||||
ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{})
|
||||
} else {
|
||||
md := metadata.New(map[string]string{
|
||||
"app": test.app,
|
||||
"token": test.token,
|
||||
})
|
||||
ctx = metadata.NewIncomingContext(context.Background(), md)
|
||||
}
|
||||
cred := ParseCredential(ctx)
|
||||
assert.False(t, cred.RequireTransportSecurity())
|
||||
m, err := cred.GetRequestMetadata(context.Background())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, test.app, m[appKey])
|
||||
assert.Equal(t, test.token, m[tokenKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,21 +8,41 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/stores/redis"
|
||||
"github.com/tal-tech/go-zero/rpcx/internal/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestUnaryAuthorizeInterceptor(t *testing.T) {
|
||||
func TestStreamAuthorizeInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
strict bool
|
||||
name string
|
||||
app string
|
||||
token string
|
||||
strict bool
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "strict=true",
|
||||
strict: true,
|
||||
name: "strict=false",
|
||||
strict: false,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=false",
|
||||
strict: false,
|
||||
name: "strict=true",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with token",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
strict: true,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with error token",
|
||||
app: "foo",
|
||||
token: "error",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -33,19 +53,24 @@ func TestUnaryAuthorizeInterceptor(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
store := redis.NewRedis(r.Addr(), redis.NodeType)
|
||||
if len(test.app) > 0 {
|
||||
assert.Nil(t, store.Hset("apps", test.app, test.token))
|
||||
defer store.Hdel("apps", test.app)
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
|
||||
assert.Nil(t, err)
|
||||
interceptor := UnaryAuthorizeInterceptor(authenticator)
|
||||
interceptor := StreamAuthorizeInterceptor(authenticator)
|
||||
md := metadata.New(map[string]string{
|
||||
"app": "name",
|
||||
"token": "key",
|
||||
"app": "foo",
|
||||
"token": "bar",
|
||||
})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
if test.strict {
|
||||
stream := mockedStream{ctx: ctx}
|
||||
err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
return nil
|
||||
})
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
@@ -53,3 +78,123 @@ func TestUnaryAuthorizeInterceptor(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnaryAuthorizeInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
app string
|
||||
token string
|
||||
strict bool
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "strict=false",
|
||||
strict: false,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with token",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
strict: true,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with error token",
|
||||
app: "foo",
|
||||
token: "error",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
}
|
||||
|
||||
r := miniredis.NewMiniRedis()
|
||||
assert.Nil(t, r.Start())
|
||||
defer r.Close()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
store := redis.NewRedis(r.Addr(), redis.NodeType)
|
||||
if len(test.app) > 0 {
|
||||
assert.Nil(t, store.Hset("apps", test.app, test.token))
|
||||
defer store.Hdel("apps", test.app)
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
|
||||
assert.Nil(t, err)
|
||||
interceptor := UnaryAuthorizeInterceptor(authenticator)
|
||||
md := metadata.New(map[string]string{
|
||||
"app": "foo",
|
||||
"token": "bar",
|
||||
})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
if test.strict {
|
||||
_, err = interceptor(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var md metadata.MD
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
md = metadata.New(map[string]string{
|
||||
"app": "",
|
||||
"token": "",
|
||||
})
|
||||
ctx = metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockedStream struct {
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m mockedStream) SetHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockedStream) SendHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockedStream) SetTrailer(md metadata.MD) {
|
||||
}
|
||||
|
||||
func (m mockedStream) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m mockedStream) SendMsg(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockedStream) RecvMsg(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user