@@ -1431,6 +1431,8 @@ end
1431
1431
end
1432
1432
1433
1433
zip_iterator (a, b) = mapreduce (splat (* ), + , zip (a, b))
1434
+ nary_mapreduce (a, b) = mapreduce (* , + , a, b)
1435
+ nary_mapreduce_dims (a, b) = mapreduce (* , + , a, b; dims = 2 )
1434
1436
enumerate_iterator (a) = mapreduce (splat (* ), + , enumerate (a))
1435
1437
1436
1438
function nested_mapreduce_zip (x, y)
@@ -1483,6 +1485,20 @@ end
1483
1485
1484
1486
@test @jit (nested_mapreduce_hcat (x_ra, y_ra)) ≈ nested_mapreduce_hcat (x, y)
1485
1487
end
1488
+
1489
+ @testset " n-ary mapreduce" begin
1490
+ x = rand (Float32, 12 )
1491
+ y = rand (Float32, 12 )
1492
+ z = rand (Float32, 4 , 3 )
1493
+ w = rand (Float32, 4 , 3 )
1494
+
1495
+ rx, ry, rz, rw = Reactant. to_rarray .((x, y, z, w))
1496
+ @test @jit (nary_mapreduce (rx, ry)) ≈ nary_mapreduce (x, y)
1497
+ @test @jit (nary_mapreduce (rx, rz)) ≈ nary_mapreduce (x, z)
1498
+ @test @jit (nary_mapreduce (rz, rw)) ≈ nary_mapreduce (z, w)
1499
+ @test @jit (nary_mapreduce_dims (rz, rw)) ≈ nary_mapreduce_dims (z, w)
1500
+ @test @jit (nary_mapreduce (rz, rx)) ≈ nary_mapreduce (z, x)
1501
+ end
1486
1502
end
1487
1503
1488
1504
@testset " compilation cache" begin
0 commit comments