Skip to content

Commit aa0d8a3

Browse files
authored
add ctx to executor param (#199)
1 parent bc8e921 commit aa0d8a3

File tree

3 files changed

+47
-12
lines changed

3 files changed

+47
-12
lines changed

ctask/doer.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ type DoConfig struct {
2121
func Do[Task any, Result any](
2222
ctx context.Context,
2323
tasks []Task,
24-
executor func(t Task) (Result, error),
24+
executor func(ctx context.Context, t Task) (Result, error),
2525
opts ...DoOpt,
2626
) ([]Result, error) {
2727
cfg := getConfigWithOptions(opts...)
2828

2929
g, ctx := errgroup.WithContext(ctx)
30-
g.SetLimit(int(cfg.WorkerNum))
30+
g.SetLimit(cfg.WorkerNum)
3131
results := make([]Result, len(tasks))
3232
for idx, task := range tasks {
3333
idx, task := idx, task // retain current loop values to be used in goroutine
@@ -36,7 +36,7 @@ func Do[Task any, Result any](
3636
case <-ctx.Done():
3737
return ctx.Err()
3838
default:
39-
res, err := executor(task)
39+
res, err := executor(ctx, task)
4040
if err != nil {
4141
return err
4242
}

ctask/doer_test.go

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"testing"
7+
"time"
78

89
"github.com/stretchr/testify/require"
910
)
@@ -13,10 +14,11 @@ func TestDo(t *testing.T) {
1314
type R = int // result type
1415

1516
type args struct {
16-
ctx context.Context
17-
tasks []T
18-
executor func(t T) (R, error)
19-
opts []DoOpt
17+
ctx context.Context
18+
ctxTimeout time.Duration
19+
tasks []T
20+
executor func(ctx context.Context, t T) (R, error)
21+
opts []DoOpt
2022
}
2123
tests := []struct {
2224
name string
@@ -59,29 +61,62 @@ func TestDo(t *testing.T) {
5961
require.Equal(t, errors.New("negative"), err)
6062
},
6163
},
64+
{
65+
name: "slow function with sleeps should run concurrently without context deadline error",
66+
args: args{
67+
ctx: context.Background(),
68+
ctxTimeout: 50 * time.Millisecond,
69+
tasks: []T{
70+
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
71+
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
72+
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
73+
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
74+
10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
75+
},
76+
executor: func(ctx context.Context, t T) (R, error) {
77+
time.Sleep(time.Duration(t) * time.Millisecond)
78+
return 1, nil
79+
},
80+
opts: []DoOpt{WithWorkerNum(20)},
81+
},
82+
want: []R{
83+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
84+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
85+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
86+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
87+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
88+
},
89+
requireErr: require.NoError,
90+
},
6291
}
6392
for _, tt := range tests {
6493
t.Run(tt.name, func(t *testing.T) {
65-
got, err := Do(tt.args.ctx, tt.args.tasks, tt.args.executor, tt.args.opts...)
94+
ctx := tt.args.ctx
95+
if tt.args.ctxTimeout > 0 {
96+
var cancel context.CancelFunc
97+
ctx, cancel = context.WithTimeout(ctx, tt.args.ctxTimeout)
98+
defer cancel()
99+
}
100+
got, err := Do(ctx, tt.args.tasks, tt.args.executor, tt.args.opts...)
66101
tt.requireErr(t, err)
67102
require.Equal(t, tt.want, got)
68103
})
69104
}
70105
}
71106

72-
func fibonacci(n int) (int, error) {
107+
func fibonacci(ctx context.Context, n int) (int, error) {
73108
if n < 0 {
74109
return 0, errors.New("negative")
75110
}
76111
if n < 2 {
77112
return 1, nil
78113
}
79-
r1, err := fibonacci(n - 1)
114+
r1, err := fibonacci(ctx, n-1)
80115
if err != nil {
81116
return 0, err
82117
}
83118

84-
r2, err := fibonacci(n - 2)
119+
r2, err := fibonacci(ctx, n-2)
85120
if err != nil {
86121
return 0, err
87122
}

worker/worker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (w *worker) hold(ctx context.Context, wg *sync.WaitGroup) {
119119
go func() {
120120
defer wg.Done()
121121

122-
_ = <-ctx.Done()
122+
<-ctx.Done()
123123

124124
if w.stopFn != nil {
125125
logger.Info("stopping...")

0 commit comments

Comments
 (0)