rename rpcx to zrpc

This commit is contained in:
kevin
2020-09-18 11:41:52 +08:00
parent 26e16107ce
commit 0b1ee79d3a
110 changed files with 154 additions and 154 deletions

View File

@@ -0,0 +1,73 @@
package auth
import (
"context"
"time"
"github.com/tal-tech/go-zero/core/collection"
"github.com/tal-tech/go-zero/core/stores/redis"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const defaultExpiration = 5 * time.Minute
type Authenticator struct {
store *redis.Redis
key string
cache *collection.Cache
strict bool
}
func NewAuthenticator(store *redis.Redis, key string, strict bool) (*Authenticator, error) {
cache, err := collection.NewCache(defaultExpiration)
if err != nil {
return nil, err
}
return &Authenticator{
store: store,
key: key,
cache: cache,
strict: strict,
}, nil
}
func (a *Authenticator) Authenticate(ctx context.Context) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Error(codes.Unauthenticated, missingMetadata)
}
apps, tokens := md[appKey], md[tokenKey]
if len(apps) == 0 || len(tokens) == 0 {
return status.Error(codes.Unauthenticated, missingMetadata)
}
app, token := apps[0], tokens[0]
if len(app) == 0 || len(token) == 0 {
return status.Error(codes.Unauthenticated, missingMetadata)
}
return a.validate(app, token)
}
func (a *Authenticator) validate(app, token string) error {
expect, err := a.cache.Take(app, func() (interface{}, error) {
return a.store.Hget(a.key, app)
})
if err != nil {
if a.strict {
return status.Error(codes.Internal, err.Error())
} else {
return nil
}
}
if token != expect {
return status.Error(codes.Unauthenticated, accessDenied)
}
return nil
}

View File

@@ -0,0 +1,47 @@
package auth
import (
"context"
"google.golang.org/grpc/metadata"
)
type Credential struct {
App string
Token string
}
func (c *Credential) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
return map[string]string{
appKey: c.App,
tokenKey: c.Token,
}, nil
}
func (c *Credential) RequireTransportSecurity() bool {
return false
}
func ParseCredential(ctx context.Context) Credential {
var credential Credential
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return credential
}
apps, tokens := md[appKey], md[tokenKey]
if len(apps) == 0 || len(tokens) == 0 {
return credential
}
app, token := apps[0], tokens[0]
if len(app) == 0 || len(token) == 0 {
return credential
}
credential.App = app
credential.Token = token
return credential
}

View File

@@ -0,0 +1,62 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
)
func TestParseCredential(t *testing.T) {
tests := []struct {
name string
withNil bool
withEmptyMd bool
app string
token string
}{
{
name: "nil",
withNil: true,
},
{
name: "empty md",
withEmptyMd: true,
},
{
name: "empty",
},
{
name: "valid",
app: "foo",
token: "bar",
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
var ctx context.Context
if test.withNil {
ctx = context.Background()
} else if test.withEmptyMd {
ctx = metadata.NewIncomingContext(context.Background(), metadata.MD{})
} else {
md := metadata.New(map[string]string{
"app": test.app,
"token": test.token,
})
ctx = metadata.NewIncomingContext(context.Background(), md)
}
cred := ParseCredential(ctx)
assert.False(t, cred.RequireTransportSecurity())
m, err := cred.GetRequestMetadata(context.Background())
assert.Nil(t, err)
assert.Equal(t, test.app, m[appKey])
assert.Equal(t, test.token, m[tokenKey])
})
}
}

View File

@@ -0,0 +1,9 @@
package auth
const (
appKey = "app"
tokenKey = "token"
accessDenied = "access denied"
missingMetadata = "app/token required"
)

View File

