feat: support context in MapReduce (#1368)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package mr
|
package mr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -43,6 +44,7 @@ type (
|
|||||||
Option func(opts *mapReduceOptions)
|
Option func(opts *mapReduceOptions)
|
||||||
|
|
||||||
mapReduceOptions struct {
|
mapReduceOptions struct {
|
||||||
|
ctx context.Context
|
||||||
workers int
|
workers int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,14 +97,15 @@ func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{}
|
|||||||
collector := make(chan interface{}, options.workers)
|
collector := make(chan interface{}, options.workers)
|
||||||
done := syncx.NewDoneChan()
|
done := syncx.NewDoneChan()
|
||||||
|
|
||||||
go executeMappers(mapper, source, collector, done.Done(), options.workers)
|
go executeMappers(options.ctx, mapper, source, collector, done.Done(), options.workers)
|
||||||
|
|
||||||
return collector
|
return collector
|
||||||
}
|
}
|
||||||
|
|
||||||
// MapReduce maps all elements generated from given generate func,
|
// MapReduce maps all elements generated from given generate func,
|
||||||
// and reduces the output elements with given reducer.
|
// and reduces the output elements with given reducer.
|
||||||
func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) {
|
func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc,
|
||||||
|
opts ...Option) (interface{}, error) {
|
||||||
source := buildSource(generate)
|
source := buildSource(generate)
|
||||||
return MapReduceWithSource(source, mapper, reducer, opts...)
|
return MapReduceWithSource(source, mapper, reducer, opts...)
|
||||||
}
|
}
|
||||||
@@ -120,7 +123,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
|
|||||||
|
|
||||||
collector := make(chan interface{}, options.workers)
|
collector := make(chan interface{}, options.workers)
|
||||||
done := syncx.NewDoneChan()
|
done := syncx.NewDoneChan()
|
||||||
writer := newGuardedWriter(output, done.Done())
|
writer := newGuardedWriter(options.ctx, output, done.Done())
|
||||||
var closeOnce sync.Once
|
var closeOnce sync.Once
|
||||||
var retErr errorx.AtomicError
|
var retErr errorx.AtomicError
|
||||||
finish := func() {
|
finish := func() {
|
||||||
@@ -154,7 +157,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
|
|||||||
reducer(collector, writer, cancel)
|
reducer(collector, writer, cancel)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go executeMappers(func(item interface{}, w Writer) {
|
go executeMappers(options.ctx, func(item interface{}, w Writer) {
|
||||||
mapper(item, w, cancel)
|
mapper(item, w, cancel)
|
||||||
}, source, collector, done.Done(), options.workers)
|
}, source, collector, done.Done(), options.workers)
|
||||||
|
|
||||||
@@ -187,6 +190,13 @@ func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) {
|
|||||||
}, opts...))
|
}, opts...))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithContext customizes a mapreduce processing accepts a given ctx.
|
||||||
|
func WithContext(ctx context.Context) Option {
|
||||||
|
return func(opts *mapReduceOptions) {
|
||||||
|
opts.ctx = ctx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithWorkers customizes a mapreduce processing with given workers.
|
// WithWorkers customizes a mapreduce processing with given workers.
|
||||||
func WithWorkers(workers int) Option {
|
func WithWorkers(workers int) Option {
|
||||||
return func(opts *mapReduceOptions) {
|
return func(opts *mapReduceOptions) {
|
||||||
@@ -224,8 +234,8 @@ func drain(channel <-chan interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- interface{},
|
func executeMappers(ctx context.Context, mapper MapFunc, input <-chan interface{},
|
||||||
done <-chan lang.PlaceholderType, workers int) {
|
collector chan<- interface{}, done <-chan lang.PlaceholderType, workers int) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
defer func() {
|
defer func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
@@ -233,9 +243,11 @@ func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- i
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
pool := make(chan lang.PlaceholderType, workers)
|
pool := make(chan lang.PlaceholderType, workers)
|
||||||
writer := newGuardedWriter(collector, done)
|
writer := newGuardedWriter(ctx, collector, done)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
case <-done:
|
case <-done:
|
||||||
return
|
return
|
||||||
case pool <- lang.Placeholder:
|
case pool <- lang.Placeholder:
|
||||||
@@ -261,6 +273,7 @@ func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- i
|
|||||||
|
|
||||||
func newOptions() *mapReduceOptions {
|
func newOptions() *mapReduceOptions {
|
||||||
return &mapReduceOptions{
|
return &mapReduceOptions{
|
||||||
|
ctx: context.Background(),
|
||||||
workers: defaultWorkers,
|
workers: defaultWorkers,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -275,12 +288,15 @@ func once(fn func(error)) func(error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type guardedWriter struct {
|
type guardedWriter struct {
|
||||||
|
ctx context.Context
|
||||||
channel chan<- interface{}
|
channel chan<- interface{}
|
||||||
done <-chan lang.PlaceholderType
|
done <-chan lang.PlaceholderType
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderType) guardedWriter {
|
func newGuardedWriter(ctx context.Context, channel chan<- interface{},
|
||||||
|
done <-chan lang.PlaceholderType) guardedWriter {
|
||||||
return guardedWriter{
|
return guardedWriter{
|
||||||
|
ctx: ctx,
|
||||||
channel: channel,
|
channel: channel,
|
||||||
done: done,
|
done: done,
|
||||||
}
|
}
|
||||||
@@ -288,6 +304,8 @@ func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderTy
|
|||||||
|
|
||||||
func (gw guardedWriter) Write(v interface{}) {
|
func (gw guardedWriter) Write(v interface{}) {
|
||||||
select {
|
select {
|
||||||
|
case <-gw.ctx.Done():
|
||||||
|
return
|
||||||
case <-gw.done:
|
case <-gw.done:
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package mr
|
package mr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
@@ -410,6 +411,50 @@ func TestMapReduceWithoutReducerWrite(t *testing.T) {
|
|||||||
assert.Nil(t, res)
|
assert.Nil(t, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMapReduceVoidPanicInReducer(t *testing.T) {
|
||||||
|
const message = "foo"
|
||||||
|
var done syncx.AtomicBool
|
||||||
|
err := MapReduceVoid(func(source chan<- interface{}) {
|
||||||
|
for i := 0; i < defaultWorkers*2; i++ {
|
||||||
|
source <- i
|
||||||
|
}
|
||||||
|
done.Set(true)
|
||||||
|
}, func(item interface{}, writer Writer, cancel func(error)) {
|
||||||
|
i := item.(int)
|
||||||
|
writer.Write(i)
|
||||||
|
}, func(pipe <-chan interface{}, cancel func(error)) {
|
||||||
|
panic(message)
|
||||||
|
}, WithWorkers(1))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Equal(t, message, err.Error())
|
||||||
|
assert.True(t, done.True())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapReduceWithContext(t *testing.T) {
|
||||||
|
var done syncx.AtomicBool
|
||||||
|
var result []int
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
err := MapReduceVoid(func(source chan<- interface{}) {
|
||||||
|
for i := 0; i < defaultWorkers*2; i++ {
|
||||||
|
source <- i
|
||||||
|
}
|
||||||
|
done.Set(true)
|
||||||
|
}, func(item interface{}, writer Writer, c func(error)) {
|
||||||
|
i := item.(int)
|
||||||
|
if i == defaultWorkers/2 {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
writer.Write(i)
|
||||||
|
}, func(pipe <-chan interface{}, cancel func(error)) {
|
||||||
|
for item := range pipe {
|
||||||
|
i := item.(int)
|
||||||
|
result = append(result, i)
|
||||||
|
}
|
||||||
|
}, WithContext(ctx))
|
||||||
|
assert.NotNil(t, err)
|
||||||
|
assert.Equal(t, ErrReduceNoOutput, err)
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkMapReduce(b *testing.B) {
|
func BenchmarkMapReduce(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user