rename rpcx to zrpc
This commit is contained in:
73
zrpc/internal/auth/auth.go
Normal file
73
zrpc/internal/auth/auth.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/collection"
|
||||
"github.com/tal-tech/go-zero/core/stores/redis"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const defaultExpiration = 5 * time.Minute
|
||||
|
||||
type Authenticator struct {
|
||||
store *redis.Redis
|
||||
key string
|
||||
cache *collection.Cache
|
||||
strict bool
|
||||
}
|
||||
|
||||
func NewAuthenticator(store *redis.Redis, key string, strict bool) (*Authenticator, error) {
|
||||
cache, err := collection.NewCache(defaultExpiration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Authenticator{
|
||||
store: store,
|
||||
key: key,
|
||||
cache: cache,
|
||||
strict: strict,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) Authenticate(ctx context.Context) error {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.Unauthenticated, missingMetadata)
|
||||
}
|
||||
|
||||
apps, tokens := md[appKey], md[tokenKey]
|
||||
if len(apps) == 0 || len(tokens) == 0 {
|
||||
return status.Error(codes.Unauthenticated, missingMetadata)
|
||||
}
|
||||
|
||||
app, token := apps[0], tokens[0]
|
||||
if len(app) == 0 || len(token) == 0 {
|
||||
return status.Error(codes.Unauthenticated, missingMetadata)
|
||||
}
|
||||
|
||||
return a.validate(app, token)
|
||||
}
|
||||
|
||||
func (a *Authenticator) validate(app, token string) error {
|
||||
expect, err := a.cache.Take(app, func() (interface{}, error) {
|
||||
return a.store.Hget(a.key, app)
|
||||
})
|
||||
if err != nil {
|
||||
if a.strict {
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if token != expect {
|
||||
return status.Error(codes.Unauthenticated, accessDenied)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
47
zrpc/internal/auth/credential.go
Normal file
47
zrpc/internal/auth/credential.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
type Credential struct {
|
||||
App string
|
||||
Token string
|
||||
}
|
||||
|
||||
func (c *Credential) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
|
||||
return map[string]string{
|
||||
appKey: c.App,
|
||||
tokenKey: c.Token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Credential) RequireTransportSecurity() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func ParseCredential(ctx context.Context) Credential {
|
||||
var credential Credential
|
||||
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return credential
|
||||
}
|
||||
|
||||
apps, tokens := md[appKey], md[tokenKey]
|
||||
if len(apps) == 0 || len(tokens) == 0 {
|
||||
return credential
|
||||
}
|
||||
|
||||
app, token := apps[0], tokens[0]
|
||||
if len(app) == 0 || len(token) == 0 {
|
||||
return credential
|
||||
}
|
||||
|
||||
credential.App = app
|
||||
credential.Token = token
|
||||
|
||||
return credential
|
||||
}
|
||||
62
zrpc/internal/auth/credential_test.go
Normal file
62
zrpc/internal/auth/credential_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestParseCredential(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
withNil bool
|
||||
withEmptyMd bool
|
||||
app string
|
||||
token string
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
withNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty md",
|
||||
withEmptyMd: true,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var ctx context.Context
|
||||
if test.withNil {
|
||||
ctx = context.Background()
|
||||
} else if test.withEmptyMd {
|
||||
ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{})
|
||||
} else {
|
||||
md := metadata.New(map[string]string{
|
||||
"app": test.app,
|
||||
"token": test.token,
|
||||
})
|
||||
ctx = metadata.NewIncomingContext(context.Background(), md)
|
||||
}
|
||||
cred := ParseCredential(ctx)
|
||||
assert.False(t, cred.RequireTransportSecurity())
|
||||
m, err := cred.GetRequestMetadata(context.Background())
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, test.app, m[appKey])
|
||||
assert.Equal(t, test.token, m[tokenKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
9
zrpc/internal/auth/vars.go
Normal file
9
zrpc/internal/auth/vars.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package auth
|
||||
|
||||
const (
|
||||
appKey = "app"
|
||||
tokenKey = "token"
|
||||
|
||||
accessDenied = "access denied"
|
||||
missingMetadata = "app/token required"
|
||||
)
|
||||
202
zrpc/internal/balancer/p2c/p2c.go
Normal file
202
zrpc/internal/balancer/p2c/p2c.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package p2c
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/syncx"
|
||||
"github.com/tal-tech/go-zero/core/timex"
|
||||
"github.com/tal-tech/go-zero/zrpc/internal/codes"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/base"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
const (
|
||||
Name = "p2c_ewma"
|
||||
decayTime = int64(time.Second * 10) // default value from finagle
|
||||
forcePick = int64(time.Second)
|
||||
initSuccess = 1000
|
||||
throttleSuccess = initSuccess / 2
|
||||
penalty = int64(math.MaxInt32)
|
||||
pickTimes = 3
|
||||
logInterval = time.Minute
|
||||
)
|
||||
|
||||
func init() {
|
||||
balancer.Register(newBuilder())
|
||||
}
|
||||
|
||||
type p2cPickerBuilder struct {
|
||||
}
|
||||
|
||||
func newBuilder() balancer.Builder {
|
||||
return base.NewBalancerBuilder(Name, new(p2cPickerBuilder))
|
||||
}
|
||||
|
||||
func (b *p2cPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker {
|
||||
if len(readySCs) == 0 {
|
||||
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
|
||||
}
|
||||
|
||||
var conns []*subConn
|
||||
for addr, conn := range readySCs {
|
||||
conns = append(conns, &subConn{
|
||||
addr: addr,
|
||||
conn: conn,
|
||||
success: initSuccess,
|
||||
})
|
||||
}
|
||||
|
||||
return &p2cPicker{
|
||||
conns: conns,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
stamp: syncx.NewAtomicDuration(),
|
||||
}
|
||||
}
|
||||
|
||||
type p2cPicker struct {
|
||||
conns []*subConn
|
||||
r *rand.Rand
|
||||
stamp *syncx.AtomicDuration
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (p *p2cPicker) Pick(ctx context.Context, info balancer.PickInfo) (
|
||||
conn balancer.SubConn, done func(balancer.DoneInfo), err error) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
var chosen *subConn
|
||||
switch len(p.conns) {
|
||||
case 0:
|
||||
return nil, nil, balancer.ErrNoSubConnAvailable
|
||||
case 1:
|
||||
chosen = p.choose(p.conns[0], nil)
|
||||
case 2:
|
||||
chosen = p.choose(p.conns[0], p.conns[1])
|
||||
default:
|
||||
var node1, node2 *subConn
|
||||
for i := 0; i < pickTimes; i++ {
|
||||
a := p.r.Intn(len(p.conns))
|
||||
b := p.r.Intn(len(p.conns) - 1)
|
||||
if b >= a {
|
||||
b++
|
||||
}
|
||||
node1 = p.conns[a]
|
||||
node2 = p.conns[b]
|
||||
if node1.healthy() && node2.healthy() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
chosen = p.choose(node1, node2)
|
||||
}
|
||||
|
||||
atomic.AddInt64(&chosen.inflight, 1)
|
||||
atomic.AddInt64(&chosen.requests, 1)
|
||||
return chosen.conn, p.buildDoneFunc(chosen), nil
|
||||
}
|
||||
|
||||
func (p *p2cPicker) buildDoneFunc(c *subConn) func(info balancer.DoneInfo) {
|
||||
start := int64(timex.Now())
|
||||
return func(info balancer.DoneInfo) {
|
||||
atomic.AddInt64(&c.inflight, -1)
|
||||
now := timex.Now()
|
||||
last := atomic.SwapInt64(&c.last, int64(now))
|
||||
td := int64(now) - last
|
||||
if td < 0 {
|
||||
td = 0
|
||||
}
|
||||
w := math.Exp(float64(-td) / float64(decayTime))
|
||||
lag := int64(now) - start
|
||||
if lag < 0 {
|
||||
lag = 0
|
||||
}
|
||||
olag := atomic.LoadUint64(&c.lag)
|
||||
if olag == 0 {
|
||||
w = 0
|
||||
}
|
||||
atomic.StoreUint64(&c.lag, uint64(float64(olag)*w+float64(lag)*(1-w)))
|
||||
success := initSuccess
|
||||
if info.Err != nil && !codes.Acceptable(info.Err) {
|
||||
success = 0
|
||||
}
|
||||
osucc := atomic.LoadUint64(&c.success)
|
||||
atomic.StoreUint64(&c.success, uint64(float64(osucc)*w+float64(success)*(1-w)))
|
||||
|
||||
stamp := p.stamp.Load()
|
||||
if now-stamp >= logInterval {
|
||||
if p.stamp.CompareAndSwap(stamp, now) {
|
||||
p.logStats()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *p2cPicker) choose(c1, c2 *subConn) *subConn {
|
||||
start := int64(timex.Now())
|
||||
if c2 == nil {
|
||||
atomic.StoreInt64(&c1.pick, start)
|
||||
return c1
|
||||
}
|
||||
|
||||
if c1.load() > c2.load() {
|
||||
c1, c2 = c2, c1
|
||||
}
|
||||
|
||||
pick := atomic.LoadInt64(&c2.pick)
|
||||
if start-pick > forcePick && atomic.CompareAndSwapInt64(&c2.pick, pick, start) {
|
||||
return c2
|
||||
} else {
|
||||
atomic.StoreInt64(&c1.pick, start)
|
||||
return c1
|
||||
}
|
||||
}
|
||||
|
||||
func (p *p2cPicker) logStats() {
|
||||
var stats []string
|
||||
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
for _, conn := range p.conns {
|
||||
stats = append(stats, fmt.Sprintf("conn: %s, load: %d, reqs: %d",
|
||||
conn.addr.Addr, conn.load(), atomic.SwapInt64(&conn.requests, 0)))
|
||||
}
|
||||
|
||||
logx.Statf("p2c - %s", strings.Join(stats, "; "))
|
||||
}
|
||||
|
||||
type subConn struct {
|
||||
addr resolver.Address
|
||||
conn balancer.SubConn
|
||||
lag uint64
|
||||
inflight int64
|
||||
success uint64
|
||||
requests int64
|
||||
last int64
|
||||
pick int64
|
||||
}
|
||||
|
||||
func (c *subConn) healthy() bool {
|
||||
return atomic.LoadUint64(&c.success) > throttleSuccess
|
||||
}
|
||||
|
||||
func (c *subConn) load() int64 {
|
||||
// plus one to avoid multiply zero
|
||||
lag := int64(math.Sqrt(float64(atomic.LoadUint64(&c.lag) + 1)))
|
||||
load := lag * (atomic.LoadInt64(&c.inflight) + 1)
|
||||
if load == 0 {
|
||||
return penalty
|
||||
} else {
|
||||
return load
|
||||
}
|
||||
}
|
||||
113
zrpc/internal/balancer/p2c/p2c_test.go
Normal file
113
zrpc/internal/balancer/p2c/p2c_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package p2c
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/mathx"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
func TestP2cPicker_PickNil(t *testing.T) {
|
||||
builder := new(p2cPickerBuilder)
|
||||
picker := builder.Build(nil)
|
||||
_, _, err := picker.Pick(context.Background(), balancer.PickInfo{
|
||||
FullMethodName: "/",
|
||||
Ctx: context.Background(),
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestP2cPicker_Pick(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
candidates int
|
||||
threshold float64
|
||||
}{
|
||||
{
|
||||
name: "single",
|
||||
candidates: 1,
|
||||
threshold: 0.9,
|
||||
},
|
||||
{
|
||||
name: "two",
|
||||
candidates: 2,
|
||||
threshold: 0.5,
|
||||
},
|
||||
{
|
||||
name: "multiple",
|
||||
candidates: 100,
|
||||
threshold: 0.95,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const total = 10000
|
||||
builder := new(p2cPickerBuilder)
|
||||
ready := make(map[resolver.Address]balancer.SubConn)
|
||||
for i := 0; i < test.candidates; i++ {
|
||||
ready[resolver.Address{
|
||||
Addr: strconv.Itoa(i),
|
||||
}] = new(mockClientConn)
|
||||
}
|
||||
|
||||
picker := builder.Build(ready)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(total)
|
||||
for i := 0; i < total; i++ {
|
||||
_, done, err := picker.Pick(context.Background(), balancer.PickInfo{
|
||||
FullMethodName: "/",
|
||||
Ctx: context.Background(),
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
if i%100 == 0 {
|
||||
err = status.Error(codes.DeadlineExceeded, "deadline")
|
||||
}
|
||||
go func() {
|
||||
runtime.Gosched()
|
||||
done(balancer.DoneInfo{
|
||||
Err: err,
|
||||
})
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
dist := make(map[interface{}]int)
|
||||
conns := picker.(*p2cPicker).conns
|
||||
for _, conn := range conns {
|
||||
dist[conn.addr.Addr] = int(conn.requests)
|
||||
}
|
||||
|
||||
entropy := mathx.CalcEntropy(dist)
|
||||
assert.True(t, entropy > test.threshold, fmt.Sprintf("entropy is %f, less than %f",
|
||||
entropy, test.threshold))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockClientConn struct {
|
||||
}
|
||||
|
||||
func (m mockClientConn) UpdateAddresses(addresses []resolver.Address) {
|
||||
}
|
||||
|
||||
func (m mockClientConn) Connect() {
|
||||
}
|
||||
83
zrpc/internal/chainclientinterceptors.go
Normal file
83
zrpc/internal/chainclientinterceptors.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption {
|
||||
return grpc.WithStreamInterceptor(chainStreamClientInterceptors(interceptors...))
|
||||
}
|
||||
|
||||
func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption {
|
||||
return grpc.WithUnaryInterceptor(chainUnaryClientInterceptors(interceptors...))
|
||||
}
|
||||
|
||||
func chainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
|
||||
switch len(interceptors) {
|
||||
case 0:
|
||||
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}
|
||||
case 1:
|
||||
return interceptors[0]
|
||||
default:
|
||||
last := len(interceptors) - 1
|
||||
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
|
||||
method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
var chainStreamer grpc.Streamer
|
||||
var current int
|
||||
|
||||
chainStreamer = func(curCtx context.Context, curDesc *grpc.StreamDesc, curCc *grpc.ClientConn,
|
||||
curMethod string, curOpts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
if current == last {
|
||||
return streamer(curCtx, curDesc, curCc, curMethod, curOpts...)
|
||||
}
|
||||
|
||||
current++
|
||||
clientStream, err := interceptors[current](curCtx, curDesc, curCc, curMethod, chainStreamer, curOpts...)
|
||||
current--
|
||||
|
||||
return clientStream, err
|
||||
}
|
||||
|
||||
return interceptors[0](ctx, desc, cc, method, chainStreamer, opts...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
|
||||
switch len(interceptors) {
|
||||
case 0:
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
case 1:
|
||||
return interceptors[0]
|
||||
default:
|
||||
last := len(interceptors) - 1
|
||||
return func(ctx context.Context, method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
var chainInvoker grpc.UnaryInvoker
|
||||
var current int
|
||||
|
||||
chainInvoker = func(curCtx context.Context, curMethod string, curReq, curReply interface{},
|
||||
curCc *grpc.ClientConn, curOpts ...grpc.CallOption) error {
|
||||
if current == last {
|
||||
return invoker(curCtx, curMethod, curReq, curReply, curCc, curOpts...)
|
||||
}
|
||||
|
||||
current++
|
||||
err := interceptors[current](curCtx, curMethod, curReq, curReply, curCc, chainInvoker, curOpts...)
|
||||
current--
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return interceptors[0](ctx, method, req, reply, cc, chainInvoker, opts...)
|
||||
}
|
||||
}
|
||||
}
|
||||
123
zrpc/internal/chainclientinterceptors_test.go
Normal file
123
zrpc/internal/chainclientinterceptors_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestWithStreamClientInterceptors(t *testing.T) {
|
||||
opts := WithStreamClientInterceptors()
|
||||
assert.NotNil(t, opts)
|
||||
}
|
||||
|
||||
func TestWithUnaryClientInterceptors(t *testing.T) {
|
||||
opts := WithUnaryClientInterceptors()
|
||||
assert.NotNil(t, opts)
|
||||
}
|
||||
|
||||
func TestChainStreamClientInterceptors_zero(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamClientInterceptors()
|
||||
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
|
||||
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
vals = append(vals, 1)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
|
||||
func TestChainStreamClientInterceptors_one(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
|
||||
cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
|
||||
grpc.ClientStream, error) {
|
||||
vals = append(vals, 1)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
})
|
||||
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
|
||||
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
vals = append(vals, 2)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2}, vals)
|
||||
}
|
||||
|
||||
func TestChainStreamClientInterceptors_more(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
|
||||
cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
|
||||
grpc.ClientStream, error) {
|
||||
vals = append(vals, 1)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
vals = append(vals, 2)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
})
|
||||
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
|
||||
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
|
||||
opts ...grpc.CallOption) (grpc.ClientStream, error) {
|
||||
vals = append(vals, 3)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
|
||||
}
|
||||
|
||||
func TestWithUnaryClientInterceptors_zero(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryClientInterceptors()
|
||||
err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 1)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
|
||||
func TestWithUnaryClientInterceptors_one(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
|
||||
reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 1)
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
})
|
||||
err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 2)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2}, vals)
|
||||
}
|
||||
|
||||
func TestWithUnaryClientInterceptors_more(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
|
||||
reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 1)
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 2)
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
})
|
||||
err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
vals = append(vals, 3)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
|
||||
}
|
||||
81
zrpc/internal/chainserverinterceptors.go
Normal file
81
zrpc/internal/chainserverinterceptors.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
|
||||
return grpc.StreamInterceptor(chainStreamServerInterceptors(interceptors...))
|
||||
}
|
||||
|
||||
func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
|
||||
return grpc.UnaryInterceptor(chainUnaryServerInterceptors(interceptors...))
|
||||
}
|
||||
|
||||
func chainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
|
||||
switch len(interceptors) {
|
||||
case 0:
|
||||
return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) error {
|
||||
return handler(srv, stream)
|
||||
}
|
||||
case 1:
|
||||
return interceptors[0]
|
||||
default:
|
||||
last := len(interceptors) - 1
|
||||
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) error {
|
||||
var chainHandler grpc.StreamHandler
|
||||
var current int
|
||||
|
||||
chainHandler = func(curSrv interface{}, curStream grpc.ServerStream) error {
|
||||
if current == last {
|
||||
return handler(curSrv, curStream)
|
||||
}
|
||||
|
||||
current++
|
||||
err := interceptors[current](curSrv, curStream, info, chainHandler)
|
||||
current--
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return interceptors[0](srv, stream, info, chainHandler)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func chainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
|
||||
switch len(interceptors) {
|
||||
case 0:
|
||||
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
|
||||
interface{}, error) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
case 1:
|
||||
return interceptors[0]
|
||||
default:
|
||||
last := len(interceptors) - 1
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
|
||||
interface{}, error) {
|
||||
var chainHandler grpc.UnaryHandler
|
||||
var current int
|
||||
|
||||
chainHandler = func(curCtx context.Context, curReq interface{}) (interface{}, error) {
|
||||
if current == last {
|
||||
return handler(curCtx, curReq)
|
||||
}
|
||||
|
||||
current++
|
||||
resp, err := interceptors[current](curCtx, curReq, info, chainHandler)
|
||||
current--
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
return interceptors[0](ctx, req, info, chainHandler)
|
||||
}
|
||||
}
|
||||
}
|
||||
111
zrpc/internal/chainserverinterceptors_test.go
Normal file
111
zrpc/internal/chainserverinterceptors_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestWithStreamServerInterceptors(t *testing.T) {
|
||||
opts := WithStreamServerInterceptors()
|
||||
assert.NotNil(t, opts)
|
||||
}
|
||||
|
||||
func TestWithUnaryServerInterceptors(t *testing.T) {
|
||||
opts := WithUnaryServerInterceptors()
|
||||
assert.NotNil(t, opts)
|
||||
}
|
||||
|
||||
func TestChainStreamServerInterceptors_zero(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamServerInterceptors()
|
||||
err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
vals = append(vals, 1)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
|
||||
func TestChainStreamServerInterceptors_one(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
vals = append(vals, 1)
|
||||
return handler(srv, ss)
|
||||
})
|
||||
err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
vals = append(vals, 2)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2}, vals)
|
||||
}
|
||||
|
||||
func TestChainStreamServerInterceptors_more(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
vals = append(vals, 1)
|
||||
return handler(srv, ss)
|
||||
}, func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
vals = append(vals, 2)
|
||||
return handler(srv, ss)
|
||||
})
|
||||
err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
vals = append(vals, 3)
|
||||
return nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
|
||||
}
|
||||
|
||||
func TestChainUnaryServerInterceptors_zero(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryServerInterceptors()
|
||||
_, err := interceptors(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
vals = append(vals, 1)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
|
||||
func TestChainUnaryServerInterceptors_one(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
|
||||
info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
vals = append(vals, 1)
|
||||
return handler(ctx, req)
|
||||
})
|
||||
_, err := interceptors(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
vals = append(vals, 2)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2}, vals)
|
||||
}
|
||||
|
||||
func TestChainUnaryServerInterceptors_more(t *testing.T) {
|
||||
var vals []int
|
||||
interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
|
||||
info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
vals = append(vals, 1)
|
||||
return handler(ctx, req)
|
||||
}, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
vals = append(vals, 2)
|
||||
return handler(ctx, req)
|
||||
})
|
||||
_, err := interceptors(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
vals = append(vals, 3)
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
|
||||
}
|
||||
90
zrpc/internal/client.go
Normal file
90
zrpc/internal/client.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/zrpc/internal/balancer/p2c"
|
||||
"github.com/tal-tech/go-zero/zrpc/internal/clientinterceptors"
|
||||
"github.com/tal-tech/go-zero/zrpc/internal/resolver"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const dialTimeout = time.Second * 3
|
||||
|
||||
func init() {
|
||||
resolver.RegisterResolver()
|
||||
}
|
||||
|
||||
type (
|
||||
ClientOptions struct {
|
||||
Timeout time.Duration
|
||||
DialOptions []grpc.DialOption
|
||||
}
|
||||
|
||||
ClientOption func(options *ClientOptions)
|
||||
|
||||
client struct {
|
||||
conn *grpc.ClientConn
|
||||
}
|
||||
)
|
||||
|
||||
func NewClient(target string, opts ...ClientOption) (*client, error) {
|
||||
opts = append(opts, WithDialOption(grpc.WithBalancerName(p2c.Name)))
|
||||
conn, err := dial(target, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &client{conn: conn}, nil
|
||||
}
|
||||
|
||||
func (c *client) Conn() *grpc.ClientConn {
|
||||
return c.conn
|
||||
}
|
||||
|
||||
func WithDialOption(opt grpc.DialOption) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.DialOptions = append(options.DialOptions, opt)
|
||||
}
|
||||
}
|
||||
|
||||
func WithTimeout(timeout time.Duration) ClientOption {
|
||||
return func(options *ClientOptions) {
|
||||
options.Timeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func buildDialOptions(opts ...ClientOption) []grpc.DialOption {
|
||||
var clientOptions ClientOptions
|
||||
for _, opt := range opts {
|
||||
opt(&clientOptions)
|
||||
}
|
||||
|
||||
options := []grpc.DialOption{
|
||||
grpc.WithInsecure(),
|
||||
grpc.WithBlock(),
|
||||
WithUnaryClientInterceptors(
|
||||
clientinterceptors.BreakerInterceptor,
|
||||
clientinterceptors.DurationInterceptor,
|
||||
clientinterceptors.PromMetricInterceptor,
|
||||
clientinterceptors.TimeoutInterceptor(clientOptions.Timeout),
|
||||
clientinterceptors.TracingInterceptor,
|
||||
),
|
||||
}
|
||||
|
||||
return append(options, clientOptions.DialOptions...)
|
||||
}
|
||||
|
||||
func dial(server string, opts ...ClientOption) (*grpc.ClientConn, error) {
|
||||
options := buildDialOptions(opts...)
|
||||
timeCtx, cancel := context.WithTimeout(context.Background(), dialTimeout)
|
||||
defer cancel()
|
||||
conn, err := grpc.DialContext(timeCtx, server, options...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rpc dial: %s, error: %s", server, err.Error())
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
30
zrpc/internal/client_test.go
Normal file
30
zrpc/internal/client_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestWithDialOption(t *testing.T) {
|
||||
var options ClientOptions
|
||||
agent := grpc.WithUserAgent("chrome")
|
||||
opt := WithDialOption(agent)
|
||||
opt(&options)
|
||||
assert.Contains(t, options.DialOptions, agent)
|
||||
}
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
var options ClientOptions
|
||||
opt := WithTimeout(time.Second)
|
||||
opt(&options)
|
||||
assert.Equal(t, time.Second, options.Timeout)
|
||||
}
|
||||
|
||||
func TestBuildDialOptions(t *testing.T) {
|
||||
agent := grpc.WithUserAgent("chrome")
|
||||
opts := buildDialOptions(WithDialOption(agent))
|
||||
assert.Contains(t, opts, agent)
|
||||
}
|
||||
18
zrpc/internal/clientinterceptors/breakerinterceptor.go
Normal file
18
zrpc/internal/clientinterceptors/breakerinterceptor.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/breaker"
|
||||
"github.com/tal-tech/go-zero/zrpc/internal/codes"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func BreakerInterceptor(ctx context.Context, method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
breakerName := path.Join(cc.Target(), method)
|
||||
return breaker.DoWithAcceptable(breakerName, func() error {
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}, codes.Acceptable)
|
||||
}
|
||||
81
zrpc/internal/clientinterceptors/breakerinterceptor_test.go
Normal file
81
zrpc/internal/clientinterceptors/breakerinterceptor_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/breaker"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
rcodes "github.com/tal-tech/go-zero/zrpc/internal/codes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func init() {
|
||||
stat.SetReporter(nil)
|
||||
}
|
||||
|
||||
type mockError struct {
|
||||
st *status.Status
|
||||
}
|
||||
|
||||
func (m mockError) GRPCStatus() *status.Status {
|
||||
return m.st
|
||||
}
|
||||
|
||||
func (m mockError) Error() string {
|
||||
return "mocked error"
|
||||
}
|
||||
|
||||
func TestBreakerInterceptorNotFound(t *testing.T) {
|
||||
err := mockError{st: status.New(codes.NotFound, "any")}
|
||||
for i := 0; i < 1000; i++ {
|
||||
assert.Equal(t, err, breaker.DoWithAcceptable("call", func() error {
|
||||
return err
|
||||
}, rcodes.Acceptable))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBreakerInterceptorDeadlineExceeded(t *testing.T) {
|
||||
err := mockError{st: status.New(codes.DeadlineExceeded, "any")}
|
||||
errs := make(map[error]int)
|
||||
for i := 0; i < 1000; i++ {
|
||||
e := breaker.DoWithAcceptable("call", func() error {
|
||||
return err
|
||||
}, rcodes.Acceptable)
|
||||
errs[e]++
|
||||
}
|
||||
assert.Equal(t, 2, len(errs))
|
||||
assert.True(t, errs[err] > 0)
|
||||
assert.True(t, errs[breaker.ErrServiceUnavailable] > 0)
|
||||
}
|
||||
|
||||
func TestBreakerInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "with error",
|
||||
err: errors.New("mock"),
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cc := new(grpc.ClientConn)
|
||||
err := BreakerInterceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
return test.err
|
||||
})
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
30
zrpc/internal/clientinterceptors/durationinterceptor.go
Normal file
30
zrpc/internal/clientinterceptors/durationinterceptor.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/timex"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
func DurationInterceptor(ctx context.Context, method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
serverName := path.Join(cc.Target(), method)
|
||||
start := timex.Now()
|
||||
err := invoker(ctx, method, req, reply, cc, opts...)
|
||||
if err != nil {
|
||||
logx.WithDuration(timex.Since(start)).Infof("fail - %s - %v - %s", serverName, req, err.Error())
|
||||
} else {
|
||||
elapsed := timex.Since(start)
|
||||
if elapsed > slowThreshold {
|
||||
logx.WithDuration(elapsed).Slowf("[RPC] ok - slowcall - %s - %v - %v", serverName, req, reply)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
37
zrpc/internal/clientinterceptors/durationinterceptor_test.go
Normal file
37
zrpc/internal/clientinterceptors/durationinterceptor_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestDurationInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "with error",
|
||||
err: errors.New("mock"),
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cc := new(grpc.ClientConn)
|
||||
err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
return test.err
|
||||
})
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
42
zrpc/internal/clientinterceptors/prommetricinterceptor.go
Normal file
42
zrpc/internal/clientinterceptors/prommetricinterceptor.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/metric"
|
||||
"github.com/tal-tech/go-zero/core/timex"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const clientNamespace = "rpc_client"
|
||||
|
||||
var (
|
||||
metricClientReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{
|
||||
Namespace: clientNamespace,
|
||||
Subsystem: "requests",
|
||||
Name: "duration_ms",
|
||||
Help: "rpc client requests duration(ms).",
|
||||
Labels: []string{"method"},
|
||||
Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000},
|
||||
})
|
||||
|
||||
metricClientReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{
|
||||
Namespace: clientNamespace,
|
||||
Subsystem: "requests",
|
||||
Name: "code_total",
|
||||
Help: "rpc client requests code count.",
|
||||
Labels: []string{"method", "code"},
|
||||
})
|
||||
)
|
||||
|
||||
func PromMetricInterceptor(ctx context.Context, method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
startTime := timex.Now()
|
||||
err := invoker(ctx, method, req, reply, cc, opts...)
|
||||
metricClientReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), method)
|
||||
metricClientReqCodeTotal.Inc(method, strconv.Itoa(int(status.Code(err))))
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestPromMetricInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "with error",
|
||||
err: errors.New("mock"),
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cc := new(grpc.ClientConn)
|
||||
err := PromMetricInterceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
return test.err
|
||||
})
|
||||
assert.Equal(t, test.err, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
24
zrpc/internal/clientinterceptors/timeoutinterceptor.go
Normal file
24
zrpc/internal/clientinterceptors/timeoutinterceptor.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/contextx"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const defaultTimeout = time.Second * 2
|
||||
|
||||
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
|
||||
if timeout <= 0 {
|
||||
timeout = defaultTimeout
|
||||
}
|
||||
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
|
||||
defer cancel()
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
}
|
||||
50
zrpc/internal/clientinterceptors/timeoutinterceptor_test.go
Normal file
50
zrpc/internal/clientinterceptors/timeoutinterceptor_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestTimeoutInterceptor(t *testing.T) {
|
||||
timeouts := []time.Duration{0, time.Millisecond * 10}
|
||||
for _, timeout := range timeouts {
|
||||
t.Run(strconv.FormatInt(int64(timeout), 10), func(t *testing.T) {
|
||||
interceptor := TimeoutInterceptor(timeout)
|
||||
cc := new(grpc.ClientConn)
|
||||
err := interceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
return nil
|
||||
},
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutInterceptor_timeout(t *testing.T) {
|
||||
const timeout = time.Millisecond * 10
|
||||
interceptor := TimeoutInterceptor(timeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cc := new(grpc.ClientConn)
|
||||
err := interceptor(ctx, "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
defer wg.Done()
|
||||
tm, ok := ctx.Deadline()
|
||||
assert.True(t, ok)
|
||||
assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
|
||||
return nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
24
zrpc/internal/clientinterceptors/tracinginterceptor.go
Normal file
24
zrpc/internal/clientinterceptors/tracinginterceptor.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/trace"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TracingInterceptor(ctx context.Context, method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
ctx, span := trace.StartClientSpan(ctx, cc.Target(), method)
|
||||
defer span.Finish()
|
||||
|
||||
var pairs []string
|
||||
span.Visit(func(key, val string) bool {
|
||||
pairs = append(pairs, key, val)
|
||||
return true
|
||||
})
|
||||
ctx = metadata.AppendToOutgoingContext(ctx, pairs...)
|
||||
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
53
zrpc/internal/clientinterceptors/tracinginterceptor_test.go
Normal file
53
zrpc/internal/clientinterceptors/tracinginterceptor_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package clientinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/trace"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestTracingInterceptor(t *testing.T) {
|
||||
var run int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cc := new(grpc.ClientConn)
|
||||
err := TracingInterceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&run, 1)
|
||||
return nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&run))
|
||||
}
|
||||
|
||||
func TestTracingInterceptor_GrpcFormat(t *testing.T) {
|
||||
var run int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
md := metadata.New(map[string]string{
|
||||
"foo": "bar",
|
||||
})
|
||||
carrier, err := trace.Inject(trace.GrpcFormat, md)
|
||||
assert.Nil(t, err)
|
||||
ctx, _ := trace.StartServerSpan(context.Background(), carrier, "user", "/foo")
|
||||
cc := new(grpc.ClientConn)
|
||||
err = TracingInterceptor(ctx, "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&run, 1)
|
||||
return nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&run))
|
||||
}
|
||||
15
zrpc/internal/codes/accept.go
Normal file
15
zrpc/internal/codes/accept.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package codes
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func Acceptable(err error) bool {
|
||||
switch status.Code(err) {
|
||||
case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
34
zrpc/internal/codes/accept_test.go
Normal file
34
zrpc/internal/codes/accept_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package codes
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestAccept(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
accept bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
accept: true,
|
||||
},
|
||||
{
|
||||
name: "deadline error",
|
||||
err: status.Error(codes.DeadlineExceeded, "deadline"),
|
||||
accept: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assert.Equal(t, test.accept, Acceptable(test.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
159
zrpc/internal/mock/deposit.pb.go
Normal file
159
zrpc/internal/mock/deposit.pb.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Code generated by protoc-gen-go.
|
||||
// source: deposit.proto
|
||||
// DO NOT EDIT!
|
||||
|
||||
/*
|
||||
Package mock is a generated protocol buffer package.
|
||||
|
||||
It is generated from these files:
|
||||
deposit.proto
|
||||
|
||||
It has these top-level messages:
|
||||
DepositRequest
|
||||
DepositResponse
|
||||
*/
|
||||
package mock
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
|
||||
import (
|
||||
context "golang.org/x/net/context"
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type DepositRequest struct {
|
||||
Amount float32 `protobuf:"fixed32,1,opt,name=amount" json:"amount,omitempty"`
|
||||
}
|
||||
|
||||
func (m *DepositRequest) Reset() { *m = DepositRequest{} }
|
||||
func (m *DepositRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*DepositRequest) ProtoMessage() {}
|
||||
func (*DepositRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||
|
||||
func (m *DepositRequest) GetAmount() float32 {
|
||||
if m != nil {
|
||||
return m.Amount
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type DepositResponse struct {
|
||||
Ok bool `protobuf:"varint,1,opt,name=ok" json:"ok,omitempty"`
|
||||
}
|
||||
|
||||
func (m *DepositResponse) Reset() { *m = DepositResponse{} }
|
||||
func (m *DepositResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*DepositResponse) ProtoMessage() {}
|
||||
func (*DepositResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
|
||||
|
||||
func (m *DepositResponse) GetOk() bool {
|
||||
if m != nil {
|
||||
return m.Ok
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*DepositRequest)(nil), "mock.DepositRequest")
|
||||
proto.RegisterType((*DepositResponse)(nil), "mock.DepositResponse")
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// Client API for DepositService service
|
||||
|
||||
type DepositServiceClient interface {
|
||||
Deposit(ctx context.Context, in *DepositRequest, opts ...grpc.CallOption) (*DepositResponse, error)
|
||||
}
|
||||
|
||||
type depositServiceClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewDepositServiceClient(cc *grpc.ClientConn) DepositServiceClient {
|
||||
return &depositServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *depositServiceClient) Deposit(ctx context.Context, in *DepositRequest, opts ...grpc.CallOption) (*DepositResponse, error) {
|
||||
out := new(DepositResponse)
|
||||
err := grpc.Invoke(ctx, "/mock.DepositService/Deposit", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Server API for DepositService service
|
||||
|
||||
type DepositServiceServer interface {
|
||||
Deposit(context.Context, *DepositRequest) (*DepositResponse, error)
|
||||
}
|
||||
|
||||
func RegisterDepositServiceServer(s *grpc.Server, srv DepositServiceServer) {
|
||||
s.RegisterService(&_DepositService_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _DepositService_Deposit_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(DepositRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DepositServiceServer).Deposit(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/mock.DepositService/Deposit",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DepositServiceServer).Deposit(ctx, req.(*DepositRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _DepositService_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "mock.DepositService",
|
||||
HandlerType: (*DepositServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Deposit",
|
||||
Handler: _DepositService_Deposit_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "deposit.proto",
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("deposit.proto", fileDescriptor0) }
|
||||
|
||||
var fileDescriptor0 = []byte{
|
||||
// 139 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4d, 0x49, 0x2d, 0xc8,
|
||||
0x2f, 0xce, 0x2c, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0xc9, 0xcd, 0x4f, 0xce, 0x56,
|
||||
0xd2, 0xe0, 0xe2, 0x73, 0x81, 0x08, 0x07, 0xa5, 0x16, 0x96, 0xa6, 0x16, 0x97, 0x08, 0x89, 0x71,
|
||||
0xb1, 0x25, 0xe6, 0xe6, 0x97, 0xe6, 0x95, 0x48, 0x30, 0x2a, 0x30, 0x6a, 0x30, 0x05, 0x41, 0x79,
|
||||
0x4a, 0x8a, 0x5c, 0xfc, 0x70, 0x95, 0xc5, 0x05, 0xf9, 0x79, 0xc5, 0xa9, 0x42, 0x7c, 0x5c, 0x4c,
|
||||
0xf9, 0xd9, 0x60, 0x65, 0x1c, 0x41, 0x4c, 0xf9, 0xd9, 0x46, 0x1e, 0x70, 0xc3, 0x82, 0x53, 0x8b,
|
||||
0xca, 0x32, 0x93, 0x53, 0x85, 0xcc, 0xb8, 0xd8, 0xa1, 0x22, 0x42, 0x22, 0x7a, 0x20, 0x0b, 0xf5,
|
||||
0x50, 0x6d, 0x93, 0x12, 0x45, 0x13, 0x85, 0x98, 0x9c, 0xc4, 0x06, 0x76, 0xa3, 0x31, 0x20, 0x00,
|
||||
0x00, 0xff, 0xff, 0x62, 0x37, 0xf2, 0x36, 0xb4, 0x00, 0x00, 0x00,
|
||||
}
|
||||
15
zrpc/internal/mock/deposit.proto
Normal file
15
zrpc/internal/mock/deposit.proto
Normal file
@@ -0,0 +1,15 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package mock;
|
||||
|
||||
message DepositRequest {
|
||||
float amount = 1;
|
||||
}
|
||||
|
||||
message DepositResponse {
|
||||
bool ok = 1;
|
||||
}
|
||||
|
||||
service DepositService {
|
||||
rpc Deposit(DepositRequest) returns (DepositResponse);
|
||||
}
|
||||
19
zrpc/internal/mock/depositserver.go
Normal file
19
zrpc/internal/mock/depositserver.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type DepositServer struct {
|
||||
}
|
||||
|
||||
func (*DepositServer) Deposit(ctx context.Context, req *DepositRequest) (*DepositResponse, error) {
|
||||
if req.GetAmount() < 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "cannot deposit %v", req.GetAmount())
|
||||
}
|
||||
|
||||
return &DepositResponse{Ok: true}, nil
|
||||
}
|
||||
32
zrpc/internal/resolver/directbuilder.go
Normal file
32
zrpc/internal/resolver/directbuilder.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
type directBuilder struct{}
|
||||
|
||||
func (d *directBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
|
||||
resolver.Resolver, error) {
|
||||
var addrs []resolver.Address
|
||||
endpoints := strings.FieldsFunc(target.Endpoint, func(r rune) bool {
|
||||
return r == EndpointSepChar
|
||||
})
|
||||
|
||||
for _, val := range subset(endpoints, subsetSize) {
|
||||
addrs = append(addrs, resolver.Address{
|
||||
Addr: val,
|
||||
})
|
||||
}
|
||||
cc.UpdateState(resolver.State{
|
||||
Addresses: addrs,
|
||||
})
|
||||
|
||||
return &nopResolver{cc: cc}, nil
|
||||
}
|
||||
|
||||
func (d *directBuilder) Scheme() string {
|
||||
return DirectScheme
|
||||
}
|
||||
52
zrpc/internal/resolver/directbuilder_test.go
Normal file
52
zrpc/internal/resolver/directbuilder_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/lang"
|
||||
"github.com/tal-tech/go-zero/core/mathx"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
func TestDirectBuilder_Build(t *testing.T) {
|
||||
tests := []int{
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
subsetSize / 2,
|
||||
subsetSize,
|
||||
subsetSize * 2,
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(strconv.Itoa(test), func(t *testing.T) {
|
||||
var servers []string
|
||||
for i := 0; i < test; i++ {
|
||||
servers = append(servers, fmt.Sprintf("localhost:%d", i))
|
||||
}
|
||||
var b directBuilder
|
||||
cc := new(mockedClientConn)
|
||||
_, err := b.Build(resolver.Target{
|
||||
Scheme: DirectScheme,
|
||||
Endpoint: strings.Join(servers, ","),
|
||||
}, cc, resolver.BuildOptions{})
|
||||
assert.Nil(t, err)
|
||||
size := mathx.MinInt(test, subsetSize)
|
||||
assert.Equal(t, size, len(cc.state.Addresses))
|
||||
m := make(map[string]lang.PlaceholderType)
|
||||
for _, each := range cc.state.Addresses {
|
||||
m[each.Addr] = lang.Placeholder
|
||||
}
|
||||
assert.Equal(t, size, len(m))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectBuilder_Scheme(t *testing.T) {
|
||||
var b directBuilder
|
||||
assert.Equal(t, DirectScheme, b.Scheme())
|
||||
}
|
||||
41
zrpc/internal/resolver/discovbuilder.go
Normal file
41
zrpc/internal/resolver/discovbuilder.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/discov"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
type discovBuilder struct{}
|
||||
|
||||
func (d *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
|
||||
resolver.Resolver, error) {
|
||||
hosts := strings.FieldsFunc(target.Authority, func(r rune) bool {
|
||||
return r == EndpointSepChar
|
||||
})
|
||||
sub, err := discov.NewSubscriber(hosts, target.Endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
update := func() {
|
||||
var addrs []resolver.Address
|
||||
for _, val := range subset(sub.Values(), subsetSize) {
|
||||
addrs = append(addrs, resolver.Address{
|
||||
Addr: val,
|
||||
})
|
||||
}
|
||||
cc.UpdateState(resolver.State{
|
||||
Addresses: addrs,
|
||||
})
|
||||
}
|
||||
sub.AddListener(update)
|
||||
update()
|
||||
|
||||
return &nopResolver{cc: cc}, nil
|
||||
}
|
||||
|
||||
func (d *discovBuilder) Scheme() string {
|
||||
return DiscovScheme
|
||||
}
|
||||
35
zrpc/internal/resolver/resolver.go
Normal file
35
zrpc/internal/resolver/resolver.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
const (
|
||||
DirectScheme = "direct"
|
||||
DiscovScheme = "discov"
|
||||
EndpointSepChar = ','
|
||||
subsetSize = 32
|
||||
)
|
||||
|
||||
var (
|
||||
EndpointSep = fmt.Sprintf("%c", EndpointSepChar)
|
||||
dirBuilder directBuilder
|
||||
disBuilder discovBuilder
|
||||
)
|
||||
|
||||
func RegisterResolver() {
|
||||
resolver.Register(&dirBuilder)
|
||||
resolver.Register(&disBuilder)
|
||||
}
|
||||
|
||||
type nopResolver struct {
|
||||
cc resolver.ClientConn
|
||||
}
|
||||
|
||||
func (r *nopResolver) Close() {
|
||||
}
|
||||
|
||||
func (r *nopResolver) ResolveNow(options resolver.ResolveNowOptions) {
|
||||
}
|
||||
36
zrpc/internal/resolver/resolver_test.go
Normal file
36
zrpc/internal/resolver/resolver_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/serviceconfig"
|
||||
)
|
||||
|
||||
func TestNopResolver(t *testing.T) {
|
||||
// make sure ResolveNow & Close don't panic
|
||||
var r nopResolver
|
||||
r.ResolveNow(resolver.ResolveNowOptions{})
|
||||
r.Close()
|
||||
}
|
||||
|
||||
type mockedClientConn struct {
|
||||
state resolver.State
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) UpdateState(state resolver.State) {
|
||||
m.state = state
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) ReportError(err error) {
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) NewAddress(addresses []resolver.Address) {
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) NewServiceConfig(serviceConfig string) {
|
||||
}
|
||||
|
||||
func (m *mockedClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult {
|
||||
return nil
|
||||
}
|
||||
14
zrpc/internal/resolver/subset.go
Normal file
14
zrpc/internal/resolver/subset.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package resolver
|
||||
|
||||
import "math/rand"
|
||||
|
||||
func subset(set []string, sub int) []string {
|
||||
rand.Shuffle(len(set), func(i, j int) {
|
||||
set[i], set[j] = set[j], set[i]
|
||||
})
|
||||
if len(set) <= sub {
|
||||
return set
|
||||
} else {
|
||||
return set[:sub]
|
||||
}
|
||||
}
|
||||
53
zrpc/internal/resolver/subset_test.go
Normal file
53
zrpc/internal/resolver/subset_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/mathx"
|
||||
)
|
||||
|
||||
func TestSubset(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
set int
|
||||
sub int
|
||||
}{
|
||||
{
|
||||
name: "more vals to subset",
|
||||
set: 100,
|
||||
sub: 36,
|
||||
},
|
||||
{
|
||||
name: "less vals to subset",
|
||||
set: 100,
|
||||
sub: 200,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var vals []string
|
||||
for i := 0; i < test.set; i++ {
|
||||
vals = append(vals, strconv.Itoa(i))
|
||||
}
|
||||
|
||||
m := make(map[interface{}]int)
|
||||
for i := 0; i < 1000; i++ {
|
||||
set := subset(append([]string(nil), vals...), test.sub)
|
||||
if test.sub < test.set {
|
||||
assert.Equal(t, test.sub, len(set))
|
||||
} else {
|
||||
assert.Equal(t, test.set, len(set))
|
||||
}
|
||||
|
||||
for _, val := range set {
|
||||
m[val]++
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, mathx.CalcEntropy(m) > 0.95)
|
||||
})
|
||||
}
|
||||
}
|
||||
73
zrpc/internal/rpclogger.go
Normal file
73
zrpc/internal/rpclogger.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
)
|
||||
|
||||
// because grpclog.errorLog is not exported, we need to define our own.
|
||||
const errorLevel = 2
|
||||
|
||||
var once sync.Once
|
||||
|
||||
type Logger struct{}
|
||||
|
||||
func InitLogger() {
|
||||
once.Do(func() {
|
||||
grpclog.SetLoggerV2(new(Logger))
|
||||
})
|
||||
}
|
||||
|
||||
func (l *Logger) Error(args ...interface{}) {
|
||||
logx.Error(args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
logx.Errorf(format, args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Errorln(args ...interface{}) {
|
||||
logx.Error(args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Fatal(args ...interface{}) {
|
||||
logx.Error(args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Fatalf(format string, args ...interface{}) {
|
||||
logx.Errorf(format, args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Fatalln(args ...interface{}) {
|
||||
logx.Error(args...)
|
||||
}
|
||||
|
||||
func (l *Logger) Info(args ...interface{}) {
|
||||
// ignore builtin grpc info
|
||||
}
|
||||
|
||||
func (l *Logger) Infoln(args ...interface{}) {
|
||||
// ignore builtin grpc info
|
||||
}
|
||||
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
// ignore builtin grpc info
|
||||
}
|
||||
|
||||
func (l *Logger) V(v int) bool {
|
||||
return v >= errorLevel
|
||||
}
|
||||
|
||||
func (l *Logger) Warning(args ...interface{}) {
|
||||
// ignore builtin grpc warning
|
||||
}
|
||||
|
||||
func (l *Logger) Warningln(args ...interface{}) {
|
||||
// ignore builtin grpc warning
|
||||
}
|
||||
|
||||
func (l *Logger) Warningf(format string, args ...interface{}) {
|
||||
// ignore builtin grpc warning
|
||||
}
|
||||
29
zrpc/internal/rpcpubserver.go
Normal file
29
zrpc/internal/rpcpubserver.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package internal
|
||||
|
||||
import "github.com/tal-tech/go-zero/core/discov"
|
||||
|
||||
func NewRpcPubServer(etcdEndpoints []string, etcdKey, listenOn string, opts ...ServerOption) (Server, error) {
|
||||
registerEtcd := func() error {
|
||||
pubClient := discov.NewPublisher(etcdEndpoints, etcdKey, listenOn)
|
||||
return pubClient.KeepAlive()
|
||||
}
|
||||
server := keepAliveServer{
|
||||
registerEtcd: registerEtcd,
|
||||
Server: NewRpcServer(listenOn, opts...),
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
type keepAliveServer struct {
|
||||
registerEtcd func() error
|
||||
Server
|
||||
}
|
||||
|
||||
func (ags keepAliveServer) Start(fn RegisterFn) error {
|
||||
if err := ags.registerEtcd(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ags.Server.Start(fn)
|
||||
}
|
||||
81
zrpc/internal/rpcserver.go
Normal file
81
zrpc/internal/rpcserver.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/proc"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"github.com/tal-tech/go-zero/zrpc/internal/serverinterceptors"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type (
|
||||
ServerOption func(options *rpcServerOptions)
|
||||
|
||||
rpcServerOptions struct {
|
||||
metrics *stat.Metrics
|
||||
}
|
||||
|
||||
rpcServer struct {
|
||||
*baseRpcServer
|
||||
}
|
||||
)
|
||||
|
||||
func init() {
|
||||
InitLogger()
|
||||
}
|
||||
|
||||
func NewRpcServer(address string, opts ...ServerOption) Server {
|
||||
var options rpcServerOptions
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
if options.metrics == nil {
|
||||
options.metrics = stat.NewMetrics(address)
|
||||
}
|
||||
|
||||
return &rpcServer{
|
||||
baseRpcServer: newBaseRpcServer(address, options.metrics),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *rpcServer) SetName(name string) {
|
||||
s.baseRpcServer.SetName(name)
|
||||
}
|
||||
|
||||
func (s *rpcServer) Start(register RegisterFn) error {
|
||||
lis, err := net.Listen("tcp", s.address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unaryInterceptors := []grpc.UnaryServerInterceptor{
|
||||
serverinterceptors.UnaryCrashInterceptor(),
|
||||
serverinterceptors.UnaryStatInterceptor(s.metrics),
|
||||
serverinterceptors.UnaryPromMetricInterceptor(),
|
||||
}
|
||||
unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...)
|
||||
streamInterceptors := []grpc.StreamServerInterceptor{
|
||||
serverinterceptors.StreamCrashInterceptor,
|
||||
}
|
||||
streamInterceptors = append(streamInterceptors, s.streamInterceptors...)
|
||||
options := append(s.options, WithUnaryServerInterceptors(unaryInterceptors...),
|
||||
WithStreamServerInterceptors(streamInterceptors...))
|
||||
server := grpc.NewServer(options...)
|
||||
register(server)
|
||||
// we need to make sure all others are wrapped up
|
||||
// so we do graceful stop at shutdown phase instead of wrap up phase
|
||||
shutdownCalled := proc.AddShutdownListener(func() {
|
||||
server.GracefulStop()
|
||||
})
|
||||
err = server.Serve(lis)
|
||||
shutdownCalled()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func WithMetrics(metrics *stat.Metrics) ServerOption {
|
||||
return func(options *rpcServerOptions) {
|
||||
options.metrics = metrics
|
||||
}
|
||||
}
|
||||
16
zrpc/internal/rpcserver_test.go
Normal file
16
zrpc/internal/rpcserver_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
)
|
||||
|
||||
func TestWithMetrics(t *testing.T) {
|
||||
metrics := stat.NewMetrics("foo")
|
||||
opt := WithMetrics(metrics)
|
||||
var options rpcServerOptions
|
||||
opt(&options)
|
||||
assert.Equal(t, metrics, options.metrics)
|
||||
}
|
||||
49
zrpc/internal/server.go
Normal file
49
zrpc/internal/server.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type (
|
||||
RegisterFn func(*grpc.Server)
|
||||
|
||||
Server interface {
|
||||
AddOptions(options ...grpc.ServerOption)
|
||||
AddStreamInterceptors(interceptors ...grpc.StreamServerInterceptor)
|
||||
AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor)
|
||||
SetName(string)
|
||||
Start(register RegisterFn) error
|
||||
}
|
||||
|
||||
baseRpcServer struct {
|
||||
address string
|
||||
metrics *stat.Metrics
|
||||
options []grpc.ServerOption
|
||||
streamInterceptors []grpc.StreamServerInterceptor
|
||||
unaryInterceptors []grpc.UnaryServerInterceptor
|
||||
}
|
||||
)
|
||||
|
||||
func newBaseRpcServer(address string, metrics *stat.Metrics) *baseRpcServer {
|
||||
return &baseRpcServer{
|
||||
address: address,
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *baseRpcServer) AddOptions(options ...grpc.ServerOption) {
|
||||
s.options = append(s.options, options...)
|
||||
}
|
||||
|
||||
func (s *baseRpcServer) AddStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) {
|
||||
s.streamInterceptors = append(s.streamInterceptors, interceptors...)
|
||||
}
|
||||
|
||||
func (s *baseRpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) {
|
||||
s.unaryInterceptors = append(s.unaryInterceptors, interceptors...)
|
||||
}
|
||||
|
||||
func (s *baseRpcServer) SetName(name string) {
|
||||
s.metrics.SetName(name)
|
||||
}
|
||||
53
zrpc/internal/server_test.go
Normal file
53
zrpc/internal/server_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestBaseRpcServer_AddOptions(t *testing.T) {
|
||||
metrics := stat.NewMetrics("foo")
|
||||
server := newBaseRpcServer("foo", metrics)
|
||||
server.SetName("bar")
|
||||
var opt grpc.EmptyServerOption
|
||||
server.AddOptions(opt)
|
||||
assert.Contains(t, server.options, opt)
|
||||
}
|
||||
|
||||
func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) {
|
||||
metrics := stat.NewMetrics("foo")
|
||||
server := newBaseRpcServer("foo", metrics)
|
||||
server.SetName("bar")
|
||||
var vals []int
|
||||
f := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
vals = append(vals, 1)
|
||||
return nil
|
||||
}
|
||||
server.AddStreamInterceptors(f)
|
||||
for _, each := range server.streamInterceptors {
|
||||
assert.Nil(t, each(nil, nil, nil, nil))
|
||||
}
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
|
||||
func TestBaseRpcServer_AddUnaryInterceptors(t *testing.T) {
|
||||
metrics := stat.NewMetrics("foo")
|
||||
server := newBaseRpcServer("foo", metrics)
|
||||
server.SetName("bar")
|
||||
var vals []int
|
||||
f := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
|
||||
resp interface{}, err error) {
|
||||
vals = append(vals, 1)
|
||||
return nil, nil
|
||||
}
|
||||
server.AddUnaryInterceptors(f)
|
||||
for _, each := range server.unaryInterceptors {
|
||||
_, err := each(context.Background(), nil, nil, nil)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
assert.ElementsMatch(t, []int{1}, vals)
|
||||
}
|
||||
30
zrpc/internal/serverinterceptors/authinterceptor.go
Normal file
30
zrpc/internal/serverinterceptors/authinterceptor.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tal-tech/go-zero/zrpc/internal/auth"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func StreamAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) error {
|
||||
if err := authenticator.Authenticate(stream.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return handler(srv, stream)
|
||||
}
|
||||
}
|
||||
|
||||
func UnaryAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if err := authenticator.Authenticate(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
200
zrpc/internal/serverinterceptors/authinterceptor_test.go
Normal file
200
zrpc/internal/serverinterceptors/authinterceptor_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/alicebob/miniredis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/stores/redis"
|
||||
"github.com/tal-tech/go-zero/zrpc/internal/auth"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestStreamAuthorizeInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
app string
|
||||
token string
|
||||
strict bool
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "strict=false",
|
||||
strict: false,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with token",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
strict: true,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with error token",
|
||||
app: "foo",
|
||||
token: "error",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
}
|
||||
|
||||
r := miniredis.NewMiniRedis()
|
||||
assert.Nil(t, r.Start())
|
||||
defer r.Close()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
store := redis.NewRedis(r.Addr(), redis.NodeType)
|
||||
if len(test.app) > 0 {
|
||||
assert.Nil(t, store.Hset("apps", test.app, test.token))
|
||||
defer store.Hdel("apps", test.app)
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
|
||||
assert.Nil(t, err)
|
||||
interceptor := StreamAuthorizeInterceptor(authenticator)
|
||||
md := metadata.New(map[string]string{
|
||||
"app": "foo",
|
||||
"token": "bar",
|
||||
})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
stream := mockedStream{ctx: ctx}
|
||||
err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
|
||||
return nil
|
||||
})
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnaryAuthorizeInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
app string
|
||||
token string
|
||||
strict bool
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "strict=false",
|
||||
strict: false,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with token",
|
||||
app: "foo",
|
||||
token: "bar",
|
||||
strict: true,
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "strict=true,with error token",
|
||||
app: "foo",
|
||||
token: "error",
|
||||
strict: true,
|
||||
hasError: true,
|
||||
},
|
||||
}
|
||||
|
||||
r := miniredis.NewMiniRedis()
|
||||
assert.Nil(t, r.Start())
|
||||
defer r.Close()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
store := redis.NewRedis(r.Addr(), redis.NodeType)
|
||||
if len(test.app) > 0 {
|
||||
assert.Nil(t, store.Hset("apps", test.app, test.token))
|
||||
defer store.Hdel("apps", test.app)
|
||||
}
|
||||
|
||||
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
|
||||
assert.Nil(t, err)
|
||||
interceptor := UnaryAuthorizeInterceptor(authenticator)
|
||||
md := metadata.New(map[string]string{
|
||||
"app": "foo",
|
||||
"token": "bar",
|
||||
})
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
} else {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
if test.strict {
|
||||
_, err = interceptor(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
var md metadata.MD
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
md = metadata.New(map[string]string{
|
||||
"app": "",
|
||||
"token": "",
|
||||
})
|
||||
ctx = metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err = interceptor(ctx, nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockedStream struct {
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m mockedStream) SetHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockedStream) SendHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockedStream) SetTrailer(md metadata.MD) {
|
||||
}
|
||||
|
||||
func (m mockedStream) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m mockedStream) SendMsg(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockedStream) RecvMsg(v interface{}) error {
|
||||
return nil
|
||||
}
|
||||
42
zrpc/internal/serverinterceptors/crashinterceptor.go
Normal file
42
zrpc/internal/serverinterceptors/crashinterceptor.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func StreamCrashInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler) (err error) {
|
||||
defer handleCrash(func(r interface{}) {
|
||||
err = toPanicError(r)
|
||||
})
|
||||
|
||||
return handler(srv, stream)
|
||||
}
|
||||
|
||||
func UnaryCrashInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
defer handleCrash(func(r interface{}) {
|
||||
err = toPanicError(r)
|
||||
})
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func handleCrash(handler func(interface{})) {
|
||||
if r := recover(); r != nil {
|
||||
handler(r)
|
||||
}
|
||||
}
|
||||
|
||||
func toPanicError(r interface{}) error {
|
||||
logx.Errorf("%+v %s", r, debug.Stack())
|
||||
return status.Errorf(codes.Internal, "panic: %v", r)
|
||||
}
|
||||
31
zrpc/internal/serverinterceptors/crashinterceptor_test.go
Normal file
31
zrpc/internal/serverinterceptors/crashinterceptor_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func init() {
|
||||
logx.Disable()
|
||||
}
|
||||
|
||||
func TestStreamCrashInterceptor(t *testing.T) {
|
||||
err := StreamCrashInterceptor(nil, nil, nil, func(
|
||||
srv interface{}, stream grpc.ServerStream) error {
|
||||
panic("mock panic")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
|
||||
func TestUnaryCrashInterceptor(t *testing.T) {
|
||||
interceptor := UnaryCrashInterceptor()
|
||||
_, err := interceptor(context.Background(), nil, nil,
|
||||
func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
panic("mock panic")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
44
zrpc/internal/serverinterceptors/prommetricinterceptor.go
Normal file
44
zrpc/internal/serverinterceptors/prommetricinterceptor.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/metric"
|
||||
"github.com/tal-tech/go-zero/core/timex"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const serverNamespace = "rpc_server"
|
||||
|
||||
var (
|
||||
metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{
|
||||
Namespace: serverNamespace,
|
||||
Subsystem: "requests",
|
||||
Name: "duration_ms",
|
||||
Help: "rpc server requests duration(ms).",
|
||||
Labels: []string{"method"},
|
||||
Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000},
|
||||
})
|
||||
|
||||
metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{
|
||||
Namespace: serverNamespace,
|
||||
Subsystem: "requests",
|
||||
Name: "code_total",
|
||||
Help: "rpc server requests code count.",
|
||||
Labels: []string{"method", "code"},
|
||||
})
|
||||
)
|
||||
|
||||
func UnaryPromMetricInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
|
||||
interface{}, error) {
|
||||
startTime := timex.Now()
|
||||
resp, err := handler(ctx, req)
|
||||
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), info.FullMethod)
|
||||
metricServerReqCodeTotal.Inc(info.FullMethod, strconv.Itoa(int(status.Code(err))))
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestUnaryPromMetricInterceptor(t *testing.T) {
|
||||
interceptor := UnaryPromMetricInterceptor()
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
52
zrpc/internal/serverinterceptors/sheddinginterceptor.go
Normal file
52
zrpc/internal/serverinterceptors/sheddinginterceptor.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/load"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const serviceType = "rpc"
|
||||
|
||||
var (
|
||||
sheddingStat *load.SheddingStat
|
||||
lock sync.Mutex
|
||||
)
|
||||
|
||||
func UnarySheddingInterceptor(shedder load.Shedder, metrics *stat.Metrics) grpc.UnaryServerInterceptor {
|
||||
ensureSheddingStat()
|
||||
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (val interface{}, err error) {
|
||||
sheddingStat.IncrementTotal()
|
||||
var promise load.Promise
|
||||
promise, err = shedder.Allow()
|
||||
if err != nil {
|
||||
metrics.AddDrop()
|
||||
sheddingStat.IncrementDrop()
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err == context.DeadlineExceeded {
|
||||
promise.Fail()
|
||||
} else {
|
||||
sheddingStat.IncrementPass()
|
||||
promise.Pass()
|
||||
}
|
||||
}()
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func ensureSheddingStat() {
|
||||
lock.Lock()
|
||||
if sheddingStat == nil {
|
||||
sheddingStat = load.NewSheddingStat(serviceType)
|
||||
}
|
||||
lock.Unlock()
|
||||
}
|
||||
77
zrpc/internal/serverinterceptors/sheddinginterceptor_test.go
Normal file
77
zrpc/internal/serverinterceptors/sheddinginterceptor_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/load"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestUnarySheddingInterceptor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allow bool
|
||||
handleErr error
|
||||
expect error
|
||||
}{
|
||||
{
|
||||
name: "allow",
|
||||
allow: true,
|
||||
handleErr: nil,
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
name: "allow",
|
||||
allow: true,
|
||||
handleErr: context.DeadlineExceeded,
|
||||
expect: context.DeadlineExceeded,
|
||||
},
|
||||
{
|
||||
name: "reject",
|
||||
allow: false,
|
||||
handleErr: nil,
|
||||
expect: load.ErrServiceOverloaded,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
shedder := mockedShedder{allow: test.allow}
|
||||
metrics := stat.NewMetrics("mock")
|
||||
interceptor := UnarySheddingInterceptor(shedder, metrics)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, test.handleErr
|
||||
})
|
||||
assert.Equal(t, test.expect, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockedShedder struct {
|
||||
allow bool
|
||||
}
|
||||
|
||||
func (m mockedShedder) Allow() (load.Promise, error) {
|
||||
if m.allow {
|
||||
return mockedPromise{}, nil
|
||||
} else {
|
||||
return nil, load.ErrServiceOverloaded
|
||||
}
|
||||
}
|
||||
|
||||
type mockedPromise struct {
|
||||
}
|
||||
|
||||
func (m mockedPromise) Pass() {
|
||||
}
|
||||
|
||||
func (m mockedPromise) Fail() {
|
||||
}
|
||||
51
zrpc/internal/serverinterceptors/statinterceptor.go
Normal file
51
zrpc/internal/serverinterceptors/statinterceptor.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/logx"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"github.com/tal-tech/go-zero/core/timex"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
const serverSlowThreshold = time.Millisecond * 500
|
||||
|
||||
func UnaryStatInterceptor(metrics *stat.Metrics) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
defer handleCrash(func(r interface{}) {
|
||||
err = toPanicError(r)
|
||||
})
|
||||
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
metrics.Add(stat.Task{
|
||||
Duration: duration,
|
||||
})
|
||||
logDuration(ctx, info.FullMethod, req, duration)
|
||||
}()
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func logDuration(ctx context.Context, method string, req interface{}, duration time.Duration) {
|
||||
var addr string
|
||||
client, ok := peer.FromContext(ctx)
|
||||
if ok {
|
||||
addr = client.Addr.String()
|
||||
}
|
||||
content, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
logx.Errorf("%s - %s", addr, err.Error())
|
||||
} else if duration > serverSlowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[RPC] slowcall - %s - %s - %s", addr, method, string(content))
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("%s - %s - %s", addr, method, string(content))
|
||||
}
|
||||
}
|
||||
32
zrpc/internal/serverinterceptors/statinterceptor_test.go
Normal file
32
zrpc/internal/serverinterceptors/statinterceptor_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/stat"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestUnaryStatInterceptor(t *testing.T) {
|
||||
metrics := stat.NewMetrics("mock")
|
||||
interceptor := UnaryStatInterceptor(metrics)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestUnaryStatInterceptor_crash(t *testing.T) {
|
||||
metrics := stat.NewMetrics("mock")
|
||||
interceptor := UnaryStatInterceptor(metrics)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
panic("error")
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
}
|
||||
18
zrpc/internal/serverinterceptors/timeoutinterceptor.go
Normal file
18
zrpc/internal/serverinterceptors/timeoutinterceptor.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/contextx"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
|
||||
defer cancel()
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
41
zrpc/internal/serverinterceptors/timeoutinterceptor_test.go
Normal file
41
zrpc/internal/serverinterceptors/timeoutinterceptor_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestUnaryTimeoutInterceptor(t *testing.T) {
|
||||
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
|
||||
const timeout = time.Millisecond * 10
|
||||
interceptor := UnaryTimeoutInterceptor(timeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
defer wg.Done()
|
||||
tm, ok := ctx.Deadline()
|
||||
assert.True(t, ok)
|
||||
assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
|
||||
return nil, nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
28
zrpc/internal/serverinterceptors/tracinginterceptor.go
Normal file
28
zrpc/internal/serverinterceptors/tracinginterceptor.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/trace"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func UnaryTracingInterceptor(serviceName string) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
carrier, err := trace.Extract(trace.GrpcFormat, md)
|
||||
if err != nil {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
ctx, span := trace.StartServerSpan(ctx, carrier, serviceName, info.FullMethod)
|
||||
defer span.Finish()
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
48
zrpc/internal/serverinterceptors/tracinginterceptor_test.go
Normal file
48
zrpc/internal/serverinterceptors/tracinginterceptor_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/core/trace/tracespec"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestUnaryTracingInterceptor(t *testing.T) {
|
||||
interceptor := UnaryTracingInterceptor("foo")
|
||||
var run int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&run, 1)
|
||||
return nil, nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&run))
|
||||
}
|
||||
|
||||
func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
|
||||
interceptor := UnaryTracingInterceptor("foo")
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
var md metadata.MD
|
||||
ctx := metadata.NewIncomingContext(context.Background(), md)
|
||||
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
defer wg.Done()
|
||||
assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).TraceId()) > 0)
|
||||
assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).SpanId()) > 0)
|
||||
return nil, nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
18
zrpc/internal/target.go
Normal file
18
zrpc/internal/target.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/zrpc/internal/resolver"
|
||||
)
|
||||
|
||||
func BuildDirectTarget(endpoints []string) string {
|
||||
return fmt.Sprintf("%s:///%s", resolver.DirectScheme,
|
||||
strings.Join(endpoints, resolver.EndpointSep))
|
||||
}
|
||||
|
||||
func BuildDiscovTarget(endpoints []string, key string) string {
|
||||
return fmt.Sprintf("%s://%s/%s", resolver.DiscovScheme,
|
||||
strings.Join(endpoints, resolver.EndpointSep), key)
|
||||
}
|
||||
17
zrpc/internal/target_test.go
Normal file
17
zrpc/internal/target_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBuildDirectTarget(t *testing.T) {
|
||||
target := BuildDirectTarget([]string{"localhost:123", "localhost:456"})
|
||||
assert.Equal(t, "direct:///localhost:123,localhost:456", target)
|
||||
}
|
||||
|
||||
func TestBuildDiscovTarget(t *testing.T) {
|
||||
target := BuildDiscovTarget([]string{"localhost:123", "localhost:456"}, "foo")
|
||||
assert.Equal(t, "discov://localhost:123,localhost:456/foo", target)
|
||||
}
|
||||
Reference in New Issue
Block a user