@@ -0,0 +1,202 @@
package p2c
import (
"context"
"fmt"
"math"
"math/rand"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/syncx"
"github.com/tal-tech/go-zero/core/timex"
"github.com/tal-tech/go-zero/zrpc/internal/codes"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/resolver"
)
const (
Name = "p2c_ewma"
decayTime = int64(time.Second * 10) // default value from finagle
forcePick = int64(time.Second)
initSuccess = 1000
throttleSuccess = initSuccess / 2
penalty = int64(math.MaxInt32)
pickTimes = 3
logInterval = time.Minute
)
func init() {
balancer.Register(newBuilder())
}
type p2cPickerBuilder struct {
}
func newBuilder() balancer.Builder {
return base.NewBalancerBuilder(Name, new(p2cPickerBuilder))
}
func (b *p2cPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker {
if len(readySCs) == 0 {
return base.NewErrPicker(balancer.ErrNoSubConnAvailable)
}
var conns []*subConn
for addr, conn := range readySCs {
conns = append(conns, &subConn{
addr: addr,
conn: conn,
success: initSuccess,
})
}
return &p2cPicker{
conns: conns,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
stamp: syncx.NewAtomicDuration(),
}
}
type p2cPicker struct {
conns []*subConn
r *rand.Rand
stamp *syncx.AtomicDuration
lock sync.Mutex
}
func (p *p2cPicker) Pick(ctx context.Context, info balancer.PickInfo) (
conn balancer.SubConn, done func(balancer.DoneInfo), err error) {
p.lock.Lock()
defer p.lock.Unlock()
var chosen *subConn
switch len(p.conns) {
case 0:
return nil, nil, balancer.ErrNoSubConnAvailable
case 1:
chosen = p.choose(p.conns[0], nil)
case 2:
chosen = p.choose(p.conns[0], p.conns[1])
default:
var node1, node2 *subConn
for i := 0; i < pickTimes; i++ {
a := p.r.Intn(len(p.conns))
b := p.r.Intn(len(p.conns) - 1)
if b >= a {
b++
}
node1 = p.conns[a]
node2 = p.conns[b]
if node1.healthy() && node2.healthy() {
break
}
}
chosen = p.choose(node1, node2)
}
atomic.AddInt64(&chosen.inflight, 1)
atomic.AddInt64(&chosen.requests, 1)
return chosen.conn, p.buildDoneFunc(chosen), nil
}
func (p *p2cPicker) buildDoneFunc(c *subConn) func(info balancer.DoneInfo) {
start := int64(timex.Now())
return func(info balancer.DoneInfo) {
atomic.AddInt64(&c.inflight, -1)
now := timex.Now()
last := atomic.SwapInt64(&c.last, int64(now))
td := int64(now) - last
if td < 0 {
td = 0
}
w := math.Exp(float64(-td) / float64(decayTime))
lag := int64(now) - start
if lag < 0 {
lag = 0
}
olag := atomic.LoadUint64(&c.lag)
if olag == 0 {
w = 0
}
atomic.StoreUint64(&c.lag, uint64(float64(olag)*w+float64(lag)*(1-w)))
success := initSuccess
if info.Err != nil && !codes.Acceptable(info.Err) {
success = 0
}
osucc := atomic.LoadUint64(&c.success)
atomic.StoreUint64(&c.success, uint64(float64(osucc)*w+float64(success)*(1-w)))
stamp := p.stamp.Load()
if now-stamp >= logInterval {
if p.stamp.CompareAndSwap(stamp, now) {
p.logStats()
}
}
}
}
func (p *p2cPicker) choose(c1, c2 *subConn) *subConn {
start := int64(timex.Now())
if c2 == nil {
atomic.StoreInt64(&c1.pick, start)
return c1
}
if c1.load() > c2.load() {
c1, c2 = c2, c1
}
pick := atomic.LoadInt64(&c2.pick)
if start-pick > forcePick && atomic.CompareAndSwapInt64(&c2.pick, pick, start) {
return c2
} else {
atomic.StoreInt64(&c1.pick, start)
return c1
}
}
func (p *p2cPicker) logStats() {
var stats []string
p.lock.Lock()
defer p.lock.Unlock()
for _, conn := range p.conns {
stats = append(stats, fmt.Sprintf("conn: %s, load: %d, reqs: %d",
conn.addr.Addr, conn.load(), atomic.SwapInt64(&conn.requests, 0)))
}
logx.Statf("p2c - %s", strings.Join(stats, "; "))
}
type subConn struct {
addr resolver.Address
conn balancer.SubConn
lag uint64
inflight int64
success uint64
requests int64
last int64
pick int64
}
func (c *subConn) healthy() bool {
return atomic.LoadUint64(&c.success) > throttleSuccess
}
func (c *subConn) load() int64 {
// plus one to avoid multiply zero
lag := int64(math.Sqrt(float64(atomic.LoadUint64(&c.lag) + 1)))
load := lag * (atomic.LoadInt64(&c.inflight) + 1)
if load == 0 {
return penalty
} else {
return load
}
}

View File

@@ -0,0 +1,113 @@
package p2c
import (
"context"
"fmt"
"runtime"
"strconv"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/mathx"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)
func init() {
logx.Disable()
}
func TestP2cPicker_PickNil(t *testing.T) {
builder := new(p2cPickerBuilder)
picker := builder.Build(nil)
_, _, err := picker.Pick(context.Background(), balancer.PickInfo{
FullMethodName: "/",
Ctx: context.Background(),
})
assert.NotNil(t, err)
}
func TestP2cPicker_Pick(t *testing.T) {
tests := []struct {
name string
candidates int
threshold float64
}{
{
name: "single",
candidates: 1,
threshold: 0.9,
},
{
name: "two",
candidates: 2,
threshold: 0.5,
},
{
name: "multiple",
candidates: 100,
threshold: 0.95,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
const total = 10000
builder := new(p2cPickerBuilder)
ready := make(map[resolver.Address]balancer.SubConn)
for i := 0; i < test.candidates; i++ {
ready[resolver.Address{
Addr: strconv.Itoa(i),
}] = new(mockClientConn)
}
picker := builder.Build(ready)
var wg sync.WaitGroup
wg.Add(total)
for i := 0; i < total; i++ {
_, done, err := picker.Pick(context.Background(), balancer.PickInfo{
FullMethodName: "/",
Ctx: context.Background(),
})
assert.Nil(t, err)
if i%100 == 0 {
err = status.Error(codes.DeadlineExceeded, "deadline")
}
go func() {
runtime.Gosched()
done(balancer.DoneInfo{
Err: err,
})
wg.Done()
}()
}
wg.Wait()
dist := make(map[interface{}]int)
conns := picker.(*p2cPicker).conns
for _, conn := range conns {
dist[conn.addr.Addr] = int(conn.requests)
}
entropy := mathx.CalcEntropy(dist)
assert.True(t, entropy > test.threshold, fmt.Sprintf("entropy is %f, less than %f",
entropy, test.threshold))
})
}
}
type mockClientConn struct {
}
func (m mockClientConn) UpdateAddresses(addresses []resolver.Address) {
}
func (m mockClientConn) Connect() {
}

View File

@@ -0,0 +1,83 @@
package internal
import (
"context"
"google.golang.org/grpc"
)
func WithStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.DialOption {
return grpc.WithStreamInterceptor(chainStreamClientInterceptors(interceptors...))
}
func WithUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.DialOption {
return grpc.WithUnaryInterceptor(chainUnaryClientInterceptors(interceptors...))
}
func chainStreamClientInterceptors(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor {
switch len(interceptors) {
case 0:
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return streamer(ctx, desc, cc, method, opts...)
}
case 1:
return interceptors[0]
default:
last := len(interceptors) - 1
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
var chainStreamer grpc.Streamer
var current int
chainStreamer = func(curCtx context.Context, curDesc *grpc.StreamDesc, curCc *grpc.ClientConn,
curMethod string, curOpts ...grpc.CallOption) (grpc.ClientStream, error) {
if current == last {
return streamer(curCtx, curDesc, curCc, curMethod, curOpts...)
}
current++
clientStream, err := interceptors[current](curCtx, curDesc, curCc, curMethod, chainStreamer, curOpts...)
current--
return clientStream, err
}
return interceptors[0](ctx, desc, cc, method, chainStreamer, opts...)
}
}
}
func chainUnaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
switch len(interceptors) {
case 0:
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return invoker(ctx, method, req, reply, cc, opts...)
}
case 1:
return interceptors[0]
default:
last := len(interceptors) - 1
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
var chainInvoker grpc.UnaryInvoker
var current int
chainInvoker = func(curCtx context.Context, curMethod string, curReq, curReply interface{},
curCc *grpc.ClientConn, curOpts ...grpc.CallOption) error {
if current == last {
return invoker(curCtx, curMethod, curReq, curReply, curCc, curOpts...)
}
current++
err := interceptors[current](curCtx, curMethod, curReq, curReply, curCc, chainInvoker, curOpts...)
current--
return err
}
return interceptors[0](ctx, method, req, reply, cc, chainInvoker, opts...)
}
}
}

View File

@@ -0,0 +1,123 @@
package internal
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
func TestWithStreamClientInterceptors(t *testing.T) {
opts := WithStreamClientInterceptors()
assert.NotNil(t, opts)
}
func TestWithUnaryClientInterceptors(t *testing.T) {
opts := WithUnaryClientInterceptors()
assert.NotNil(t, opts)
}
func TestChainStreamClientInterceptors_zero(t *testing.T) {
var vals []int
interceptors := chainStreamClientInterceptors()
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
opts ...grpc.CallOption) (grpc.ClientStream, error) {
vals = append(vals, 1)
return nil, nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1}, vals)
}
func TestChainStreamClientInterceptors_one(t *testing.T) {
var vals []int
interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
grpc.ClientStream, error) {
vals = append(vals, 1)
return streamer(ctx, desc, cc, method, opts...)
})
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
opts ...grpc.CallOption) (grpc.ClientStream, error) {
vals = append(vals, 2)
return nil, nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1, 2}, vals)
}
func TestChainStreamClientInterceptors_more(t *testing.T) {
var vals []int
interceptors := chainStreamClientInterceptors(func(ctx context.Context, desc *grpc.StreamDesc,
cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (
grpc.ClientStream, error) {
vals = append(vals, 1)
return streamer(ctx, desc, cc, method, opts...)
}, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
vals = append(vals, 2)
return streamer(ctx, desc, cc, method, opts...)
})
_, err := interceptors(context.Background(), nil, new(grpc.ClientConn), "/foo",
func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
opts ...grpc.CallOption) (grpc.ClientStream, error) {
vals = append(vals, 3)
return nil, nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
}
func TestWithUnaryClientInterceptors_zero(t *testing.T) {
var vals []int
interceptors := chainUnaryClientInterceptors()
err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
vals = append(vals, 1)
return nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1}, vals)
}
func TestWithUnaryClientInterceptors_one(t *testing.T) {
var vals []int
interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
vals = append(vals, 1)
return invoker(ctx, method, req, reply, cc, opts...)
})
err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
vals = append(vals, 2)
return nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1, 2}, vals)
}
func TestWithUnaryClientInterceptors_more(t *testing.T) {
var vals []int
interceptors := chainUnaryClientInterceptors(func(ctx context.Context, method string, req,
reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
vals = append(vals, 1)
return invoker(ctx, method, req, reply, cc, opts...)
}, func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
vals = append(vals, 2)
return invoker(ctx, method, req, reply, cc, opts...)
})
err := interceptors(context.Background(), "/foo", nil, nil, new(grpc.ClientConn),
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
vals = append(vals, 3)
return nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
}

View File

