diff --git a/core/logx/logtest/logtest.go b/core/logx/logtest/logtest.go index 53c2db43..6e583840 100644 --- a/core/logx/logtest/logtest.go +++ b/core/logx/logtest/logtest.go @@ -76,3 +76,14 @@ func (b *Buffer) Reset() { func (b *Buffer) String() string { return b.buf.String() } + +func PanicOnFatal(t *testing.T) { + ok := logx.ExitOnFatal.CompareAndSwap(true, false) + if !ok { + return + } + + t.Cleanup(func() { + logx.ExitOnFatal.CompareAndSwap(false, true) + }) +} diff --git a/core/logx/logtest/logtest_test.go b/core/logx/logtest/logtest_test.go index 1a61e07a..da7c3f79 100644 --- a/core/logx/logtest/logtest_test.go +++ b/core/logx/logtest/logtest_test.go @@ -1,6 +1,7 @@ package logtest import ( + "errors" "testing" "github.com/stretchr/testify/assert" @@ -15,8 +16,13 @@ func TestCollector(t *testing.T) { assert.Contains(t, c.String(), input) } -func TestDiscard(t *testing.T) { +func TestPanicOnFatal(t *testing.T) { const input = "hello" Discard(t) logx.Info(input) + + PanicOnFatal(t) + assert.Panics(t, func() { + logx.Must(errors.New("foo")) + }) } diff --git a/core/syncx/atomicbool.go b/core/syncx/atomicbool.go index 7e489d9b..751b8306 100644 --- a/core/syncx/atomicbool.go +++ b/core/syncx/atomicbool.go @@ -20,12 +20,14 @@ func ForAtomicBool(val bool) *AtomicBool { // CompareAndSwap compares current value with given old, if equals, set to given val. func (b *AtomicBool) CompareAndSwap(old, val bool) bool { var ov, nv uint32 + if old { ov = 1 } if val { nv = 1 } + return atomic.CompareAndSwapUint32((*uint32)(b), ov, nv) } diff --git a/gateway/server_test.go b/gateway/server_test.go index 6d21c2f1..74168559 100644 --- a/gateway/server_test.go +++ b/gateway/server_test.go @@ -10,7 +10,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/conf" + "github.com/zeromicro/go-zero/core/discov" "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/core/logx/logtest" "github.com/zeromicro/go-zero/internal/mock" "github.com/zeromicro/go-zero/rest/httpc" "github.com/zeromicro/go-zero/zrpc" @@ -51,6 +53,8 @@ func TestMustNewServer(t *testing.T) { s := MustNewServer(c, withDialer(func(conf zrpc.RpcClientConf) zrpc.Client { return zrpc.MustNewClient(conf, zrpc.WithDialOption(grpc.WithContextDialer(dialer()))) + }), WithHeaderProcessor(func(header http.Header) []string { + return []string{"foo"} })) s.upstreams = []Upstream{ { @@ -77,6 +81,7 @@ func TestMustNewServer(t *testing.T) { assert.NoError(t, s.build()) go s.Server.Start() + defer s.Stop() time.Sleep(time.Millisecond * 200) @@ -103,3 +108,20 @@ func TestServer_ensureUpstreamNames(t *testing.T) { assert.NoError(t, s.ensureUpstreamNames()) assert.Equal(t, "target", s.upstreams[0].Name) } + +func TestServer_ensureUpstreamNames_badEtcd(t *testing.T) { + var s = Server{ + upstreams: []Upstream{ + { + Grpc: zrpc.RpcClientConf{ + Etcd: discov.EtcdConf{}, + }, + }, + }, + } + + logtest.PanicOnFatal(t) + assert.Panics(t, func() { + s.Start() + }) +}