Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ end
end

@reactant_overlay @noinline function Base.mapreduce(
f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}; kwargs...
f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}...; kwargs...
)
if use_overlayed_version(A)
return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...)
return TracedRArrayOverrides.overloaded_mapreduce(f, op, A...; kwargs...)
else
return Base.inferencebarrier(Base.mapreduce)(f, op, A; kwargs...)
return Base.inferencebarrier(Base.mapreduce)(f, op, A...; kwargs...)
end
end

Expand Down
95 changes: 64 additions & 31 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -587,42 +587,55 @@ function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
return T(__default_init(Float16, op))
end

function overloaded_mapreduce(
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing
)
res = unwrapped_broadcast(f, A)
# This means we are unable to use the optimized dispatches. For now we will
# unroll the mapreduce.
if typeof(res) == typeof(A)
@assert dims == Colon() "dims not supported for mapreduce currently."
return foldl(op, res; init)
end
return overloaded_mapreduce(identity, op, res; dims=:, init)
end
_maybe_materialize_traced_array(x::AbstractArray) = materialize_traced_array(x)
_maybe_materialize_traced_array(x) = x

_change_traced_type(::Type{T}, x::AnyTracedRArray) where {T} = T.(x)
_change_traced_type(::Type{T}, x) where {T} = x

function overloaded_mapreduce(
@nospecialize(f),
@nospecialize(op),
@nospecialize(A::AnyTracedRArray{T,N});
@nospecialize(A...);
dims=:,
init=nothing,
) where {T,N}
A = materialize_traced_array(A)
)
if all(x -> !(x isa AnyTracedRArray), A)
res = unwrapped_broadcast(f, A...)
# This means we are unable to use the optimized dispatches. For now we will
# unroll the mapreduce.
if typeof(res) == typeof(A[1])
@assert dims == Colon() "dims not supported for mapreduce currently."
return foldl(op, res; init)
end
return overloaded_mapreduce(identity, op, res; dims=:, init)
end

A = _maybe_materialize_traced_array.(A)
mapped_shape = allequal(map(size, A)) ? size(A[1]) : (minimum(length, A),)
N = length(mapped_shape)
A = map(x -> reshape(x, length(x)), A)

original_dims = dims
dims isa Int && (dims = Int64[dims])
dims isa Colon && (dims = collect(Int64, 1:N))
dims isa Vector{Int64} || (dims = collect(Int64, dims))

op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T}))
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Broadcast.eltypes(A)))
reduce_init = __default_init(op_in_T, op)
if unwrapped_eltype(typeof(reduce_init)) != op_in_T
op_in_T = typeof(reduce_init)
A = typeof(reduce_init).(A)
A = _change_traced_type.(typeof(reduce_init), A)
end
reduce_init = TracedUtils.promote_to(TracedRNumber{op_in_T}, reduce_init)

reduce_input = materialize_traced_array(broadcast(f, A))
res = reshape(f.(A...), mapped_shape)
if !(res isa AnyTracedRArray)
@assert dims == Colon() "dims not supported for mapreduce currently."
return foldl(op, res; init)
end

reduce_input = materialize_traced_array(res)

res = @opcall reduce(reduce_input, reduce_init, dims, op)

