Skip to content

Commit 2216bdb

Browse files
feat: handle multiple buffers in overloaded_mapreduce
1 parent 8e30330 commit 2216bdb

File tree

1 file changed

+32
-19
lines changed

1 file changed

+32
-19
lines changed

src/TracedRArray.jl

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -543,42 +543,55 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
543543
return T(__default_init(Float16, op))
544544
end
545545

546-
function overloaded_mapreduce(
547-
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing
548-
)
549-
res = unwrapped_broadcast(f, A)
550-
# This means we are unable to use the optimized dispatches. For now we will
551-
# unroll the mapreduce.
552-
if typeof(res) == typeof(A)
553-
@assert dims == Colon() "dims not supported for mapreduce currently."
554-
return foldl(op, res; init)
555-
end
556-
return overloaded_mapreduce(identity, op, res; dims=:, init)
557-
end
546+
_maybe_materialize_traced_array(x::AbstractArray) = materialize_traced_array(x)
547+
_maybe_materialize_traced_array(x) = x
548+
549+
_change_traced_type(::Type{T}, x::AnyTracedRArray) where {T} = T.(x)
550+
_change_traced_type(::Type{T}, x) where {T} = x
558551

559552
function overloaded_mapreduce(
560553
@nospecialize(f),
561554
@nospecialize(op),
562-
@nospecialize(A::AnyTracedRArray{T,N});
555+
@nospecialize(A...);
563556
dims=:,
564557
init=nothing,
565-
) where {T,N}
566-
A = materialize_traced_array(A)
558+
)
559+
if all(x -> !(x isa AnyTracedRArray), A)
560+
res = f.(A...)
561+
# This means we are unable to use the optimized dispatches. For now we will
562+
# unroll the mapreduce.
563+
if typeof(res) == typeof(A[1])
564+
@assert dims == Colon() "dims not supported for mapreduce currently."
565+
return foldl(op, res; init)
566+
end
567+
return overloaded_mapreduce(identity, op, res; dims=:, init)
568+
end
569+
570+
A = _maybe_materialize_traced_array.(A)
571+
mapped_shape = allequal(map(size, A)) ? size(A[1]) : (minimum(length, A),)
572+
N = length(mapped_shape)
573+
A = map(x -> reshape(x, length(x)), A)
567574

568575
original_dims = dims
569576
dims isa Int && (dims = Int64[dims])
570577
dims isa Colon && (dims = collect(Int64, 1:N))
571578
dims isa AbstractVector{<:Integer} || (dims = collect(Int64, dims))
572579

573-
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T}))
580+
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Broadcast.eltypes(A)))
574581
reduce_init = __default_init(op_in_T, op)
575582
if unwrapped_eltype(typeof(reduce_init)) != op_in_T
576583
op_in_T = typeof(reduce_init)
577-
A = typeof(reduce_init).(A)
584+
A = _change_traced_type.(typeof(reduce_init), A)
578585
end
579586
reduce_init = TracedUtils.promote_to(TracedRNumber{op_in_T}, reduce_init)
580587

581-
reduce_input = materialize_traced_array(broadcast(f, A))
588+
res = f.(A...)
589+
if !(res isa AnyTracedRArray)
590+
@assert dims == Colon() "dims not supported for mapreduce currently."
591+
return foldl(op, res; init)
592+
end
593+
594+
reduce_input = materialize_traced_array(res)
582595

583596
res = Ops.reduce(reduce_input, reduce_init, dims, op)
584597

@@ -591,7 +604,7 @@ function overloaded_mapreduce(
591604
if res isa TracedRNumber
592605
res = TracedRArray{unwrapped_eltype(res),0}((), res.mlir_data, ())
593606
end
594-
return Ops.reshape(res, [ifelse(i in dims, 1, size(A, i)) for i in 1:N])
607+
return Ops.reshape(res, [ifelse(i in dims, 1, mapped_shape[i]) for i in 1:N])
595608
end
596609

597610
function Base.mapreducedim!(

0 commit comments

Comments
 (0)