@@ -0,0 +1,81 @@
package internal
import (
"context"
"google.golang.org/grpc"
)
func WithStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption {
return grpc.StreamInterceptor(chainStreamServerInterceptors(interceptors...))
}
func WithUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption {
return grpc.UnaryInterceptor(chainUnaryServerInterceptors(interceptors...))
}
func chainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
switch len(interceptors) {
case 0:
return func(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
return handler(srv, stream)
}
case 1:
return interceptors[0]
default:
last := len(interceptors) - 1
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
var chainHandler grpc.StreamHandler
var current int
chainHandler = func(curSrv interface{}, curStream grpc.ServerStream) error {
if current == last {
return handler(curSrv, curStream)
}
current++
err := interceptors[current](curSrv, curStream, info, chainHandler)
current--
return err
}
return interceptors[0](srv, stream, info, chainHandler)
}
}
}
func chainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
switch len(interceptors) {
case 0:
return func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
interface{}, error) {
return handler(ctx, req)
}
case 1:
return interceptors[0]
default:
last := len(interceptors) - 1
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
interface{}, error) {
var chainHandler grpc.UnaryHandler
var current int
chainHandler = func(curCtx context.Context, curReq interface{}) (interface{}, error) {
if current == last {
return handler(curCtx, curReq)
}
current++
resp, err := interceptors[current](curCtx, curReq, info, chainHandler)
current--
return resp, err
}
return interceptors[0](ctx, req, info, chainHandler)
}
}
}

View File

@@ -0,0 +1,111 @@
package internal
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
func TestWithStreamServerInterceptors(t *testing.T) {
opts := WithStreamServerInterceptors()
assert.NotNil(t, opts)
}
func TestWithUnaryServerInterceptors(t *testing.T) {
opts := WithUnaryServerInterceptors()
assert.NotNil(t, opts)
}
func TestChainStreamServerInterceptors_zero(t *testing.T) {
var vals []int
interceptors := chainStreamServerInterceptors()
err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
vals = append(vals, 1)
return nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1}, vals)
}
func TestChainStreamServerInterceptors_one(t *testing.T) {
var vals []int
interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
vals = append(vals, 1)
return handler(srv, ss)
})
err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
vals = append(vals, 2)
return nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1, 2}, vals)
}
func TestChainStreamServerInterceptors_more(t *testing.T) {
var vals []int
interceptors := chainStreamServerInterceptors(func(srv interface{}, ss grpc.ServerStream,
info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
vals = append(vals, 1)
return handler(srv, ss)
}, func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
vals = append(vals, 2)
return handler(srv, ss)
})
err := interceptors(nil, nil, nil, func(srv interface{}, stream grpc.ServerStream) error {
vals = append(vals, 3)
return nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
}
func TestChainUnaryServerInterceptors_zero(t *testing.T) {
var vals []int
interceptors := chainUnaryServerInterceptors()
_, err := interceptors(context.Background(), nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
vals = append(vals, 1)
return nil, nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1}, vals)
}
func TestChainUnaryServerInterceptors_one(t *testing.T) {
var vals []int
interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
vals = append(vals, 1)
return handler(ctx, req)
})
_, err := interceptors(context.Background(), nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
vals = append(vals, 2)
return nil, nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1, 2}, vals)
}
func TestChainUnaryServerInterceptors_more(t *testing.T) {
var vals []int
interceptors := chainUnaryServerInterceptors(func(ctx context.Context, req interface{},
info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
vals = append(vals, 1)
return handler(ctx, req)
}, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) {
vals = append(vals, 2)
return handler(ctx, req)
})
_, err := interceptors(context.Background(), nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
vals = append(vals, 3)
return nil, nil
})
assert.Nil(t, err)
assert.ElementsMatch(t, []int{1, 2, 3}, vals)
}

90
zrpc/internal/client.go Normal file
View File

@@ -0,0 +1,90 @@
package internal
import (
"context"
"fmt"
"time"
"github.com/tal-tech/go-zero/zrpc/internal/balancer/p2c"
"github.com/tal-tech/go-zero/zrpc/internal/clientinterceptors"
"github.com/tal-tech/go-zero/zrpc/internal/resolver"
"google.golang.org/grpc"
)
const dialTimeout = time.Second * 3
func init() {
resolver.RegisterResolver()
}
type (
ClientOptions struct {
Timeout time.Duration
DialOptions []grpc.DialOption
}
ClientOption func(options *ClientOptions)
client struct {
conn *grpc.ClientConn
}
)
func NewClient(target string, opts ...ClientOption) (*client, error) {
opts = append(opts, WithDialOption(grpc.WithBalancerName(p2c.Name)))
conn, err := dial(target, opts...)
if err != nil {
return nil, err
}
return &client{conn: conn}, nil
}
func (c *client) Conn() *grpc.ClientConn {
return c.conn
}
func WithDialOption(opt grpc.DialOption) ClientOption {
return func(options *ClientOptions) {
options.DialOptions = append(options.DialOptions, opt)
}
}
func WithTimeout(timeout time.Duration) ClientOption {
return func(options *ClientOptions) {
options.Timeout = timeout
}
}
func buildDialOptions(opts ...ClientOption) []grpc.DialOption {
var clientOptions ClientOptions
for _, opt := range opts {
opt(&clientOptions)
}
options := []grpc.DialOption{
grpc.WithInsecure(),
grpc.WithBlock(),
WithUnaryClientInterceptors(
clientinterceptors.BreakerInterceptor,
clientinterceptors.DurationInterceptor,
clientinterceptors.PromMetricInterceptor,
clientinterceptors.TimeoutInterceptor(clientOptions.Timeout),
clientinterceptors.TracingInterceptor,
),
}
return append(options, clientOptions.DialOptions...)
}
func dial(server string, opts ...ClientOption) (*grpc.ClientConn, error) {
options := buildDialOptions(opts...)
timeCtx, cancel := context.WithTimeout(context.Background(), dialTimeout)
defer cancel()
conn, err := grpc.DialContext(timeCtx, server, options...)
if err != nil {
return nil, fmt.Errorf("rpc dial: %s, error: %s", server, err.Error())
}
return conn, nil
}

View File

@@ -0,0 +1,30 @@
package internal
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
func TestWithDialOption(t *testing.T) {
var options ClientOptions
agent := grpc.WithUserAgent("chrome")
opt := WithDialOption(agent)
opt(&options)
assert.Contains(t, options.DialOptions, agent)
}
func TestWithTimeout(t *testing.T) {
var options ClientOptions
opt := WithTimeout(time.Second)
opt(&options)
assert.Equal(t, time.Second, options.Timeout)
}
func TestBuildDialOptions(t *testing.T) {
agent := grpc.WithUserAgent("chrome")
opts := buildDialOptions(WithDialOption(agent))
assert.Contains(t, opts, agent)
}

View File

@@ -0,0 +1,18 @@
package clientinterceptors
import (
"context"
"path"
"github.com/tal-tech/go-zero/core/breaker"
"github.com/tal-tech/go-zero/zrpc/internal/codes"
"google.golang.org/grpc"
)
func BreakerInterceptor(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
breakerName := path.Join(cc.Target(), method)
return breaker.DoWithAcceptable(breakerName, func() error {
return invoker(ctx, method, req, reply, cc, opts...)
}, codes.Acceptable)
}

View File

@@ -0,0 +1,81 @@
package clientinterceptors
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/breaker"
"github.com/tal-tech/go-zero/core/stat"
rcodes "github.com/tal-tech/go-zero/zrpc/internal/codes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func init() {
stat.SetReporter(nil)
}
type mockError struct {
st *status.Status
}
func (m mockError) GRPCStatus() *status.Status {
return m.st
}
func (m mockError) Error() string {
return "mocked error"
}
func TestBreakerInterceptorNotFound(t *testing.T) {
err := mockError{st: status.New(codes.NotFound, "any")}
for i := 0; i < 1000; i++ {
assert.Equal(t, err, breaker.DoWithAcceptable("call", func() error {
return err
}, rcodes.Acceptable))
}
}
func TestBreakerInterceptorDeadlineExceeded(t *testing.T) {
err := mockError{st: status.New(codes.DeadlineExceeded, "any")}
errs := make(map[error]int)
for i := 0; i < 1000; i++ {
e := breaker.DoWithAcceptable("call", func() error {
return err
}, rcodes.Acceptable)
errs[e]++
}
assert.Equal(t, 2, len(errs))
assert.True(t, errs[err] > 0)
assert.True(t, errs[breaker.ErrServiceUnavailable] > 0)
}
func TestBreakerInterceptor(t *testing.T) {
tests := []struct {
name string
err error
}{
{
name: "nil",
err: nil,
},
{
name: "with error",
err: errors.New("mock"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cc := new(grpc.ClientConn)
err := BreakerInterceptor(context.Background(), "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
return test.err
})
assert.Equal(t, test.err, err)
})
}
}

