Skip to content

Commit 8a5189b

Browse files
test: add tests for multi-array mapreduce
1 parent 43b3e63 commit 8a5189b

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

test/basic.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,8 @@ end
14311431
end
14321432

14331433
zip_iterator(a, b) = mapreduce(splat(*), +, zip(a, b))
1434+
nary_mapreduce(a, b) = mapreduce(*, +, a, b)
1435+
nary_mapreduce_dims(a, b) = mapreduce(*, +, a, b; dims = 2)
14341436
enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a))
14351437

14361438
function nested_mapreduce_zip(x, y)
@@ -1483,6 +1485,20 @@ end
14831485

14841486
@test @jit(nested_mapreduce_hcat(x_ra, y_ra)) nested_mapreduce_hcat(x, y)
14851487
end
1488+
1489+
@testset "n-ary mapreduce" begin
1490+
x = rand(Float32, 12)
1491+
y = rand(Float32, 12)
1492+
z = rand(Float32, 4, 3)
1493+
w = rand(Float32, 4, 3)
1494+
1495+
rx, ry, rz, rw = Reactant.to_rarray.((x, y, z, w))
1496+
@test @jit(nary_mapreduce(rx, ry)) nary_mapreduce(x, y)
1497+
@test @jit(nary_mapreduce(rx, rz)) nary_mapreduce(x, z)
1498+
@test @jit(nary_mapreduce(rz, rw)) nary_mapreduce(z, w)
1499+
@test @jit(nary_mapreduce_dims(rz, rw)) nary_mapreduce_dims(z, w)
1500+
@test @jit(nary_mapreduce(rz, rx)) nary_mapreduce(z, x)
1501+
end
14861502
end
14871503

14881504
@testset "compilation cache" begin

0 commit comments

Comments
 (0)