Skip to content

Commit

Permalink
implement do all (#226)
Browse files Browse the repository at this point in the history
* implement do all

* remove panic call
  • Loading branch information
Teddy Budiono Hermawan authored Jul 12, 2023
1 parent 4141cf0 commit 776b80f
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 0 deletions.
72 changes: 72 additions & 0 deletions ctask/do_all.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package ctask

import (
"context"
"runtime"

"golang.org/x/sync/errgroup"
)

type DoAllOpt func(cfg *DoAllConfig)
type DoAllConfig struct {
WorkerNum int
}

type DoAllResp[R any] struct {
Result R
Error error
}

// DoAll execute tasks using the given executor function for all the given tasks.
// It waits until all tasks are finished.
// The return value is a slice of Result or Error
//
// The max number of goroutines can optionally be specified using the option WithWorkerNum.
// By default, it is set to runtime.NumCPU()
func DoAll[Task any, Result any](
ctx context.Context,
tasks []Task,
executor func(ctx context.Context, t Task) (Result, error),
opts ...DoAllOpt,
) []DoAllResp[Result] {
cfg := getDoAllConfigWithOptions(opts...)

g, ctx := errgroup.WithContext(ctx)
g.SetLimit(cfg.WorkerNum)
results := make([]DoAllResp[Result], len(tasks))
for idx, task := range tasks {
idx, task := idx, task // retain current loop values to be used in goroutine
g.Go(func() error {
select {
case <-ctx.Done():
results[idx] = DoAllResp[Result]{Error: ctx.Err()}
return nil
default:
res, err := executor(ctx, task)
results[idx] = DoAllResp[Result]{
Result: res,
Error: err,
}
return nil
}
})
}
_ = g.Wait() // impossible to have error here
return results
}

func getDoAllConfigWithOptions(opts ...DoAllOpt) DoAllConfig {
cfg := DoAllConfig{
WorkerNum: runtime.NumCPU(),
}
for _, opt := range opts {
opt(&cfg)
}
return cfg
}

func WithDoAllWorkerNum(num int) DoAllOpt {
return func(cfg *DoAllConfig) {
cfg.WorkerNum = num
}
}
139 changes: 139 additions & 0 deletions ctask/do_all_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package ctask

import (
"context"
"errors"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestDoAll(t *testing.T) {
type T = int // task type
type R = int // result type

type args struct {
ctx context.Context
ctxTimeout time.Duration
tasks []T
executor func(ctx context.Context, t T) (R, error)
opts []DoAllOpt
}
tests := []struct {
name string
args args
want []DoAllResp[R]
}{
{
name: "happy path",
args: args{
ctx: context.Background(),
tasks: []T{0, 1, 2, 3, 4, 5, 6},
executor: fibonacci,
opts: nil,
},
want: []DoAllResp[R]{
{Result: 1},
{Result: 1},
{Result: 2},
{Result: 3},
{Result: 5},
{Result: 8},
{Result: 13},
},
},
{
name: "empty slice",
args: args{
ctx: context.Background(),
tasks: nil,
executor: fibonacci,
opts: nil,
},
want: []DoAllResp[R]{},
},
{
name: "error path",
args: args{
ctx: context.Background(),
tasks: []T{0, 1, 2, 1, -1, 5},
executor: fibonacci,
opts: []DoAllOpt{WithDoAllWorkerNum(1)},
},
want: []DoAllResp[R]{
{Result: 1},
{Result: 1},
{Result: 2},
{Result: 1},
{Error: errors.New("negative")},
{Result: 8},
},
},
{
name: "slow functions should return context deadline exceeded error",
args: args{
ctx: context.Background(),
ctxTimeout: 20 * time.Millisecond,
tasks: []T{
10, 1000, 10, 5000, 10,
},
executor: func(ctx context.Context, t T) (R, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-time.After(time.Duration(t) * time.Millisecond):
return 1, nil
}
},
opts: []DoAllOpt{WithDoAllWorkerNum(5)},
},
want: []DoAllResp[R]{
{Result: 1}, {Error: context.DeadlineExceeded}, {Result: 1}, {Error: context.DeadlineExceeded}, {Result: 1},
},
},
{
name: "slow function with sleeps should run concurrently without context deadline error",
args: args{
ctx: context.Background(),
ctxTimeout: 50 * time.Millisecond,
tasks: []T{
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
},
executor: func(ctx context.Context, t T) (R, error) {
time.Sleep(time.Duration(t) * time.Millisecond)
return 1, nil
},
opts: []DoAllOpt{WithDoAllWorkerNum(20)},
},
want: []DoAllResp[R]{
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := tt.args.ctx
if tt.args.ctxTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, tt.args.ctxTimeout)
defer cancel()
}
got := DoAll(ctx, tt.args.tasks, tt.args.executor, tt.args.opts...)
require.Equal(t, tt.want, got)
})
}
}

0 comments on commit 776b80f

Please sign in to comment.