View File

@@ -0,0 +1,30 @@
package clientinterceptors
import (
"context"
"path"
"time"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/timex"
"google.golang.org/grpc"
)
const slowThreshold = time.Millisecond * 500
func DurationInterceptor(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
serverName := path.Join(cc.Target(), method)
start := timex.Now()
err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
logx.WithDuration(timex.Since(start)).Infof("fail - %s - %v - %s", serverName, req, err.Error())
} else {
elapsed := timex.Since(start)
if elapsed > slowThreshold {
logx.WithDuration(elapsed).Slowf("[RPC] ok - slowcall - %s - %v - %v", serverName, req, reply)
}
}
return err
}

View File

@@ -0,0 +1,37 @@
package clientinterceptors
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
func TestDurationInterceptor(t *testing.T) {
tests := []struct {
name string
err error
}{
{
name: "nil",
err: nil,
},
{
name: "with error",
err: errors.New("mock"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cc := new(grpc.ClientConn)
err := DurationInterceptor(context.Background(), "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
return test.err
})
assert.Equal(t, test.err, err)
})
}
}

View File

@@ -0,0 +1,42 @@
package clientinterceptors
import (
"context"
"strconv"
"time"
"github.com/tal-tech/go-zero/core/metric"
"github.com/tal-tech/go-zero/core/timex"
"google.golang.org/grpc"
"google.golang.org/grpc/status"
)
const clientNamespace = "rpc_client"
var (
metricClientReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{
Namespace: clientNamespace,
Subsystem: "requests",
Name: "duration_ms",
Help: "rpc client requests duration(ms).",
Labels: []string{"method"},
Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000},
})
metricClientReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{
Namespace: clientNamespace,
Subsystem: "requests",
Name: "code_total",
Help: "rpc client requests code count.",
Labels: []string{"method", "code"},
})
)
func PromMetricInterceptor(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
startTime := timex.Now()
err := invoker(ctx, method, req, reply, cc, opts...)
metricClientReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), method)
metricClientReqCodeTotal.Inc(method, strconv.Itoa(int(status.Code(err))))
return err
}

View File

@@ -0,0 +1,37 @@
package clientinterceptors
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
func TestPromMetricInterceptor(t *testing.T) {
tests := []struct {
name string
err error
}{
{
name: "nil",
err: nil,
},
{
name: "with error",
err: errors.New("mock"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cc := new(grpc.ClientConn)
err := PromMetricInterceptor(context.Background(), "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
return test.err
})
assert.Equal(t, test.err, err)
})
}
}

View File

@@ -0,0 +1,24 @@
package clientinterceptors
import (
"context"
"time"
"github.com/tal-tech/go-zero/core/contextx"
"google.golang.org/grpc"
)
const defaultTimeout = time.Second * 2
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
if timeout <= 0 {
timeout = defaultTimeout
}
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
defer cancel()
return invoker(ctx, method, req, reply, cc, opts...)
}
}

View File

@@ -0,0 +1,50 @@
package clientinterceptors
import (
"context"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
func TestTimeoutInterceptor(t *testing.T) {
timeouts := []time.Duration{0, time.Millisecond * 10}
for _, timeout := range timeouts {
t.Run(strconv.FormatInt(int64(timeout), 10), func(t *testing.T) {
interceptor := TimeoutInterceptor(timeout)
cc := new(grpc.ClientConn)
err := interceptor(context.Background(), "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
return nil
},
)
assert.Nil(t, err)
})
}
}
func TestTimeoutInterceptor_timeout(t *testing.T) {
const timeout = time.Millisecond * 10
interceptor := TimeoutInterceptor(timeout)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var wg sync.WaitGroup
wg.Add(1)
cc := new(grpc.ClientConn)
err := interceptor(ctx, "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
defer wg.Done()
tm, ok := ctx.Deadline()
assert.True(t, ok)
assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
return nil
})
wg.Wait()
assert.Nil(t, err)
}

View File

@@ -0,0 +1,24 @@
package clientinterceptors
import (
"context"
"github.com/tal-tech/go-zero/core/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
func TracingInterceptor(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx, span := trace.StartClientSpan(ctx, cc.Target(), method)
defer span.Finish()
var pairs []string
span.Visit(func(key, val string) bool {
pairs = append(pairs, key, val)
return true
})
ctx = metadata.AppendToOutgoingContext(ctx, pairs...)
return invoker(ctx, method, req, reply, cc, opts...)
}

View File

@@ -0,0 +1,53 @@
package clientinterceptors
import (
"context"
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
func TestTracingInterceptor(t *testing.T) {
var run int32
var wg sync.WaitGroup
wg.Add(1)
cc := new(grpc.ClientConn)
err := TracingInterceptor(context.Background(), "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
defer wg.Done()
atomic.AddInt32(&run, 1)
return nil
})
wg.Wait()
assert.Nil(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&run))
}
func TestTracingInterceptor_GrpcFormat(t *testing.T) {
var run int32
var wg sync.WaitGroup
wg.Add(1)
md := metadata.New(map[string]string{
"foo": "bar",
})
carrier, err := trace.Inject(trace.GrpcFormat, md)
assert.Nil(t, err)
ctx, _ := trace.StartServerSpan(context.Background(), carrier, "user", "/foo")
cc := new(grpc.ClientConn)
err = TracingInterceptor(ctx, "/foo", nil, nil, cc,
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
opts ...grpc.CallOption) error {
defer wg.Done()
atomic.AddInt32(&run, 1)
return nil
})
wg.Wait()
assert.Nil(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&run))
}

View File

@@ -0,0 +1,15 @@
package codes
import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func Acceptable(err error) bool {
switch status.Code(err) {
case codes.DeadlineExceeded, codes.Internal, codes.Unavailable, codes.DataLoss:
return false
default:
return true
}
}

View File

@@ -0,0 +1,34 @@
package codes
import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestAccept(t *testing.T) {
tests := []struct {
name string
err error
accept bool
}{
{
name: "nil error",
err: nil,
accept: true,
},
{
name: "deadline error",
err: status.Error(codes.DeadlineExceeded, "deadline"),
accept: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.accept, Acceptable(test.err))
})
}
}

View File

