chore: refactor gateway (#3157)

This commit is contained in:
Kevin Wan
2023-04-22 23:25:51 +08:00
committed by GitHub
parent de1e0f2410
commit 027193dc99
2 changed files with 24 additions and 32 deletions

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/fullstorydev/grpcurl" "github.com/fullstorydev/grpcurl"
"github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/jsonpb"
@@ -17,7 +16,6 @@ import (
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
"github.com/zeromicro/go-zero/zrpc" "github.com/zeromicro/go-zero/zrpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
) )
type ( type (
@@ -26,7 +24,6 @@ type (
c GatewayConf c GatewayConf
*rest.Server *rest.Server
upstreams []*upstream upstreams []*upstream
timeout time.Duration
processHeader func(http.Header) []string processHeader func(http.Header) []string
} }
@@ -42,9 +39,8 @@ type (
// MustNewServer creates a new gateway server. // MustNewServer creates a new gateway server.
func MustNewServer(c GatewayConf, opts ...Option) *Server { func MustNewServer(c GatewayConf, opts ...Option) *Server {
svr := &Server{ svr := &Server{
c: c, c: c,
Server: rest.MustNewServer(c.RestConf), Server: rest.MustNewServer(c.RestConf),
timeout: time.Duration(c.Timeout) * time.Millisecond,
} }
for _, opt := range opts { for _, opt := range opts {
opt(svr) opt(svr)
@@ -68,6 +64,7 @@ func (s *Server) build() error {
if err := s.buildClient(); err != nil { if err := s.buildClient(); err != nil {
return err return err
} }
return s.buildUpstream() return s.buildUpstream()
} }
@@ -84,7 +81,9 @@ func (s *Server) buildClient() error {
target, err := up.Grpc.BuildTarget() target, err := up.Grpc.BuildTarget()
if err != nil { if err != nil {
cancel(err) cancel(err)
return
} }
up.Name = target up.Name = target
cli := zrpc.MustNewClient(up.Grpc) cli := zrpc.MustNewClient(up.Grpc)
writer.Write(&upstream{ writer.Write(&upstream{
@@ -160,13 +159,9 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A
return return
} }
timeout := internal.GetTimeout(r.Header, s.timeout)
ctx, can := context.WithTimeout(r.Context(), timeout)
defer can()
w.Header().Set(httpx.ContentType, httpx.JsonContentType) w.Header().Set(httpx.ContentType, httpx.JsonContentType)
handler := internal.NewEventHandler(w, resolver) handler := internal.NewEventHandler(w, resolver)
if err := grpcurl.InvokeRPC(ctx, source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header), if err := grpcurl.InvokeRPC(r.Context(), source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header),
handler, parser.Next); err != nil { handler, parser.Next); err != nil {
httpx.ErrorCtx(r.Context(), w, err) httpx.ErrorCtx(r.Context(), w, err)
} }
@@ -188,8 +183,7 @@ func (s *Server) createDescriptorSource(cli zrpc.Client, up Upstream) (grpcurl.D
return nil, err return nil, err
} }
} else { } else {
refCli := grpc_reflection_v1alpha.NewServerReflectionClient(cli.Conn()) client := grpcreflect.NewClientAuto(context.Background(), cli.Conn())
client := grpcreflect.NewClient(context.Background(), refCli)
source = grpcurl.DescriptorSourceFromServer(context.Background(), client) source = grpcurl.DescriptorSourceFromServer(context.Background(), client)
} }

View File

@@ -44,10 +44,12 @@ func dialer() func(context.Context, string) (net.Conn, error) {
func TestMustNewServer(t *testing.T) { func TestMustNewServer(t *testing.T) {
var c GatewayConf var c GatewayConf
assert.NoError(t, conf.FillDefault(&c)) assert.NoError(t, conf.FillDefault(&c))
// avoid popup alert on macos for asking permissions
c.DevServer.Host = "localhost"
c.Host = "localhost"
c.Port = 18881 c.Port = 18881
s := MustNewServer(c) s := MustNewServer(c)
s.upstreams = []*upstream{ s.upstreams = []*upstream{
{ {
Upstream: Upstream{ Upstream: Upstream{
@@ -59,22 +61,20 @@ func TestMustNewServer(t *testing.T) {
}, },
}, },
}, },
client: zrpc.MustNewClient(zrpc.RpcClientConf{ client: zrpc.MustNewClient(
Endpoints: []string{"foo"}, zrpc.RpcClientConf{
Timeout: 1000, Endpoints: []string{"foo"},
Middlewares: zrpc.ClientMiddlewaresConf{ Timeout: 1000,
Trace: true, Middlewares: zrpc.ClientMiddlewaresConf{
Duration: true, Trace: true,
Prometheus: true, Duration: true,
Breaker: true, Prometheus: true,
Timeout: true, Breaker: true,
Timeout: true,
},
}, },
},
zrpc.WithDialOption(grpc.WithContextDialer(dialer())), zrpc.WithDialOption(grpc.WithContextDialer(dialer())),
zrpc.WithUnaryClientInterceptor(func(ctx context.Context, method string, req, reply any, ),
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return invoker(ctx, method, req, reply, cc, opts...)
})),
}, },
} }
@@ -83,13 +83,11 @@ func TestMustNewServer(t *testing.T) {
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
ctx := context.Background() resp, err := httpc.Do(context.Background(), http.MethodGet, "http://localhost:18881/deposit/100", nil)
resp, err := httpc.Do(ctx, http.MethodGet, "http://localhost:18881/deposit/100", nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode) assert.Equal(t, http.StatusOK, resp.StatusCode)
resp, err = httpc.Do(ctx, http.MethodGet, "http://localhost:18881/deposit_fail/100", nil) resp, err = httpc.Do(context.Background(), http.MethodGet, "http://localhost:18881/deposit_fail/100", nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusNotFound, resp.StatusCode) assert.Equal(t, http.StatusNotFound, resp.StatusCode)
} }