From ec4188047615c6bc4272a74a83d0118630a67c63 Mon Sep 17 00:00:00 2001 From: chentong Date: Sat, 2 Mar 2024 00:32:39 +0800 Subject: [PATCH] fix: BatchError.Add() non thread safe (#3946) --- core/errorx/batcherror.go | 9 ++++++++- core/errorx/batcherror_test.go | 24 ++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/core/errorx/batcherror.go b/core/errorx/batcherror.go index 92ae644d..b63867cb 100644 --- a/core/errorx/batcherror.go +++ b/core/errorx/batcherror.go @@ -1,10 +1,14 @@ package errorx -import "bytes" +import ( + "bytes" + "sync" +) type ( // A BatchError is an error that can hold multiple errors. BatchError struct { + mu sync.Mutex errs errorArray } @@ -13,6 +17,9 @@ type ( // Add adds errs to be, nil errors are ignored. func (be *BatchError) Add(errs ...error) { + be.mu.Lock() + defer be.mu.Unlock() + for _, err := range errs { if err != nil { be.errs = append(be.errs, err) diff --git a/core/errorx/batcherror_test.go b/core/errorx/batcherror_test.go index ae5c8c3e..3e345343 100644 --- a/core/errorx/batcherror_test.go +++ b/core/errorx/batcherror_test.go @@ -3,6 +3,7 @@ package errorx import ( "errors" "fmt" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -33,7 +34,7 @@ func TestBatchErrorNilFromFunc(t *testing.T) { func TestBatchErrorOneError(t *testing.T) { var batch BatchError batch.Add(errors.New(err1)) - assert.NotNil(t, batch) + assert.NotNil(t, batch.Err()) assert.Equal(t, err1, batch.Err().Error()) assert.True(t, batch.NotNil()) } @@ -42,7 +43,26 @@ func TestBatchErrorWithErrors(t *testing.T) { var batch BatchError batch.Add(errors.New(err1)) batch.Add(errors.New(err2)) - assert.NotNil(t, batch) + assert.NotNil(t, batch.Err()) assert.Equal(t, fmt.Sprintf("%s\n%s", err1, err2), batch.Err().Error()) assert.True(t, batch.NotNil()) } + +func TestBatchErrorConcurrentAdd(t *testing.T) { + const count = 10000 + var batch BatchError + var wg sync.WaitGroup + + wg.Add(count) + for i := 0; i < count; i++ { + go func() { + defer wg.Done() + batch.Add(errors.New(err1)) + }() + } + wg.Wait() + + assert.NotNil(t, batch.Err()) + assert.Equal(t, count, len(batch.errs)) + assert.True(t, batch.NotNil()) +}