@@ -0,0 +1,159 @@
// Code generated by protoc-gen-go.
// source: deposit.proto
// DO NOT EDIT!
/*
Package mock is a generated protocol buffer package.
It is generated from these files:
deposit.proto
It has these top-level messages:
DepositRequest
DepositResponse
*/
package mock
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import (
context "golang.org/x/net/context"
grpc "google.golang.org/grpc"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type DepositRequest struct {
Amount float32 `protobuf:"fixed32,1,opt,name=amount" json:"amount,omitempty"`
}
func (m *DepositRequest) Reset() { *m = DepositRequest{} }
func (m *DepositRequest) String() string { return proto.CompactTextString(m) }
func (*DepositRequest) ProtoMessage() {}
func (*DepositRequest) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *DepositRequest) GetAmount() float32 {
if m != nil {
return m.Amount
}
return 0
}
type DepositResponse struct {
Ok bool `protobuf:"varint,1,opt,name=ok" json:"ok,omitempty"`
}
func (m *DepositResponse) Reset() { *m = DepositResponse{} }
func (m *DepositResponse) String() string { return proto.CompactTextString(m) }
func (*DepositResponse) ProtoMessage() {}
func (*DepositResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *DepositResponse) GetOk() bool {
if m != nil {
return m.Ok
}
return false
}
func init() {
proto.RegisterType((*DepositRequest)(nil), "mock.DepositRequest")
proto.RegisterType((*DepositResponse)(nil), "mock.DepositResponse")
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConn
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion4
// Client API for DepositService service
type DepositServiceClient interface {
Deposit(ctx context.Context, in *DepositRequest, opts ...grpc.CallOption) (*DepositResponse, error)
}
type depositServiceClient struct {
cc *grpc.ClientConn
}
func NewDepositServiceClient(cc *grpc.ClientConn) DepositServiceClient {
return &depositServiceClient{cc}
}
func (c *depositServiceClient) Deposit(ctx context.Context, in *DepositRequest, opts ...grpc.CallOption) (*DepositResponse, error) {
out := new(DepositResponse)
err := grpc.Invoke(ctx, "/mock.DepositService/Deposit", in, out, c.cc, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// Server API for DepositService service
type DepositServiceServer interface {
Deposit(context.Context, *DepositRequest) (*DepositResponse, error)
}
func RegisterDepositServiceServer(s *grpc.Server, srv DepositServiceServer) {
s.RegisterService(&_DepositService_serviceDesc, srv)
}
func _DepositService_Deposit_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DepositRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DepositServiceServer).Deposit(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/mock.DepositService/Deposit",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DepositServiceServer).Deposit(ctx, req.(*DepositRequest))
}
return interceptor(ctx, in, info, handler)
}
var _DepositService_serviceDesc = grpc.ServiceDesc{
ServiceName: "mock.DepositService",
HandlerType: (*DepositServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Deposit",
Handler: _DepositService_Deposit_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "deposit.proto",
}
func init() { proto.RegisterFile("deposit.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 139 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4d, 0x49, 0x2d, 0xc8,
0x2f, 0xce, 0x2c, 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0xc9, 0xcd, 0x4f, 0xce, 0x56,
0xd2, 0xe0, 0xe2, 0x73, 0x81, 0x08, 0x07, 0xa5, 0x16, 0x96, 0xa6, 0x16, 0x97, 0x08, 0x89, 0x71,
0xb1, 0x25, 0xe6, 0xe6, 0x97, 0xe6, 0x95, 0x48, 0x30, 0x2a, 0x30, 0x6a, 0x30, 0x05, 0x41, 0x79,
0x4a, 0x8a, 0x5c, 0xfc, 0x70, 0x95, 0xc5, 0x05, 0xf9, 0x79, 0xc5, 0xa9, 0x42, 0x7c, 0x5c, 0x4c,
0xf9, 0xd9, 0x60, 0x65, 0x1c, 0x41, 0x4c, 0xf9, 0xd9, 0x46, 0x1e, 0x70, 0xc3, 0x82, 0x53, 0x8b,
0xca, 0x32, 0x93, 0x53, 0x85, 0xcc, 0xb8, 0xd8, 0xa1, 0x22, 0x42, 0x22, 0x7a, 0x20, 0x0b, 0xf5,
0x50, 0x6d, 0x93, 0x12, 0x45, 0x13, 0x85, 0x98, 0x9c, 0xc4, 0x06, 0x76, 0xa3, 0x31, 0x20, 0x00,
0x00, 0xff, 0xff, 0x62, 0x37, 0xf2, 0x36, 0xb4, 0x00, 0x00, 0x00,
}

View File

@@ -0,0 +1,15 @@
syntax = "proto3";
package mock;
message DepositRequest {
float amount = 1;
}
message DepositResponse {
bool ok = 1;
}
service DepositService {
rpc Deposit(DepositRequest) returns (DepositResponse);
}

View File

@@ -0,0 +1,19 @@
package mock
import (
"context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type DepositServer struct {
}
func (*DepositServer) Deposit(ctx context.Context, req *DepositRequest) (*DepositResponse, error) {
if req.GetAmount() < 0 {
return nil, status.Errorf(codes.InvalidArgument, "cannot deposit %v", req.GetAmount())
}
return &DepositResponse{Ok: true}, nil
}

View File

@@ -0,0 +1,32 @@
package resolver
import (
"strings"
"google.golang.org/grpc/resolver"
)
type directBuilder struct{}
func (d *directBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
resolver.Resolver, error) {
var addrs []resolver.Address
endpoints := strings.FieldsFunc(target.Endpoint, func(r rune) bool {
return r == EndpointSepChar
})
for _, val := range subset(endpoints, subsetSize) {
addrs = append(addrs, resolver.Address{
Addr: val,
})
}
cc.UpdateState(resolver.State{
Addresses: addrs,
})
return &nopResolver{cc: cc}, nil
}
func (d *directBuilder) Scheme() string {
return DirectScheme
}

View File

@@ -0,0 +1,52 @@
package resolver
import (
"fmt"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/core/mathx"
"google.golang.org/grpc/resolver"
)
func TestDirectBuilder_Build(t *testing.T) {
tests := []int{
0,
1,
2,
subsetSize / 2,
subsetSize,
subsetSize * 2,
}
for _, test := range tests {
t.Run(strconv.Itoa(test), func(t *testing.T) {
var servers []string
for i := 0; i < test; i++ {
servers = append(servers, fmt.Sprintf("localhost:%d", i))
}
var b directBuilder
cc := new(mockedClientConn)
_, err := b.Build(resolver.Target{
Scheme: DirectScheme,
Endpoint: strings.Join(servers, ","),
}, cc, resolver.BuildOptions{})
assert.Nil(t, err)
size := mathx.MinInt(test, subsetSize)
assert.Equal(t, size, len(cc.state.Addresses))
m := make(map[string]lang.PlaceholderType)
for _, each := range cc.state.Addresses {
m[each.Addr] = lang.Placeholder
}
assert.Equal(t, size, len(m))
})
}
}
func TestDirectBuilder_Scheme(t *testing.T) {
var b directBuilder
assert.Equal(t, DirectScheme, b.Scheme())
}

View File

@@ -0,0 +1,41 @@
package resolver
import (
"strings"
"github.com/tal-tech/go-zero/core/discov"
"google.golang.org/grpc/resolver"
)
type discovBuilder struct{}
func (d *discovBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (
resolver.Resolver, error) {
hosts := strings.FieldsFunc(target.Authority, func(r rune) bool {
return r == EndpointSepChar
})
sub, err := discov.NewSubscriber(hosts, target.Endpoint)
if err != nil {
return nil, err
}
update := func() {
var addrs []resolver.Address
for _, val := range subset(sub.Values(), subsetSize) {
addrs = append(addrs, resolver.Address{
Addr: val,
})
}
cc.UpdateState(resolver.State{
Addresses: addrs,
})
}
sub.AddListener(update)
update()
return &nopResolver{cc: cc}, nil
}
func (d *discovBuilder) Scheme() string {
return DiscovScheme
}

View File

@@ -0,0 +1,35 @@
package resolver
import (
"fmt"
"google.golang.org/grpc/resolver"
)
const (
DirectScheme = "direct"
DiscovScheme = "discov"
EndpointSepChar = ','
subsetSize = 32
)
var (
EndpointSep = fmt.Sprintf("%c", EndpointSepChar)
dirBuilder directBuilder
disBuilder discovBuilder
)
func RegisterResolver() {
resolver.Register(&dirBuilder)
resolver.Register(&disBuilder)
}
type nopResolver struct {
cc resolver.ClientConn
}
func (r *nopResolver) Close() {
}
func (r *nopResolver) ResolveNow(options resolver.ResolveNowOptions) {
}

View File

@@ -0,0 +1,36 @@
package resolver
import (
"testing"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
)
func TestNopResolver(t *testing.T) {
// make sure ResolveNow & Close don't panic
var r nopResolver
r.ResolveNow(resolver.ResolveNowOptions{})
r.Close()
}
type mockedClientConn struct {
state resolver.State
}
func (m *mockedClientConn) UpdateState(state resolver.State) {
m.state = state
}
func (m *mockedClientConn) ReportError(err error) {
}
func (m *mockedClientConn) NewAddress(addresses []resolver.Address) {
}
func (m *mockedClientConn) NewServiceConfig(serviceConfig string) {
}
func (m *mockedClientConn) ParseServiceConfig(serviceConfigJSON string) *serviceconfig.ParseResult {
return nil
}

View File

@@ -0,0 +1,14 @@
package resolver
import "math/rand"
func subset(set []string, sub int) []string {
rand.Shuffle(len(set), func(i, j int) {
set[i], set[j] = set[j], set[i]
})
if len(set) <= sub {
return set
} else {
return set[:sub]
}
}

View File

@@ -0,0 +1,53 @@
package resolver
import (
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/mathx"
)
func TestSubset(t *testing.T) {
tests := []struct {
name string
set int
sub int
}{
{
name: "more vals to subset",
set: 100,
sub: 36,
},
{
name: "less vals to subset",
set: 100,
sub: 200,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var vals []string
for i := 0; i < test.set; i++ {
vals = append(vals, strconv.Itoa(i))
}
m := make(map[interface{}]int)
for i := 0; i < 1000; i++ {
set := subset(append([]string(nil), vals...), test.sub)
if test.sub < test.set {
assert.Equal(t, test.sub, len(set))
} else {
assert.Equal(t, test.set, len(set))
}
for _, val := range set {
m[val]++
}
}
assert.True(t, mathx.CalcEntropy(m) > 0.95)
})
}
}

View File

@@ -0,0 +1,73 @@
package internal
import (
"sync"
"github.com/tal-tech/go-zero/core/logx"
"google.golang.org/grpc/grpclog"
)
// because grpclog.errorLog is not exported, we need to define our own.
const errorLevel = 2
var once sync.Once
type Logger struct{}
func InitLogger() {
once.Do(func() {
grpclog.SetLoggerV2(new(Logger))
})
}
func (l *Logger) Error(args ...interface{}) {
logx.Error(args...)
}
func (l *Logger) Errorf(format string, args ...interface{}) {
logx.Errorf(format, args...)
}
func (l *Logger) Errorln(args ...interface{}) {
logx.Error(args...)
}
func (l *Logger) Fatal(args ...interface{}) {
logx.Error(args...)
}
func (l *Logger) Fatalf(format string, args ...interface{}) {
logx.Errorf(format, args...)
}
func (l *Logger) Fatalln(args ...interface{}) {
logx.Error(args...)
}
func (l *Logger) Info(args ...interface{}) {
// ignore builtin grpc info
}
func (l *Logger) Infoln(args ...interface{}) {
// ignore builtin grpc info
}
func (l *Logger) Infof(format string, args ...interface{}) {
// ignore builtin grpc info
}
func (l *Logger) V(v int) bool {
return v >= errorLevel
}
func (l *Logger) Warning(args ...interface{}) {
// ignore builtin grpc warning
}
func (l *Logger) Warningln(args ...interface{}) {
// ignore builtin grpc warning
}
func (l *Logger) Warningf(format string, args ...interface{}) {
// ignore builtin grpc warning
}

View File

@@ -0,0 +1,29 @@
package internal
import "github.com/tal-tech/go-zero/core/discov"
func NewRpcPubServer(etcdEndpoints []string, etcdKey, listenOn string, opts ...ServerOption) (Server, error) {
registerEtcd := func() error {
pubClient := discov.NewPublisher(etcdEndpoints, etcdKey, listenOn)
return pubClient.KeepAlive()
}
server := keepAliveServer{
registerEtcd: registerEtcd,
Server: NewRpcServer(listenOn, opts...),
}
return server, nil
}
type keepAliveServer struct {
registerEtcd func() error
Server
}
func (ags keepAliveServer) Start(fn RegisterFn) error {
if err := ags.registerEtcd(); err != nil {
return err
}
return ags.Server.Start(fn)
}

View File

@@ -0,0 +1,81 @@
package internal
import (
"net"
"github.com/tal-tech/go-zero/core/proc"
"github.com/tal-tech/go-zero/core/stat"
"github.com/tal-tech/go-zero/zrpc/internal/serverinterceptors"
"google.golang.org/grpc"
)
type (
ServerOption func(options *rpcServerOptions)
rpcServerOptions struct {
metrics *stat.Metrics
}
rpcServer struct {
*baseRpcServer
}
)
func init() {
InitLogger()
}
func NewRpcServer(address string, opts ...ServerOption) Server {
var options rpcServerOptions
for _, opt := range opts {
opt(&options)
}
if options.metrics == nil {
options.metrics = stat.NewMetrics(address)
}
return &rpcServer{
baseRpcServer: newBaseRpcServer(address, options.metrics),
}
}
func (s *rpcServer) SetName(name string) {
s.baseRpcServer.SetName(name)
}
func (s *rpcServer) Start(register RegisterFn) error {
lis, err := net.Listen("tcp", s.address)
if err != nil {
return err
}
unaryInterceptors := []grpc.UnaryServerInterceptor{
serverinterceptors.UnaryCrashInterceptor(),
serverinterceptors.UnaryStatInterceptor(s.metrics),
serverinterceptors.UnaryPromMetricInterceptor(),
}
unaryInterceptors = append(unaryInterceptors, s.unaryInterceptors...)
streamInterceptors := []grpc.StreamServerInterceptor{
serverinterceptors.StreamCrashInterceptor,
}
streamInterceptors = append(streamInterceptors, s.streamInterceptors...)
options := append(s.options, WithUnaryServerInterceptors(unaryInterceptors...),
WithStreamServerInterceptors(streamInterceptors...))
server := grpc.NewServer(options...)
register(server)
// we need to make sure all others are wrapped up
// so we do graceful stop at shutdown phase instead of wrap up phase
shutdownCalled := proc.AddShutdownListener(func() {
server.GracefulStop()
})
err = server.Serve(lis)
shutdownCalled()
return err
}
func WithMetrics(metrics *stat.Metrics) ServerOption {
return func(options *rpcServerOptions) {
options.metrics = metrics
}
}

View File

@@ -0,0 +1,16 @@
package internal
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stat"
)
func TestWithMetrics(t *testing.T) {
metrics := stat.NewMetrics("foo")
opt := WithMetrics(metrics)
var options rpcServerOptions
opt(&options)
assert.Equal(t, metrics, options.metrics)
}

49
zrpc/internal/server.go Normal file
View File

@@ -0,0 +1,49 @@
package internal
import (
"github.com/tal-tech/go-zero/core/stat"
"google.golang.org/grpc"
)
type (
RegisterFn func(*grpc.Server)
Server interface {
AddOptions(options ...grpc.ServerOption)
AddStreamInterceptors(interceptors ...grpc.StreamServerInterceptor)
AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor)
SetName(string)
Start(register RegisterFn) error
}
baseRpcServer struct {
address string
metrics *stat.Metrics
options []grpc.ServerOption
streamInterceptors []grpc.StreamServerInterceptor
unaryInterceptors []grpc.UnaryServerInterceptor
}
)
func newBaseRpcServer(address string, metrics *stat.Metrics) *baseRpcServer {
return &baseRpcServer{
address: address,
metrics: metrics,
}
}
func (s *baseRpcServer) AddOptions(options ...grpc.ServerOption) {
s.options = append(s.options, options...)
}
func (s *baseRpcServer) AddStreamInterceptors(interceptors ...grpc.StreamServerInterceptor) {
s.streamInterceptors = append(s.streamInterceptors, interceptors...)
}
func (s *baseRpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterceptor) {
s.unaryInterceptors = append(s.unaryInterceptors, interceptors...)
}
func (s *baseRpcServer) SetName(name string) {
s.metrics.SetName(name)
}

View File

@@ -0,0 +1,53 @@
package internal
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stat"
"google.golang.org/grpc"
)
func TestBaseRpcServer_AddOptions(t *testing.T) {
metrics := stat.NewMetrics("foo")
server := newBaseRpcServer("foo", metrics)
server.SetName("bar")
var opt grpc.EmptyServerOption
server.AddOptions(opt)
assert.Contains(t, server.options, opt)
}
func TestBaseRpcServer_AddStreamInterceptors(t *testing.T) {
metrics := stat.NewMetrics("foo")
server := newBaseRpcServer("foo", metrics)
server.SetName("bar")
var vals []int
f := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
vals = append(vals, 1)
return nil
}
server.AddStreamInterceptors(f)
for _, each := range server.streamInterceptors {
assert.Nil(t, each(nil, nil, nil, nil))
}
assert.ElementsMatch(t, []int{1}, vals)
}
func TestBaseRpcServer_AddUnaryInterceptors(t *testing.T) {
metrics := stat.NewMetrics("foo")
server := newBaseRpcServer("foo", metrics)
server.SetName("bar")
var vals []int
f := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
resp interface{}, err error) {
vals = append(vals, 1)
return nil, nil
}
server.AddUnaryInterceptors(f)
for _, each := range server.unaryInterceptors {
_, err := each(context.Background(), nil, nil, nil)
assert.Nil(t, err)
}
assert.ElementsMatch(t, []int{1}, vals)
}

