@@ -543,42 +543,55 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
543
543
return T (__default_init (Float16, op))
544
544
end
545
545
546
- function overloaded_mapreduce (
547
- @nospecialize (f), @nospecialize (op), @nospecialize (A); dims= :, init= nothing
548
- )
549
- res = unwrapped_broadcast (f, A)
550
- # This means we are unable to use the optimized dispatches. For now we will
551
- # unroll the mapreduce.
552
- if typeof (res) == typeof (A)
553
- @assert dims == Colon () " dims not supported for mapreduce currently."
554
- return foldl (op, res; init)
555
- end
556
- return overloaded_mapreduce (identity, op, res; dims= :, init)
557
- end
546
+ _maybe_materialize_traced_array (x:: AbstractArray ) = materialize_traced_array (x)
547
+ _maybe_materialize_traced_array (x) = x
548
+
549
+ _change_traced_type (:: Type{T} , x:: AnyTracedRArray ) where {T} = T .(x)
550
+ _change_traced_type (:: Type{T} , x) where {T} = x
558
551
559
552
function overloaded_mapreduce (
560
553
@nospecialize (f),
561
554
@nospecialize (op),
562
- @nospecialize (A:: AnyTracedRArray{T,N} );
555
+ @nospecialize (A... );
563
556
dims= :,
564
557
init= nothing ,
565
- ) where {T,N}
566
- A = materialize_traced_array (A)
558
+ )
559
+ if all (x -> ! (x isa AnyTracedRArray), A)
560
+ res = f .(A... )
561
+ # This means we are unable to use the optimized dispatches. For now we will
562
+ # unroll the mapreduce.
563
+ if typeof (res) == typeof (A[1 ])
564
+ @assert dims == Colon () " dims not supported for mapreduce currently."
565
+ return foldl (op, res; init)
566
+ end
567
+ return overloaded_mapreduce (identity, op, res; dims= :, init)
568
+ end
569
+
570
+ A = _maybe_materialize_traced_array .(A)
571
+ mapped_shape = allequal (map (size, A)) ? size (A[1 ]) : (minimum (length, A),)
572
+ N = length (mapped_shape)
573
+ A = map (x -> reshape (x, length (x)), A)
567
574
568
575
original_dims = dims
569
576
dims isa Int && (dims = Int64[dims])
570
577
dims isa Colon && (dims = collect (Int64, 1 : N))
571
578
dims isa AbstractVector{<: Integer } || (dims = collect (Int64, dims))
572
579
573
- op_in_T = unwrapped_eltype (Core. Compiler. return_type (f, Tuple{T} ))
580
+ op_in_T = unwrapped_eltype (Core. Compiler. return_type (f, Broadcast . eltypes (A) ))
574
581
reduce_init = __default_init (op_in_T, op)
575
582
if unwrapped_eltype (typeof (reduce_init)) != op_in_T
576
583
op_in_T = typeof (reduce_init)
577
- A = typeof (reduce_init).( A)
584
+ A = _change_traced_type .( typeof (reduce_init), A)
578
585
end
579
586
reduce_init = TracedUtils. promote_to (TracedRNumber{op_in_T}, reduce_init)
580
587
581
- reduce_input = materialize_traced_array (broadcast (f, A))
588
+ res = f .(A... )
589
+ if ! (res isa AnyTracedRArray)
590
+ @assert dims == Colon () " dims not supported for mapreduce currently."
591
+ return foldl (op, res; init)
592
+ end
593
+
594
+ reduce_input = materialize_traced_array (res)
582
595
583
596
res = Ops. reduce (reduce_input, reduce_init, dims, op)
584
597
@@ -591,7 +604,7 @@ function overloaded_mapreduce(
591
604
if res isa TracedRNumber
592
605
res = TracedRArray {unwrapped_eltype(res),0} ((), res. mlir_data, ())
593
606
end
594
- return Ops. reshape (res, [ifelse (i in dims, 1 , size (A, i) ) for i in 1 : N])
607
+ return Ops. reshape (res, [ifelse (i in dims, 1 , mapped_shape[i] ) for i in 1 : N])
595
608
end
596
609
597
610
function Base. mapreducedim! (
0 commit comments