@@ -587,42 +587,55 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
587
587
return T (__default_init (Float16, op))
588
588
end
589
589
590
- function overloaded_mapreduce (
591
- @nospecialize (f), @nospecialize (op), @nospecialize (A); dims= :, init= nothing
592
- )
593
- res = unwrapped_broadcast (f, A)
594
- # This means we are unable to use the optimized dispatches. For now we will
595
- # unroll the mapreduce.
596
- if typeof (res) == typeof (A)
597
- @assert dims == Colon () " dims not supported for mapreduce currently."
598
- return foldl (op, res; init)
599
- end
600
- return overloaded_mapreduce (identity, op, res; dims= :, init)
601
- end
590
+ _maybe_materialize_traced_array (x:: AbstractArray ) = materialize_traced_array (x)
591
+ _maybe_materialize_traced_array (x) = x
592
+
593
+ _change_traced_type (:: Type{T} , x:: AnyTracedRArray ) where {T} = T .(x)
594
+ _change_traced_type (:: Type{T} , x) where {T} = x
602
595
603
596
function overloaded_mapreduce (
604
597
@nospecialize (f),
605
598
@nospecialize (op),
606
- @nospecialize (A:: AnyTracedRArray{T,N} );
599
+ @nospecialize (A... );
607
600
dims= :,
608
601
init= nothing ,
609
- ) where {T,N}
610
- A = materialize_traced_array (A)
602
+ )
603
+ if all (x -> ! (x isa AnyTracedRArray), A)
604
+ res = unwrapped_broadcast (f, A... )
605
+ # This means we are unable to use the optimized dispatches. For now we will
606
+ # unroll the mapreduce.
607
+ if typeof (res) == typeof (A[1 ])
608
+ @assert dims == Colon () " dims not supported for mapreduce currently."
609
+ return foldl (op, res; init)
610
+ end
611
+ return overloaded_mapreduce (identity, op, res; dims= :, init)
612
+ end
613
+
614
+ A = _maybe_materialize_traced_array .(A)
615
+ mapped_shape = allequal (map (size, A)) ? size (A[1 ]) : (minimum (length, A),)
616
+ N = length (mapped_shape)
617
+ A = map (x -> reshape (x, length (x)), A)
611
618
612
619
original_dims = dims
613
620
dims isa Int && (dims = Int64[dims])
614
621
dims isa Colon && (dims = collect (Int64, 1 : N))
615
622
dims isa Vector{Int64} || (dims = collect (Int64, dims))
616
623
617
- op_in_T = unwrapped_eltype (Core. Compiler. return_type (f, Tuple{T} ))
624
+ op_in_T = unwrapped_eltype (Core. Compiler. return_type (f, Broadcast . eltypes (A) ))
618
625
reduce_init = __default_init (op_in_T, op)
619
626
if unwrapped_eltype (typeof (reduce_init)) != op_in_T
620
627
op_in_T = typeof (reduce_init)
621
- A = typeof (reduce_init).( A)
628
+ A = _change_traced_type .( typeof (reduce_init), A)
622
629
end
623
630
reduce_init = TracedUtils. promote_to (TracedRNumber{op_in_T}, reduce_init)
624
631
625
- reduce_input = materialize_traced_array (broadcast (f, A))
632
+ res = reshape (f .(A... ), mapped_shape)
633
+ if ! (res isa AnyTracedRArray)
634
+ @assert dims == Colon () " dims not supported for mapreduce currently."
635
+ return foldl (op, res; init)
636
+ end
637
+
638
+ reduce_input = materialize_traced_array (res)
626
639
627
640
res = @opcall reduce (reduce_input, reduce_init, dims, op)
628
641
@@ -635,7 +648,7 @@ function overloaded_mapreduce(
635
648
if res isa TracedRNumber
636
649
res = TracedRArray {unwrapped_eltype(res),0} ((), res. mlir_data, ())
637
650
end
638
- return @opcall reshape (res, [ifelse (i in dims, 1 , size (A, i) ) for i in 1 : N])
651
+ return @opcall reshape (res, [ifelse (i in dims, 1 , mapped_shape[i] ) for i in 1 : N])
639
652
end
640
653
641
654
function Base. mapreducedim! (
@@ -789,7 +802,6 @@ function _copyto!(dest::Array{<:TracedRNumber}, bc::Broadcasted)
789
802
bc = Broadcast. preprocess (dest, bc)
790
803
791
804
args = (TracedUtils. broadcast_to_size (Base. materialize (a), size (bc)) for a in bc. args)
792
-
793
805
res = TracedUtils. elem_apply (bc. f, args... )
794
806
for I in 1 : length (dest)
795
807
dest[I] = Reactant. @allowscalar res[I]
@@ -1460,25 +1472,46 @@ end
1460
1472
1461
1473
(fn:: BroadcastIterator )(args... ) = Reactant. call_with_reactant (fn. f, (args... ,))
1462
1474
1463
- function unwrapped_broadcast (f :: F , x:: Base.Iterators.Zip ) where {F}
1475
+ function _canonicalize_iter ( x:: Base.Iterators.Zip )
1464
1476
min_length = Base. inferencebarrier (minimum)(length, x. is)
1465
- itrs = [length (itr) > min_length ? itr[1 : min_length] : itr for itr in x. is]
1466
- if any (Base. Fix2 (isa, AnyTracedRArray), itrs)
1467
- return (BroadcastIterator (f)). (itrs... )
1477
+ iters = last .(_canonicalize_iter .(x. is))
1478
+ itrs = [Base. Fix2 (getindex, i).(iters) for i in 1 : min_length]
1479
+ any_is_anytraced = any (Base. Fix2 (isa, AnyTracedRArray), x. is)
1480
+ return min_length, any_is_anytraced, itrs
1481
+ end
1482
+
1483
+ function _canonicalize_iter (x:: Base.Iterators.Enumerate )
1484
+ return _canonicalize_iter (zip (eachindex (x), x))
1485
+ end
1486
+
1487
+ _canonicalize_iter (x) = length (x), x isa AnyTracedRArray, x
1488
+
1489
+ function unwrapped_broadcast (f:: F , xs... ) where {F}
1490
+ len, any_is_anytraced, itrs = if length (xs) == 1
1491
+ _canonicalize_iter (xs[1 ])
1468
1492
else
1469
- fn = BroadcastIterator (f)
1470
- return [fn (Base. Fix2 (getindex, i).(itrs). .. ) for i in 1 : min_length]
1493
+ _canonicalize_iter (zip (xs... ))
1494
+ end
1495
+ fn = BroadcastIterator (f)
1496
+ if any_is_anytraced
1497
+ return splat (f).(itrs)
1498
+ else
1499
+ return [fn (x... ) for x in itrs]
1471
1500
end
1472
1501
end
1473
1502
1474
- function unwrapped_broadcast (f:: F , x:: Base.Iterators.Enumerate ) where {F}
1475
- if x. itr isa AnyTracedRArray
1476
- return (BroadcastIterator (f)). (1 : length (x. itr), x. itr)
1503
+ function unwrapped_broadcast (f:: F , xs:: Union{Base.Iterators.Zip, Base.Iterators.Enumerate} ) where {F}
1504
+ len, any_is_anytraced, itrs = _canonicalize_iter (xs)
1505
+ fn = BroadcastIterator (f)
1506
+ if any_is_anytraced
1507
+ return splat (f).(itrs)
1477
1508
else
1478
- return [f ((i, x . itr[i])) for i in 1 : length (x . itr) ]
1509
+ return [fn (x ... ) for x in itrs ]
1479
1510
end
1480
1511
end
1481
1512
1482
- unwrapped_broadcast (f:: F , xs:: Vector ) where {F} = [f (x) for x in xs]
1513
+ function unwrapped_broadcast (f:: F , xs) where {F}
1514
+ [f (x) for x in xs]
1515
+ end
1483
1516
1484
1517
end
0 commit comments