View File

@@ -0,0 +1,30 @@
package serverinterceptors
import (
"context"
"github.com/tal-tech/go-zero/zrpc/internal/auth"
"google.golang.org/grpc"
)
func StreamAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) error {
if err := authenticator.Authenticate(stream.Context()); err != nil {
return err
}
return handler(srv, stream)
}
}
func UnaryAuthorizeInterceptor(authenticator *auth.Authenticator) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (interface{}, error) {
if err := authenticator.Authenticate(ctx); err != nil {
return nil, err
}
return handler(ctx, req)
}
}

View File

@@ -0,0 +1,200 @@
package serverinterceptors
import (
"context"
"testing"
"github.com/alicebob/miniredis"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stores/redis"
"github.com/tal-tech/go-zero/zrpc/internal/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
func TestStreamAuthorizeInterceptor(t *testing.T) {
tests := []struct {
name string
app string
token string
strict bool
hasError bool
}{
{
name: "strict=false",
strict: false,
hasError: false,
},
{
name: "strict=true",
strict: true,
hasError: true,
},
{
name: "strict=true,with token",
app: "foo",
token: "bar",
strict: true,
hasError: false,
},
{
name: "strict=true,with error token",
app: "foo",
token: "error",
strict: true,
hasError: true,
},
}
r := miniredis.NewMiniRedis()
assert.Nil(t, r.Start())
defer r.Close()
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
store := redis.NewRedis(r.Addr(), redis.NodeType)
if len(test.app) > 0 {
assert.Nil(t, store.Hset("apps", test.app, test.token))
defer store.Hdel("apps", test.app)
}
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
assert.Nil(t, err)
interceptor := StreamAuthorizeInterceptor(authenticator)
md := metadata.New(map[string]string{
"app": "foo",
"token": "bar",
})
ctx := metadata.NewIncomingContext(context.Background(), md)
stream := mockedStream{ctx: ctx}
err = interceptor(nil, stream, nil, func(srv interface{}, stream grpc.ServerStream) error {
return nil
})
if test.hasError {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
})
}
}
func TestUnaryAuthorizeInterceptor(t *testing.T) {
tests := []struct {
name string
app string
token string
strict bool
hasError bool
}{
{
name: "strict=false",
strict: false,
hasError: false,
},
{
name: "strict=true",
strict: true,
hasError: true,
},
{
name: "strict=true,with token",
app: "foo",
token: "bar",
strict: true,
hasError: false,
},
{
name: "strict=true,with error token",
app: "foo",
token: "error",
strict: true,
hasError: true,
},
}
r := miniredis.NewMiniRedis()
assert.Nil(t, r.Start())
defer r.Close()
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
store := redis.NewRedis(r.Addr(), redis.NodeType)
if len(test.app) > 0 {
assert.Nil(t, store.Hset("apps", test.app, test.token))
defer store.Hdel("apps", test.app)
}
authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
assert.Nil(t, err)
interceptor := UnaryAuthorizeInterceptor(authenticator)
md := metadata.New(map[string]string{
"app": "foo",
"token": "bar",
})
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
if test.hasError {
assert.NotNil(t, err)
} else {
assert.Nil(t, err)
}
if test.strict {
_, err = interceptor(context.Background(), nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.NotNil(t, err)
var md metadata.MD
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.NotNil(t, err)
md = metadata.New(map[string]string{
"app": "",
"token": "",
})
ctx = metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.NotNil(t, err)
}
})
}
}
type mockedStream struct {
ctx context.Context
}
func (m mockedStream) SetHeader(md metadata.MD) error {
return nil
}
func (m mockedStream) SendHeader(md metadata.MD) error {
return nil
}
func (m mockedStream) SetTrailer(md metadata.MD) {
}
func (m mockedStream) Context() context.Context {
return m.ctx
}
func (m mockedStream) SendMsg(v interface{}) error {
return nil
}
func (m mockedStream) RecvMsg(v interface{}) error {
return nil
}

