diff --git a/core/rescue/recover.go b/core/rescue/recover.go index 7cde054e..5d2922bf 100644 --- a/core/rescue/recover.go +++ b/core/rescue/recover.go @@ -1,6 +1,12 @@ package rescue -import "github.com/zeromicro/go-zero/core/logx" +import ( + "context" + "runtime/debug" + + "github.com/zeromicro/go-zero/core/logc" + "github.com/zeromicro/go-zero/core/logx" +) // Recover is used with defer to do cleanup on panics. // Use it like: @@ -15,3 +21,13 @@ func Recover(cleanups ...func()) { logx.ErrorStack(p) } } + +func RecoverCtx(ctx context.Context, cleanups ...func()) { + for _, cleanup := range cleanups { + cleanup() + } + + if p := recover(); p != nil { + logc.Errorf(ctx, "%+v\n\n%s", p, debug.Stack()) + } +} diff --git a/core/threading/routines.go b/core/threading/routines.go index 900dcb3b..a165992f 100644 --- a/core/threading/routines.go +++ b/core/threading/routines.go @@ -2,6 +2,7 @@ package threading import ( "bytes" + "context" "runtime" "strconv" @@ -13,6 +14,11 @@ func GoSafe(fn func()) { go RunSafe(fn) } +// GoSafeCtx runs the given fn using another goroutine, recovers if fn panics with ctx. +func GoSafeCtx(ctx context.Context, fn func()) { + go RunSafeCtx(ctx, fn) +} + // RoutineId is only for debug, never use it in production. func RoutineId() uint64 { b := make([]byte, 64) @@ -31,3 +37,10 @@ func RunSafe(fn func()) { fn() } + +// RunSafeCtx runs the given fn, recovers if fn panics with ctx. +func RunSafeCtx(ctx context.Context, fn func()) { + defer rescue.RecoverCtx(ctx) + + fn() +} diff --git a/core/threading/routines_test.go b/core/threading/routines_test.go index 6a12dca0..30a524ee 100644 --- a/core/threading/routines_test.go +++ b/core/threading/routines_test.go @@ -1,12 +1,15 @@ package threading import ( + "bytes" + "context" "io" "log" "testing" "github.com/stretchr/testify/assert" "github.com/zeromicro/go-zero/core/lang" + "github.com/zeromicro/go-zero/core/logx" ) func TestRoutineId(t *testing.T) { @@ -34,3 +37,51 @@ func TestRunSafe(t *testing.T) { <-ch i++ } + +func TestRunSafeCtx(t *testing.T) { + var buf bytes.Buffer + logx.SetWriter(logx.NewWriter(&buf)) + ctx := context.Background() + ch := make(chan lang.PlaceholderType) + + i := 0 + + defer func() { + assert.Equal(t, 1, i) + }() + + go RunSafeCtx(ctx, func() { + defer func() { + ch <- lang.Placeholder + }() + + panic("panic") + }) + + <-ch + i++ +} + +func TestGoSafeCtx(t *testing.T) { + var buf bytes.Buffer + logx.SetWriter(logx.NewWriter(&buf)) + ctx := context.Background() + ch := make(chan lang.PlaceholderType) + + i := 0 + + defer func() { + assert.Equal(t, 1, i) + }() + + GoSafeCtx(ctx, func() { + defer func() { + ch <- lang.Placeholder + }() + + panic("panic") + }) + + <-ch + i++ +}