rename rpcx to zrpc
This commit is contained in:
73
zrpc/internal/auth/auth.go
Normal file
73
zrpc/internal/auth/auth.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/collection"
|
||||
"github.com/tal-tech/go-zero/core/stores/redis"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const defaultExpiration = 5 * time.Minute
|
||||
|
||||
type Authenticator struct {
|
||||
store *redis.Redis
|
||||
key string
|
||||
cache *collection.Cache
|
||||
strict bool
|
||||
}
|
||||
|
||||
func NewAuthenticator(store *redis.Redis, key string, strict bool) (*Authenticator, error) {
|
||||
cache, err := collection.NewCache(defaultExpiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Authenticator{
|
||||
store: store,
|
||||
key: key,
|
||||
cache: cache,
|
||||
strict: strict,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) Authenticate(ctx context.Context) error {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, missingMetadata)
|
||||
}
|
||||
|
||||
apps, tokens := md[appKey], md[tokenKey]
|
||||
if len(apps) == 0 || len(tokens) == 0 {
|
||||
return status.Error(codes.Unauthenticated, missingMetadata)
|
||||
}
|
||||
|
||||
app, token := apps[0], tokens[0]
|
||||
if len(app) == 0 || len(token) == 0 {
|
||||
return status.Error(codes.Unauthenticated, missingMetadata)
|
||||
}
|
||||
|
||||
return a.validate(app, token)
|
||||
}
|
||||
|
||||
func (a *Authenticator) validate(app, token string) error {
|
||||
expect, err := a.cache.Take(app, func() (interface{}, error) {
|
||||
return a.store.Hget(a.key, app)
|
||||
})
|
||||
if err != nil {
|
||||
if a.strict {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if token != expect {
|
||||
return status.Error(codes.Unauthenticated, accessDenied)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
47
zrpc/internal/auth/credential.go
Normal file
47
zrpc/internal/auth/credential.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
type Credential struct {
|
||||
App string
|
||||
Token string
|
||||
}
|
||||
|
||||
func (c *Credential) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
|
||||
return map[string]string{
|
||||
appKey: c.App,
|
||||
tokenKey: c.Token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Credential) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func ParseCredential(ctx context.Context) Credential {
|
||||
var credential Credential
|
||||
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return credential
|
||||
}
|
||||
|
||||
apps, tokens := md[appKey], md[tokenKey]
|
||||
if len(apps) == 0 || len(tokens) == 0 {
|
||||
return credential
|
||||
}
|
||||
|
||||
app, token := apps[0], tokens[0]
|
||||
if len(app) == 0 || len(token) == 0 {
|
||||
return credential
|
||||
}
|
||||
|
||||
credential.App = app
|
||||
credential.Token = token
|
||||
|
||||
return credential
|
||||
}
|
||||
62
zrpc/internal/auth/credential_test.go
Normal file
62
zrpc/internal/auth/credential_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestParseCredential(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
withNil bool
|
||||
withEmptyMd bool
|
||||
app string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
withNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty md",
|
||||
withEmptyMd: true,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var ctx context.Context
|
||||
if test.withNil {
|
||||
ctx = context.Background()
|
||||
} else if test.withEmptyMd {
|
||||
ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{})
|
||||
} else {
|
||||
md := metadata.New(map[string]string{
|
||||
"app": test.app,
|
||||
"token": test.token,
|
||||
})
|
||||
ctx = metadata.NewIncomingContext(context.Background(), md)
|
||||
}
|
||||
cred := ParseCredential(ctx)
|
||||
assert.False(t, cred.RequireTransportSecurity())
|
||||
m, err := cred.GetRequestMetadata(context.Background())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, test.app, m[appKey])
|
||||
assert.Equal(t, test.token, m[tokenKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
9
zrpc/internal/auth/vars.go
Normal file
9
zrpc/internal/auth/vars.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package auth
|
||||
|
||||
const (
|
||||
appKey = "app"
|
||||
tokenKey = "token"
|
||||
|
||||
accessDenied = "access denied"
|
||||
missingMetadata = "app/token required"
|
||||
)
|
||||
Reference in New Issue
Block a user