Skip to content

Commit 43b3e63

Browse files
feat: handle multiple buffers in overloaded_mapreduce
1 parent 7855f19 commit 43b3e63

File tree

1 file changed

+64
-31
lines changed

1 file changed

+64
-31
lines changed

src/TracedRArray.jl

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -587,42 +587,55 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
587587
return T(__default_init(Float16, op))
588588
end
589589

590-
function overloaded_mapreduce(
591-
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing
592-
)
593-
res = unwrapped_broadcast(f, A)
594-
# This means we are unable to use the optimized dispatches. For now we will
595-
# unroll the mapreduce.
596-
if typeof(res) == typeof(A)
597-
@assert dims == Colon() "dims not supported for mapreduce currently."
598-
return foldl(op, res; init)
599-
end
600-
return overloaded_mapreduce(identity, op, res; dims=:, init)
601-
end
590+
_maybe_materialize_traced_array(x::AbstractArray) = materialize_traced_array(x)
591+
_maybe_materialize_traced_array(x) = x
592+
593+
_change_traced_type(::Type{T}, x::AnyTracedRArray) where {T} = T.(x)
594+
_change_traced_type(::Type{T}, x) where {T} = x
602595

603596
function overloaded_mapreduce(
604597
@nospecialize(f),
605598
@nospecialize(op),
606-
@nospecialize(A::AnyTracedRArray{T,N});
599+
@nospecialize(A...);
607600
dims=:,
608601
init=nothing,
609-
) where {T,N}
610-
A = materialize_traced_array(A)
602+
)
603+
if all(x -> !(x isa AnyTracedRArray), A)
604+
res = unwrapped_broadcast(f, A...)
605+
# This means we are unable to use the optimized dispatches. For now we will
606+
# unroll the mapreduce.
607+
if typeof(res) == typeof(A[1])
608+
@assert dims == Colon() "dims not supported for mapreduce currently."
609+
return foldl(op, res; init)
610+
end
611+
return overloaded_mapreduce(identity, op, res; dims=:, init)
612+
end
613+
614+
A = _maybe_materialize_traced_array.(A)
615+
mapped_shape = allequal(map(size, A)) ? size(A[1]) : (minimum(length, A),)
616+
N = length(mapped_shape)
617+
A = map(x -> reshape(x, length(x)), A)
611618

612619
original_dims = dims
613620
dims isa Int && (dims = Int64[dims])
614621
dims isa Colon && (dims = collect(Int64, 1:N))
615622
dims isa Vector{Int64} || (dims = collect(Int64, dims))
616623

617-
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T}))
624+
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Broadcast.eltypes(A)))
618625
reduce_init = __default_init(op_in_T, op)
619626
if unwrapped_eltype(typeof(reduce_init)) != op_in_T
620627
op_in_T = typeof(reduce_init)
621-
A = typeof(reduce_init).(A)
628+
A = _change_traced_type.(typeof(reduce_init), A)
622629
end
623630
reduce_init = TracedUtils.promote_to(TracedRNumber{op_in_T}, reduce_init)
624631

625-
reduce_input = materialize_traced_array(broadcast(f, A))
632+
res = reshape(f.(A...), mapped_shape)
633+
if !(res isa AnyTracedRArray)
634+
@assert dims == Colon() "dims not supported for mapreduce currently."
635+
return foldl(op, res; init)
636+
end
637+
638+
reduce_input = materialize_traced_array(res)
626639

627640
res = @opcall reduce(reduce_input, reduce_init, dims, op)
628641

@@ -635,7 +648,7 @@ function overloaded_mapreduce(
635648
if res isa TracedRNumber
636649
res = TracedRArray{unwrapped_eltype(res),0}((), res.mlir_data, ())
637650
end
638-
return @opcall reshape(res, [ifelse(i in dims, 1, size(A, i)) for i in 1:N])
651+
return @opcall reshape(res, [ifelse(i in dims, 1, mapped_shape[i]) for i in 1:N])
639652
end
640653

641654
function Base.mapreducedim!(
@@ -789,7 +802,6 @@ function _copyto!(dest::Array{<:TracedRNumber}, bc::Broadcasted)
789802
bc = Broadcast.preprocess(dest, bc)
790803

791804
args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)
792-
793805
res = TracedUtils.elem_apply(bc.f, args...)
794806
for I in 1:length(dest)
795807
dest[I] = Reactant.@allowscalar res[I]
@@ -1460,25 +1472,46 @@ end
14601472

14611473
(fn::BroadcastIterator)(args...) = Reactant.call_with_reactant(fn.f, (args...,))
14621474

1463-
function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F}
1475+
function _canonicalize_iter(x::Base.Iterators.Zip)
14641476
min_length = Base.inferencebarrier(minimum)(length, x.is)
1465-
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
1466-
if any(Base.Fix2(isa, AnyTracedRArray), itrs)
1467-
return (BroadcastIterator(f)).(itrs...)
1477+
iters = last.(_canonicalize_iter.(x.is))
1478+
itrs = [Base.Fix2(getindex, i).(iters) for i in 1:min_length]
1479+
any_is_anytraced = any(Base.Fix2(isa, AnyTracedRArray), x.is)
1480+
return min_length, any_is_anytraced, itrs
1481+
end
1482+
1483+
function _canonicalize_iter(x::Base.Iterators.Enumerate)
1484+
return _canonicalize_iter(zip(eachindex(x), x))
1485+
end
1486+
1487+
_canonicalize_iter(x) = length(x), x isa AnyTracedRArray, x
1488+
1489+
function unwrapped_broadcast(f::F, xs...) where {F}
1490+
len, any_is_anytraced, itrs = if length(xs) == 1
1491+
_canonicalize_iter(xs[1])
14681492
else
1469-
fn = BroadcastIterator(f)
1470-
return [fn(Base.Fix2(getindex, i).(itrs)...) for i in 1:min_length]
1493+
_canonicalize_iter(zip(xs...))
1494+
end
1495+
fn = BroadcastIterator(f)
1496+
if any_is_anytraced
1497+
return splat(f).(itrs)
1498+
else
1499+
return [fn(x...) for x in itrs]
14711500
end
14721501
end
14731502

1474-
function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
1475-
if x.itr isa AnyTracedRArray
1476-
return (BroadcastIterator(f)).(1:length(x.itr), x.itr)
1503+
function unwrapped_broadcast(f::F, xs::Union{Base.Iterators.Zip, Base.Iterators.Enumerate}) where {F}
1504+
len, any_is_anytraced, itrs = _canonicalize_iter(xs)
1505+
fn = BroadcastIterator(f)
1506+
if any_is_anytraced
1507+
return splat(f).(itrs)
14771508
else
1478-
return [f((i, x.itr[i])) for i in 1:length(x.itr)]
1509+
return [fn(x...) for x in itrs]
14791510
end
14801511
end
14811512

1482-
unwrapped_broadcast(f::F, xs::Vector) where {F} = [f(x) for x in xs]
1513+
function unwrapped_broadcast(f::F, xs) where {F}
1514+
[f(x) for x in xs]
1515+
end
14831516

14841517
end

0 commit comments

Comments
 (0)