-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* implement do all * remove panic call
- Loading branch information
Teddy Budiono Hermawan
authored
Jul 12, 2023
1 parent
4141cf0
commit 776b80f
Showing
2 changed files
with
211 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package ctask | ||
|
||
import ( | ||
"context" | ||
"runtime" | ||
|
||
"golang.org/x/sync/errgroup" | ||
) | ||
|
||
type DoAllOpt func(cfg *DoAllConfig) | ||
type DoAllConfig struct { | ||
WorkerNum int | ||
} | ||
|
||
type DoAllResp[R any] struct { | ||
Result R | ||
Error error | ||
} | ||
|
||
// DoAll execute tasks using the given executor function for all the given tasks. | ||
// It waits until all tasks are finished. | ||
// The return value is a slice of Result or Error | ||
// | ||
// The max number of goroutines can optionally be specified using the option WithWorkerNum. | ||
// By default, it is set to runtime.NumCPU() | ||
func DoAll[Task any, Result any]( | ||
ctx context.Context, | ||
tasks []Task, | ||
executor func(ctx context.Context, t Task) (Result, error), | ||
opts ...DoAllOpt, | ||
) []DoAllResp[Result] { | ||
cfg := getDoAllConfigWithOptions(opts...) | ||
|
||
g, ctx := errgroup.WithContext(ctx) | ||
g.SetLimit(cfg.WorkerNum) | ||
results := make([]DoAllResp[Result], len(tasks)) | ||
for idx, task := range tasks { | ||
idx, task := idx, task // retain current loop values to be used in goroutine | ||
g.Go(func() error { | ||
select { | ||
case <-ctx.Done(): | ||
results[idx] = DoAllResp[Result]{Error: ctx.Err()} | ||
return nil | ||
default: | ||
res, err := executor(ctx, task) | ||
results[idx] = DoAllResp[Result]{ | ||
Result: res, | ||
Error: err, | ||
} | ||
return nil | ||
} | ||
}) | ||
} | ||
_ = g.Wait() // impossible to have error here | ||
return results | ||
} | ||
|
||
func getDoAllConfigWithOptions(opts ...DoAllOpt) DoAllConfig { | ||
cfg := DoAllConfig{ | ||
WorkerNum: runtime.NumCPU(), | ||
} | ||
for _, opt := range opts { | ||
opt(&cfg) | ||
} | ||
return cfg | ||
} | ||
|
||
func WithDoAllWorkerNum(num int) DoAllOpt { | ||
return func(cfg *DoAllConfig) { | ||
cfg.WorkerNum = num | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
package ctask | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestDoAll(t *testing.T) { | ||
type T = int // task type | ||
type R = int // result type | ||
|
||
type args struct { | ||
ctx context.Context | ||
ctxTimeout time.Duration | ||
tasks []T | ||
executor func(ctx context.Context, t T) (R, error) | ||
opts []DoAllOpt | ||
} | ||
tests := []struct { | ||
name string | ||
args args | ||
want []DoAllResp[R] | ||
}{ | ||
{ | ||
name: "happy path", | ||
args: args{ | ||
ctx: context.Background(), | ||
tasks: []T{0, 1, 2, 3, 4, 5, 6}, | ||
executor: fibonacci, | ||
opts: nil, | ||
}, | ||
want: []DoAllResp[R]{ | ||
{Result: 1}, | ||
{Result: 1}, | ||
{Result: 2}, | ||
{Result: 3}, | ||
{Result: 5}, | ||
{Result: 8}, | ||
{Result: 13}, | ||
}, | ||
}, | ||
{ | ||
name: "empty slice", | ||
args: args{ | ||
ctx: context.Background(), | ||
tasks: nil, | ||
executor: fibonacci, | ||
opts: nil, | ||
}, | ||
want: []DoAllResp[R]{}, | ||
}, | ||
{ | ||
name: "error path", | ||
args: args{ | ||
ctx: context.Background(), | ||
tasks: []T{0, 1, 2, 1, -1, 5}, | ||
executor: fibonacci, | ||
opts: []DoAllOpt{WithDoAllWorkerNum(1)}, | ||
}, | ||
want: []DoAllResp[R]{ | ||
{Result: 1}, | ||
{Result: 1}, | ||
{Result: 2}, | ||
{Result: 1}, | ||
{Error: errors.New("negative")}, | ||
{Result: 8}, | ||
}, | ||
}, | ||
{ | ||
name: "slow functions should return context deadline exceeded error", | ||
args: args{ | ||
ctx: context.Background(), | ||
ctxTimeout: 20 * time.Millisecond, | ||
tasks: []T{ | ||
10, 1000, 10, 5000, 10, | ||
}, | ||
executor: func(ctx context.Context, t T) (R, error) { | ||
select { | ||
case <-ctx.Done(): | ||
return 0, ctx.Err() | ||
case <-time.After(time.Duration(t) * time.Millisecond): | ||
return 1, nil | ||
} | ||
}, | ||
opts: []DoAllOpt{WithDoAllWorkerNum(5)}, | ||
}, | ||
want: []DoAllResp[R]{ | ||
{Result: 1}, {Error: context.DeadlineExceeded}, {Result: 1}, {Error: context.DeadlineExceeded}, {Result: 1}, | ||
}, | ||
}, | ||
{ | ||
name: "slow function with sleeps should run concurrently without context deadline error", | ||
args: args{ | ||
ctx: context.Background(), | ||
ctxTimeout: 50 * time.Millisecond, | ||
tasks: []T{ | ||
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, | ||
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, | ||
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, | ||
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, | ||
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, | ||
}, | ||
executor: func(ctx context.Context, t T) (R, error) { | ||
time.Sleep(time.Duration(t) * time.Millisecond) | ||
return 1, nil | ||
}, | ||
opts: []DoAllOpt{WithDoAllWorkerNum(20)}, | ||
}, | ||
want: []DoAllResp[R]{ | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
{Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, {Result: 1}, | ||
}, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
ctx := tt.args.ctx | ||
if tt.args.ctxTimeout > 0 { | ||
var cancel context.CancelFunc | ||
ctx, cancel = context.WithTimeout(ctx, tt.args.ctxTimeout) | ||
defer cancel() | ||
} | ||
got := DoAll(ctx, tt.args.tasks, tt.args.executor, tt.args.opts...) | ||
require.Equal(t, tt.want, got) | ||
}) | ||
} | ||
} |