Expand All @@ -635,7 +648,7 @@ function overloaded_mapreduce(
if res isa TracedRNumber
res = TracedRArray{unwrapped_eltype(res),0}((), res.mlir_data, ())
end
return @opcall reshape(res, [ifelse(i in dims, 1, size(A, i)) for i in 1:N])
return @opcall reshape(res, [ifelse(i in dims, 1, mapped_shape[i]) for i in 1:N])
end

function Base.mapreducedim!(
Expand Down Expand Up @@ -789,7 +802,6 @@ function _copyto!(dest::Array{<:TracedRNumber}, bc::Broadcasted)
bc = Broadcast.preprocess(dest, bc)

args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args)

res = TracedUtils.elem_apply(bc.f, args...)
for I in 1:length(dest)
dest[I] = Reactant.@allowscalar res[I]
Expand Down Expand Up @@ -1460,25 +1472,46 @@ end

(fn::BroadcastIterator)(args...) = Reactant.call_with_reactant(fn.f, (args...,))

function unwrapped_broadcast(f::F, x::Base.Iterators.Zip) where {F}
function _canonicalize_iter(x::Base.Iterators.Zip)
min_length = Base.inferencebarrier(minimum)(length, x.is)
itrs = [length(itr) > min_length ? itr[1:min_length] : itr for itr in x.is]
if any(Base.Fix2(isa, AnyTracedRArray), itrs)
return (BroadcastIterator(f)).(itrs...)
iters = last.(_canonicalize_iter.(x.is))
itrs = [Base.Fix2(getindex, i).(iters) for i in 1:min_length]
any_is_anytraced = any(Base.Fix2(isa, AnyTracedRArray), x.is)
return min_length, any_is_anytraced, itrs
end

function _canonicalize_iter(x::Base.Iterators.Enumerate)
return _canonicalize_iter(zip(eachindex(x), x))
end

_canonicalize_iter(x) = length(x), x isa AnyTracedRArray, x

function unwrapped_broadcast(f::F, xs...) where {F}
len, any_is_anytraced, itrs = if length(xs) == 1
_canonicalize_iter(xs[1])
else
fn = BroadcastIterator(f)
return [fn(Base.Fix2(getindex, i).(itrs)...) for i in 1:min_length]
_canonicalize_iter(zip(xs...))
end
fn = BroadcastIterator(f)
if any_is_anytraced
return splat(f).(itrs)
else
return [fn(x...) for x in itrs]
end
end

function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
if x.itr isa AnyTracedRArray
return (BroadcastIterator(f)).(1:length(x.itr), x.itr)
function unwrapped_broadcast(f::F, xs::Union{Base.Iterators.Zip, Base.Iterators.Enumerate}) where {F}
len, any_is_anytraced, itrs = _canonicalize_iter(xs)
fn = BroadcastIterator(f)
if any_is_anytraced
return splat(f).(itrs)
else
return [f((i, x.itr[i])) for i in 1:length(x.itr)]
return [fn(x...) for x in itrs]
end
end

unwrapped_broadcast(f::F, xs::Vector) where {F} = [f(x) for x in xs]
function unwrapped_broadcast(f::F, xs) where {F}
[f(x) for x in xs]
end

end
16 changes: 16 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,8 @@ end
end

zip_iterator(a, b) = mapreduce(splat(*), +, zip(a, b))
nary_mapreduce(a, b) = mapreduce(*, +, a, b)
nary_mapreduce_dims(a, b) = mapreduce(*, +, a, b; dims = 2)
enumerate_iterator(a) = mapreduce(splat(*), +, enumerate(a))

function nested_mapreduce_zip(x, y)
Expand Down Expand Up @@ -1483,6 +1485,20 @@ end

@test @jit(nested_mapreduce_hcat(x_ra, y_ra)) ≈ nested_mapreduce_hcat(x, y)
end

@testset "n-ary mapreduce" begin
x = rand(Float32, 12)
y = rand(Float32, 12)
z = rand(Float32, 4, 3)
w = rand(Float32, 4, 3)

rx, ry, rz, rw = Reactant.to_rarray.((x, y, z, w))
@test @jit(nary_mapreduce(rx, ry)) ≈ nary_mapreduce(x, y)
@test @jit(nary_mapreduce(rx, rz)) ≈ nary_mapreduce(x, z)
@test @jit(nary_mapreduce(rz, rw)) ≈ nary_mapreduce(z, w)
@test @jit(nary_mapreduce_dims(rz, rw)) ≈ nary_mapreduce_dims(z, w)
@test @jit(nary_mapreduce(rz, rx)) ≈ nary_mapreduce(z, x)
end
end

@testset "compilation cache" begin
Expand Down