View File

@@ -0,0 +1,42 @@
package serverinterceptors
import (
"context"
"runtime/debug"
"github.com/tal-tech/go-zero/core/logx"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func StreamCrashInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler) (err error) {
defer handleCrash(func(r interface{}) {
err = toPanicError(r)
})
return handler(srv, stream)
}
func UnaryCrashInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) {
defer handleCrash(func(r interface{}) {
err = toPanicError(r)
})
return handler(ctx, req)
}
}
func handleCrash(handler func(interface{})) {
if r := recover(); r != nil {
handler(r)
}
}
func toPanicError(r interface{}) error {
logx.Errorf("%+v %s", r, debug.Stack())
return status.Errorf(codes.Internal, "panic: %v", r)
}

View File

@@ -0,0 +1,31 @@
package serverinterceptors
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/logx"
"google.golang.org/grpc"
)
func init() {
logx.Disable()
}
func TestStreamCrashInterceptor(t *testing.T) {
err := StreamCrashInterceptor(nil, nil, nil, func(
srv interface{}, stream grpc.ServerStream) error {
panic("mock panic")
})
assert.NotNil(t, err)
}
func TestUnaryCrashInterceptor(t *testing.T) {
interceptor := UnaryCrashInterceptor()
_, err := interceptor(context.Background(), nil, nil,
func(ctx context.Context, req interface{}) (interface{}, error) {
panic("mock panic")
})
assert.NotNil(t, err)
}

View File

@@ -0,0 +1,44 @@
package serverinterceptors
import (
"context"
"strconv"
"time"
"github.com/tal-tech/go-zero/core/metric"
"github.com/tal-tech/go-zero/core/timex"
"google.golang.org/grpc"
"google.golang.org/grpc/status"
)
const serverNamespace = "rpc_server"
var (
metricServerReqDur = metric.NewHistogramVec(&metric.HistogramVecOpts{
Namespace: serverNamespace,
Subsystem: "requests",
Name: "duration_ms",
Help: "rpc server requests duration(ms).",
Labels: []string{"method"},
Buckets: []float64{5, 10, 25, 50, 100, 250, 500, 1000},
})
metricServerReqCodeTotal = metric.NewCounterVec(&metric.CounterVecOpts{
Namespace: serverNamespace,
Subsystem: "requests",
Name: "code_total",
Help: "rpc server requests code count.",
Labels: []string{"method", "code"},
})
)
func UnaryPromMetricInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
interface{}, error) {
startTime := timex.Now()
resp, err := handler(ctx, req)
metricServerReqDur.Observe(int64(timex.Since(startTime)/time.Millisecond), info.FullMethod)
metricServerReqCodeTotal.Inc(info.FullMethod, strconv.Itoa(int(status.Code(err))))
return resp, err
}
}

View File

