diff --git a/reduce.go b/reduce.go new file mode 100644 index 0000000..d05a17a --- /dev/null +++ b/reduce.go @@ -0,0 +1,109 @@ +package rill + +import ( + "sync" + + "github.com/destel/rill/internal/core" +) + +// Reduce combines all elements from the input channel into a single value +// using a binary function f. The function f must be commutative, meaning +// f(x,y) == f(y,x). It is applied to pairs of elements concurrently, using n +// goroutines, progressively reducing the channel's contents until only one value remains. +// The order in which the function f is applied is not guaranteed due to concurrent processing. +// +// Reduce blocks until all items are processed or an error is encountered, +// either from the function f itself or from the upstream. In case of an error +// leading to early termination, Reduce ensures the input channel is drained to +// avoid goroutine leaks, making it safe for use in environments where cleanup +// is crucial. +// +// The function returns the first encountered error, if any, or the reduction result. +// The second return value is false if the input channel is empty, and true otherwise. +func Reduce[A any](in <-chan Try[A], n int, f func(A, A) (A, error)) (A, bool, error) { + in, earlyExit := core.Breakable(in) + + res, ok := core.Reduce(in, n, func(a1, a2 Try[A]) Try[A] { + if err := a1.Error; err != nil { + earlyExit() + return a1 + } + + if err := a2.Error; err != nil { + earlyExit() + return a2 + } + + res, err := f(a1.Value, a2.Value) + if err != nil { + earlyExit() + return Try[A]{Error: err} + } + + return Try[A]{Value: res} + }) + + return res.Value, ok, res.Error +} + +// MapReduce reduces the input channel to a map using a mapper and a reducer functions. +// Reduction is done in two phases happening concurrently. In the first phase, +// the mapper function transforms each input item into a key-value pair using +// nm goroutines. As a result of this phase, we can get multiple values for the +// same key. In the second phase, the reducer function reduces values for the +// same key into a single value, using nr goroutines. The order in which the reducer +// is applied is not guaranteed due to concurrent processing. See [Reduce] documentation +// for more details on reduction phase semantics. +// +// MapReduce blocks until all items are processed or an error is encountered, +// either from the mapper, reducer, or the upstream. In case of an error +// leading to early termination, MapReduce ensures the input channel is drained +// to avoid goroutine leaks, making it safe for use in environments where +// cleanup is crucial. +// +// The function returns the first encountered error, if any, or a map where +// each key is associated with a single reduced value +func MapReduce[A any, K comparable, V any](in <-chan Try[A], nm int, mapper func(A) (K, V, error), nr int, reducer func(V, V) (V, error)) (map[K]V, error) { + var zeroKey K + var zeroVal V + + in, earlyExit := core.Breakable(in) + + var retErr error + var once sync.Once + + reportError := func(err error) { + earlyExit() + once.Do(func() { + retErr = err + }) + } + + res := core.MapReduce(in, + nm, func(a Try[A]) (K, V) { + if a.Error != nil { + reportError(a.Error) + return zeroKey, zeroVal + } + + k, v, err := mapper(a.Value) + if err != nil { + reportError(err) + return zeroKey, zeroVal + } + + return k, v + }, + nr, func(v1, v2 V) V { + res, err := reducer(v1, v2) + if err != nil { + reportError(err) + return zeroVal + } + + return res + }, + ) + + return res, retErr +} diff --git a/reduce_test.go b/reduce_test.go new file mode 100644 index 0000000..e85e26e --- /dev/null +++ b/reduce_test.go @@ -0,0 +1,207 @@ +package rill + +import ( + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/destel/rill/internal/th" +) + +func TestReduce(t *testing.T) { + for _, n := range []int{1, 4} { + t.Run(th.Name("empty", n), func(t *testing.T) { + in := FromSlice([]int{}, nil) + + _, ok, err := Reduce(in, n, func(x, y int) (int, error) { + + return x + y, nil + }) + + th.ExpectNoError(t, err) + th.ExpectValue(t, ok, false) + }) + + t.Run(th.Name("no errors", n), func(t *testing.T) { + in := FromChan(th.FromRange(0, 100), nil) + + cnt := int64(0) + out, ok, err := Reduce(in, n, func(x, y int) (int, error) { + atomic.AddInt64(&cnt, 1) + return x + y, nil + }) + + th.ExpectNoError(t, err) + th.ExpectValue(t, out, 99*100/2) + th.ExpectValue(t, ok, true) + th.ExpectValue(t, cnt, 99) + }) + + t.Run(th.Name("error in input", n), func(t *testing.T) { + in := FromChan(th.FromRange(0, 1000), nil) + in = replaceWithError(in, 100, fmt.Errorf("err100")) + + cnt := int64(0) + _, _, err := Reduce(in, n, func(x, y int) (int, error) { + atomic.AddInt64(&cnt, 1) + return x + y, nil + }) + + th.ExpectError(t, err, "err100") + if cnt == 999 { + t.Errorf("early exit did not happen") + } + + time.Sleep(1 * time.Second) + th.ExpectDrainedChan(t, in) + }) + + t.Run(th.Name("error in func", n), func(t *testing.T) { + in := FromChan(th.FromRange(0, 1000), nil) + + cnt := int64(0) + _, _, err := Reduce(in, n, func(x, y int) (int, error) { + if atomic.AddInt64(&cnt, 1) == 100 { + return 0, fmt.Errorf("err100") + } + + return x + y, nil + }) + + th.ExpectError(t, err, "err100") + if cnt == 999 { + t.Errorf("early exit did not happen") + } + }) + } +} + +func TestMapReduce(t *testing.T) { + for _, nm := range []int{1, 4} { + for _, nr := range []int{1, 4} { + t.Run(th.Name("empty", nm, nr), func(t *testing.T) { + in := FromSlice([]int{}, nil) + + out, err := MapReduce(in, + nm, func(x int) (string, int, error) { + s := fmt.Sprint(x) + return fmt.Sprintf("%d-digit", len(s)), x, nil + }, + nr, func(x, y int) (int, error) { + return x + y, nil + }) + + th.ExpectNoError(t, err) + th.ExpectMap(t, out, map[string]int{}) + }) + + t.Run(th.Name("no errors", nm, nr), func(t *testing.T) { + in := FromChan(th.FromRange(0, 1000), nil) + + var cntMap, cntReduce int64 + out, err := MapReduce(in, + nm, func(x int) (string, int, error) { + atomic.AddInt64(&cntMap, 1) + s := fmt.Sprint(x) + return fmt.Sprintf("%d-digit", len(s)), x, nil + }, + nr, func(x, y int) (int, error) { + atomic.AddInt64(&cntReduce, 1) + return x + y, nil + }, + ) + + th.ExpectNoError(t, err) + th.ExpectMap(t, out, map[string]int{ + "1-digit": (0 + 9) * 10 / 2, + "2-digit": (10 + 99) * 90 / 2, + "3-digit": (100 + 999) * 900 / 2, + }) + th.ExpectValue(t, cntMap, 1000) + th.ExpectValue(t, cntReduce, 9+89+899) + }) + + t.Run(th.Name("error in input", nm, nr), func(t *testing.T) { + in := FromChan(th.FromRange(0, 1000), nil) + in = replaceWithError(in, 100, fmt.Errorf("err100")) + + var cntMap, cntReduce int64 + _, err := MapReduce(in, + nm, func(x int) (string, int, error) { + atomic.AddInt64(&cntMap, 1) + s := fmt.Sprint(x) + return fmt.Sprintf("%d-digit", len(s)), x, nil + }, + nr, func(x, y int) (int, error) { + atomic.AddInt64(&cntReduce, 1) + return x + y, nil + }, + ) + + th.ExpectError(t, err, "err100") + if cntMap == 1000 { + t.Errorf("early exit did not happen") + } + if cntReduce == 9+89+899 { + t.Errorf("early exit did not happen") + } + }) + + t.Run(th.Name("error in mapper", nm, nr), func(t *testing.T) { + in := FromChan(th.FromRange(0, 1000), nil) + + var cntMap, cntReduce int64 + _, err := MapReduce(in, + nm, func(x int) (string, int, error) { + if atomic.AddInt64(&cntMap, 1) == 100 { + return "", 0, fmt.Errorf("err100") + } + s := fmt.Sprint(x) + return fmt.Sprintf("%d-digit", len(s)), x, nil + }, + nr, func(x, y int) (int, error) { + atomic.AddInt64(&cntReduce, 1) + return x + y, nil + }, + ) + + th.ExpectError(t, err, "err100") + if cntMap == 1000 { + t.Errorf("early exit did not happen") + } + if cntReduce == 9+89+899 { + t.Errorf("early exit did not happen") + } + }) + + t.Run(th.Name("error in reducer", nm, nr), func(t *testing.T) { + in := FromChan(th.FromRange(0, 1000), nil) + + var cntMap, cntReduce int64 + _, err := MapReduce(in, + nm, func(x int) (string, int, error) { + atomic.AddInt64(&cntMap, 1) + s := fmt.Sprint(x) + return fmt.Sprintf("%d-digit", len(s)), x, nil + }, + nr, func(x, y int) (int, error) { + if atomic.AddInt64(&cntReduce, 1) == 100 { + return 0, fmt.Errorf("err100") + } + return x + y, nil + }, + ) + + th.ExpectError(t, err, "err100") + if cntMap == 1000 { + t.Errorf("early exit did not happen") + } + if cntReduce == 9+89+899 { + t.Errorf("early exit did not happen") + } + }) + + } + } +}