diff --git a/chans/core.go b/chans/core.go index c0524a8..ee9cb48 100644 --- a/chans/core.go +++ b/chans/core.go @@ -46,6 +46,7 @@ func ForEach[A any](in <-chan A, n int, f func(A) bool) { if n == 1 { for a := range in { if !f(a) { + defer DrainNB(in) break } } diff --git a/chans/core_test.go b/chans/core_test.go index f415b38..17ec56c 100644 --- a/chans/core_test.go +++ b/chans/core_test.go @@ -194,24 +194,27 @@ func TestForEach(t *testing.T) { t.Run(th.Name("early exit", n), func(t *testing.T) { th.ExpectNotHang(t, 10*time.Second, func() { - done := make(chan struct{}) - defer close(done) + cnt := int64(0) - sum := int64(0) - - in := th.InfiniteChan(done) + in := th.FromRange(0, 1000) ForEach(in, n, func(x int) bool { if x == 100 { return false } - atomic.AddInt64(&sum, int64(x)) + atomic.AddInt64(&cnt, 1) return true }) - if sum < 99*100/2 { + if cnt < 100 { t.Errorf("expected at least 100 iterations to complete") } + if cnt > 150 { + t.Errorf("early exit did not happen") + } + + time.Sleep(1 * time.Second) + th.ExpectDrainedChan(t, in) }) }) diff --git a/echans/core_test.go b/echans/core_test.go index 2959659..9b42e4c 100644 --- a/echans/core_test.go +++ b/echans/core_test.go @@ -335,48 +335,52 @@ func TestForEach(t *testing.T) { t.Run(th.Name("error in input", n), func(t *testing.T) { th.ExpectNotHang(t, 10*time.Second, func() { - done := make(chan struct{}) - defer close(done) - - in := Wrap(th.InfiniteChan(done), nil) + in := Wrap(th.FromRange(0, 1000), nil) in = replaceWithError(in, 100, fmt.Errorf("err100")) - sum := int64(0) + cnt := int64(0) err := ForEach(in, n, func(x int) error { - atomic.AddInt64(&sum, int64(x)) + atomic.AddInt64(&cnt, 1) return nil }) th.ExpectError(t, err, "err100") - - fmt.Println(sum, 99*100/2) - - if sum < 99*100/2 { + if cnt < 100 { t.Errorf("expected at least 100 iterations to complete") } + if cnt > 150 { + t.Errorf("early exit did not happen") + } + + time.Sleep(1 * time.Second) + th.ExpectDrainedChan(t, in) }) }) t.Run(th.Name("error in func", n), func(t *testing.T) { th.ExpectNotHang(t, 10*time.Second, func() { - done := make(chan struct{}) - defer close(done) - - in := Wrap(th.InfiniteChan(done), nil) + in := Wrap(th.FromRange(0, 1000), nil) - sum := int64(0) + cnt := int64(0) err := ForEach(in, n, func(x int) error { if x == 100 { return fmt.Errorf("err100") } - atomic.AddInt64(&sum, int64(x)) + atomic.AddInt64(&cnt, 1) return nil }) th.ExpectError(t, err, "err100") - if sum < 99*100/2 { + if cnt < 100 { t.Errorf("expected at least 100 iterations to complete") } + if cnt > 150 { + t.Errorf("early exit did not happen") + } + + // wait until it drained + time.Sleep(1 * time.Second) + th.ExpectDrainedChan(t, in) }) }) diff --git a/echans/util_test.go b/echans/util_test.go index 254fd61..16e0db2 100644 --- a/echans/util_test.go +++ b/echans/util_test.go @@ -38,6 +38,6 @@ func TestFromToSlice(t *testing.T) { th.ExpectError(t, err, "err15") time.Sleep(1 * time.Second) - th.ExpectClosedChan(t, in) + th.ExpectDrainedChan(t, in) }) } diff --git a/internal/common/loops_test.go b/internal/common/loops_test.go index c56ab1a..46ef39f 100644 --- a/internal/common/loops_test.go +++ b/internal/common/loops_test.go @@ -93,7 +93,7 @@ func TestBreakable(t *testing.T) { } th.ExpectValue(t, maxSeen, 9999) - th.ExpectClosedChan(t, in) + th.ExpectDrainedChan(t, in) }) t.Run("early exit", func(t *testing.T) { @@ -120,7 +120,7 @@ func TestBreakable(t *testing.T) { } - th.ExpectClosedChan(t, in) + th.ExpectDrainedChan(t, in) }) } diff --git a/internal/th/assertions.go b/internal/th/assertions.go index 86eb026..af90b84 100644 --- a/internal/th/assertions.go +++ b/internal/th/assertions.go @@ -83,7 +83,7 @@ func ExpectUnsorted[T ordered](t *testing.T, arr []T) { } } -func ExpectClosedChan[A any](t *testing.T, ch <-chan A) { +func ExpectDrainedChan[A any](t *testing.T, ch <-chan A) { t.Helper() select { case x, ok := <-ch: diff --git a/internal/th/helpers.go b/internal/th/helpers.go index e867722..9dd47bb 100644 --- a/internal/th/helpers.go +++ b/internal/th/helpers.go @@ -17,22 +17,6 @@ func FromRange(start, end int) <-chan int { return ch } -// InfiniteChan generates infinite sequence of natural numbers. It stops when stop channel is closed. -func InfiniteChan(stop <-chan struct{}) <-chan int { - ch := make(chan int) - go func() { - defer close(ch) - for i := 0; ; i++ { - select { - case <-stop: - return - case ch <- i: - } - } - }() - return ch -} - func Send[T any](ch chan<- T, items ...T) { for _, item := range items { ch <- item