@@ -0,0 +1,19 @@
package serverinterceptors
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
func TestUnaryPromMetricInterceptor(t *testing.T) {
interceptor := UnaryPromMetricInterceptor()
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Nil(t, err)
}

View File

@@ -0,0 +1,52 @@
package serverinterceptors
import (
"context"
"sync"
"github.com/tal-tech/go-zero/core/load"
"github.com/tal-tech/go-zero/core/stat"
"google.golang.org/grpc"
)
const serviceType = "rpc"
var (
sheddingStat *load.SheddingStat
lock sync.Mutex
)
func UnarySheddingInterceptor(shedder load.Shedder, metrics *stat.Metrics) grpc.UnaryServerInterceptor {
ensureSheddingStat()
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (val interface{}, err error) {
sheddingStat.IncrementTotal()
var promise load.Promise
promise, err = shedder.Allow()
if err != nil {
metrics.AddDrop()
sheddingStat.IncrementDrop()
return
}
defer func() {
if err == context.DeadlineExceeded {
promise.Fail()
} else {
sheddingStat.IncrementPass()
promise.Pass()
}
}()
return handler(ctx, req)
}
}
func ensureSheddingStat() {
lock.Lock()
if sheddingStat == nil {
sheddingStat = load.NewSheddingStat(serviceType)
}
lock.Unlock()
}

View File

@@ -0,0 +1,77 @@
package serverinterceptors
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/load"
"github.com/tal-tech/go-zero/core/stat"
"google.golang.org/grpc"
)
func TestUnarySheddingInterceptor(t *testing.T) {
tests := []struct {
name string
allow bool
handleErr error
expect error
}{
{
name: "allow",
allow: true,
handleErr: nil,
expect: nil,
},
{
name: "allow",
allow: true,
handleErr: context.DeadlineExceeded,
expect: context.DeadlineExceeded,
},
{
name: "reject",
allow: false,
handleErr: nil,
expect: load.ErrServiceOverloaded,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
shedder := mockedShedder{allow: test.allow}
metrics := stat.NewMetrics("mock")
interceptor := UnarySheddingInterceptor(shedder, metrics)
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, test.handleErr
})
assert.Equal(t, test.expect, err)
})
}
}
type mockedShedder struct {
allow bool
}
func (m mockedShedder) Allow() (load.Promise, error) {
if m.allow {
return mockedPromise{}, nil
} else {
return nil, load.ErrServiceOverloaded
}
}
type mockedPromise struct {
}
func (m mockedPromise) Pass() {
}
func (m mockedPromise) Fail() {
}

View File

@@ -0,0 +1,51 @@
package serverinterceptors
import (
"context"
"encoding/json"
"time"
"github.com/tal-tech/go-zero/core/logx"
"github.com/tal-tech/go-zero/core/stat"
"github.com/tal-tech/go-zero/core/timex"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
)
const serverSlowThreshold = time.Millisecond * 500
func UnaryStatInterceptor(metrics *stat.Metrics) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) {
defer handleCrash(func(r interface{}) {
err = toPanicError(r)
})
startTime := timex.Now()
defer func() {
duration := timex.Since(startTime)
metrics.Add(stat.Task{
Duration: duration,
})
logDuration(ctx, info.FullMethod, req, duration)
}()
return handler(ctx, req)
}
}
func logDuration(ctx context.Context, method string, req interface{}, duration time.Duration) {
var addr string
client, ok := peer.FromContext(ctx)
if ok {
addr = client.Addr.String()
}
content, err := json.Marshal(req)
if err != nil {
logx.Errorf("%s - %s", addr, err.Error())
} else if duration > serverSlowThreshold {
logx.WithDuration(duration).Slowf("[RPC] slowcall - %s - %s - %s", addr, method, string(content))
} else {
logx.WithDuration(duration).Infof("%s - %s - %s", addr, method, string(content))
}
}

View File

@@ -0,0 +1,32 @@
package serverinterceptors
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/stat"
"google.golang.org/grpc"
)
func TestUnaryStatInterceptor(t *testing.T) {
metrics := stat.NewMetrics("mock")
interceptor := UnaryStatInterceptor(metrics)
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Nil(t, err)
}
func TestUnaryStatInterceptor_crash(t *testing.T) {
metrics := stat.NewMetrics("mock")
interceptor := UnaryStatInterceptor(metrics)
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
panic("error")
})
assert.NotNil(t, err)
}

View File

@@ -0,0 +1,18 @@
package serverinterceptors
import (
"context"
"time"
"github.com/tal-tech/go-zero/core/contextx"
"google.golang.org/grpc"
)
func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) {
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
defer cancel()
return handler(ctx, req)
}
}

View File

@@ -0,0 +1,41 @@
package serverinterceptors
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
func TestUnaryTimeoutInterceptor(t *testing.T) {
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Nil(t, err)
}
func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
const timeout = time.Millisecond * 10
interceptor := UnaryTimeoutInterceptor(timeout)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var wg sync.WaitGroup
wg.Add(1)
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
defer wg.Done()
tm, ok := ctx.Deadline()
assert.True(t, ok)
assert.True(t, tm.Before(time.Now().Add(timeout+time.Millisecond)))
return nil, nil
})
wg.Wait()
assert.Nil(t, err)
}

View File

@@ -0,0 +1,28 @@
package serverinterceptors
import (
"context"
"github.com/tal-tech/go-zero/core/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
func UnaryTracingInterceptor(serviceName string) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler) (resp interface{}, err error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return handler(ctx, req)
}
carrier, err := trace.Extract(trace.GrpcFormat, md)
if err != nil {
return handler(ctx, req)
}
ctx, span := trace.StartServerSpan(ctx, carrier, serviceName, info.FullMethod)
defer span.Finish()
return handler(ctx, req)
}
}

View File

@@ -0,0 +1,48 @@
package serverinterceptors
import (
"context"
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/tal-tech/go-zero/core/trace/tracespec"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
func TestUnaryTracingInterceptor(t *testing.T) {
interceptor := UnaryTracingInterceptor("foo")
var run int32
var wg sync.WaitGroup
wg.Add(1)
_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
defer wg.Done()
atomic.AddInt32(&run, 1)
return nil, nil
})
wg.Wait()
assert.Nil(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&run))
}
func TestUnaryTracingInterceptor_GrpcFormat(t *testing.T) {
interceptor := UnaryTracingInterceptor("foo")
var wg sync.WaitGroup
wg.Add(1)
var md metadata.MD
ctx := metadata.NewIncomingContext(context.Background(), md)
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
FullMethod: "/",
}, func(ctx context.Context, req interface{}) (interface{}, error) {
defer wg.Done()
assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).TraceId()) > 0)
assert.True(t, len(ctx.Value(tracespec.TracingKey).(tracespec.Trace).SpanId()) > 0)
return nil, nil
})
wg.Wait()
assert.Nil(t, err)
}

18
zrpc/internal/target.go Normal file
View File

@@ -0,0 +1,18 @@
package internal
import (
"fmt"
"strings"
"github.com/tal-tech/go-zero/zrpc/internal/resolver"
)
func BuildDirectTarget(endpoints []string) string {
return fmt.Sprintf("%s:///%s", resolver.DirectScheme,
strings.Join(endpoints, resolver.EndpointSep))
}
func BuildDiscovTarget(endpoints []string, key string) string {
return fmt.Sprintf("%s://%s/%s", resolver.DiscovScheme,
strings.Join(endpoints, resolver.EndpointSep), key)
}

View File

@@ -0,0 +1,17 @@
package internal
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestBuildDirectTarget(t *testing.T) {
target := BuildDirectTarget([]string{"localhost:123", "localhost:456"})
assert.Equal(t, "direct:///localhost:123,localhost:456", target)
}
func TestBuildDiscovTarget(t *testing.T) {
target := BuildDiscovTarget([]string{"localhost:123", "localhost:456"}, "foo")
assert.Equal(t, "discov://localhost:123,localhost:456/foo", target)
}