Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 17 additions & 281 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module ReactantCUDAExt
using Reactant: Reactant, TracedRArray, AnyConcretePJRTArray, MLIR, TracedRNumber
using Reactant.Compiler: raising
using Reactant.Ops: @opcall
import Reactant: CuTracedArray, CuTracedRNumber
using Reactant.CuTracedOverloads: alignment

using Adapt: Adapt, adapt
using CUDA: CUDA, CuDim, DenseCuArray, unsafe_cached_load
Expand All @@ -17,228 +19,26 @@ const KA = KernelAbstractions

Reactant.is_extension_loaded(::Val{:CUDA}) = true

struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N}
ptr::Core.LLVMPtr{T,A}

function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size}
gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()]
push!(gc_vec, xs)
@assert gc_vec[end] === xs
ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
return new(ptr)
end
function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size}
gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()]
push!(gc_vec, xs)
@assert gc_vec[end] === xs
ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
return CuTracedArray{T,N,A,Size}(ptr)
end

Reactant.use_overlayed_version(::CuTracedArray) = true

struct CuTracedRNumber{T,A} <: Number
ptr::Core.LLVMPtr{T,A}

function CuTracedRNumber{T,A}(xs::TracedRNumber) where {T,A}
gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()]
push!(gc_vec, xs)
@assert gc_vec[end] === xs
ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
return new(ptr)
end
function CuTracedRNumber{T,A}(ptr::Core.LLVMPtr{T,A}) where {T,A}
return new(ptr)
end
function CuTracedRNumber{T,A}(xs::TracedRNumber) where {T,A}
gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()]
push!(gc_vec, xs)
@assert gc_vec[end] === xs
ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))
return CuTracedRNumber{T,A}(ptr)
end

Reactant.use_overlayed_version(::CuTracedRNumber) = true

Base.@nospecializeinfer Reactant.is_traced_number(
@nospecialize(T::Type{<:CuTracedRNumber})
) = true
Reactant.unwrapped_eltype(::Type{<:CuTracedRNumber{T}}) where {T} = T

@inline CuTracedRNumber{T,A}(val::Number) where {T,A} = convert(CuTracedRNumber{T,A}, val)

function Base.getindex(RN::CuTracedRNumber{T,A}) where {T,A}
align = alignment(RN)
return @inbounds unsafe_load(RN.ptr, 1, Val(align))
end

Base.convert(::Type{T}, RN::CuTracedRNumber) where {T<:Number} = convert(T, getindex(RN))

for jlop in (
:(Base.min),
:(Base.mod),
:(Base.max),
:(Base.:+),
:(Base.:-),
:(Base.:*),
:(Base.:/),
:(Base.:^),
:(Base.rem),
:(Base.isless),
:(Base.:(==)),
:(Base.:(!=)),
)
@eval begin
@inline $jlop(a::CuTracedRNumber, b::CuTracedRNumber) = $jlop(a[], b[])
@inline $jlop(a::CuTracedRNumber{T,A}, b::Number) where {T,A} = $jlop(a[], b)
@inline $jlop(a::Number, b::CuTracedRNumber{T,A}) where {T,A} = $jlop(a, b[])
end
end

@inline Base.ifelse(cond::Bool, a, b::CuTracedRNumber) = ifelse(cond, a, b[])
@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b) = ifelse(cond, a[], b)
@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b::CuTracedRNumber) =
ifelse(cond, a[], b[])
@inline Base.ifelse(cond::CuTracedRNumber, a, b) = ifelse(cond[], a, b)
@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b) = ifelse(cond[], a[], b)
@inline Base.ifelse(cond::CuTracedRNumber, a, b::CuTracedRNumber) = ifelse(cond[], a, b[])
@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b::CuTracedRNumber) =
ifelse(cond[], a[], b[])

Base.@constprop :aggressive @inline Base.:^(
a::CuTracedRNumber{T,A}, b::Integer
) where {T,A} = ^(a[], b)

@inline Base.unsafe_trunc(::Type{T}, a::CuTracedRNumber) where {T} =
Base.unsafe_trunc(T, a[])

for jlop in (:(Base.:+), :(Base.:-), :(Base.isnan), :(Base.isfinite), :(Base.isinf))
@eval begin
@inline $jlop(a::CuTracedRNumber) = $jlop(a[])
end
end

Base.OneTo(x::CuTracedRNumber{<:Integer}) = Base.OneTo(x[])

@static if isdefined(Base, :unchecked_oneto)
function Base.unchecked_oneto(x::CuTracedRNumber{<:Integer})
return Base.unchecked_oneto(x[])
end
end

@inline function Base.convert(CT::Type{CuTracedRNumber{Float64,1}}, x::Number)
return CT(
Base.reinterpret(
Core.LLVMPtr{Float64,1},
Base.llvmcall(
(
"""define i8 addrspace(1)* @entry(double %d) alwaysinline {
%a = alloca double
store atomic double %d, double* %a release, align 8
%bc = bitcast double* %a to i8*
%ac = addrspacecast i8* %bc to i8 addrspace(1)*
ret i8 addrspace(1)* %ac
}
""",
"entry",
),
Core.LLVMPtr{UInt8,1},
Tuple{Float64},
convert(Float64, x),
),
),
)
end

@inline function Base.convert(CT::Type{CuTracedRNumber{Float32,1}}, x::Number)
return CT(
Base.reinterpret(
Core.LLVMPtr{Float32,1},
Base.llvmcall(
(
"""define i8 addrspace(1)* @entry(float %d) alwaysinline {
%a = alloca float
store atomic float %d, float* %a release, align 4
%bc = bitcast float* %a to i8*
%ac = addrspacecast i8* %bc to i8 addrspace(1)*
ret i8 addrspace(1)* %ac
}
""",
"entry",
),
Core.LLVMPtr{UInt8,1},
Tuple{Float32},
convert(Float32, x),
),
),
)
end

Base.convert(::Type{<:CuTracedRNumber{T}}, x::CuTracedRNumber{T}) where {T} = x

Base.one(a::CuTracedRNumber) = one(a[])
Base.one(::Type{<:CuTracedRNumber{T,A}}) where {T,A} = one(T)
Base.zero(a::CuTracedRNumber) = zero(a[])
Base.zero(::Type{<:CuTracedRNumber{T,A}}) where {T,A} = zero(T)

Base.@nospecializeinfer function Base.promote_rule(
@nospecialize(a::Type{<:CuTracedRNumber{T}}),
@nospecialize(b::Type{<:CuTracedRNumber{T2}})
) where {T,T2}
return promote_rule(T, T2)
end
Base.@nospecializeinfer function Base.promote_rule(
::Type{Any}, @nospecialize(b::Type{<:CuTracedRNumber})
)
return Any
end
Base.@nospecializeinfer function Base.promote_rule(
@nospecialize(a::Type{<:CuTracedRNumber}), ::Type{Any}
)
return Any
end
Base.@nospecializeinfer function Base.promote_rule(
@nospecialize(T2::Type), @nospecialize(b::Type{<:CuTracedRNumber{T}})
) where {T}
if T == T2
return T
else
return promote_rule(T, T2)
end
end
Base.@nospecializeinfer function Base.promote_rule(
@nospecialize(a::Type{<:CuTracedRNumber{T}}), @nospecialize(T2::Type)
) where {T}
if T == T2
return T
else
return promote_rule(T, T2)
end
end

Base.@nospecializeinfer function Reactant.promote_traced_type(
@nospecialize(a::Type{<:CuTracedRNumber{T,A}}),
@nospecialize(b::Type{<:CuTracedRNumber{T2,A}})
) where {T,T2,A}
return CuTracedRNumber{Reactant.promote_traced_type(T, T2),A}
end
Base.@nospecializeinfer function Reactant.promote_traced_type(
::Type{Any}, @nospecialize(b::Type{<:CuTracedRNumber})
)
return Any
end
Base.@nospecializeinfer function Reactant.promote_traced_type(
@nospecialize(a::Type{<:CuTracedRNumber}), ::Type{Any}
)
return Any
end
Base.@nospecializeinfer function Reactant.promote_traced_type(
@nospecialize(T2::Type), ::Type{<:CuTracedRNumber{T,A}}
) where {T,A}
if T == T2
return CuTracedRNumber{T,A}
else
return CuTracedRNumber{Reactant.promote_trace_type(T, T2),A}
end
end
Base.@nospecializeinfer function Reactant.promote_traced_type(
::Type{<:CuTracedRNumber{T,A}}, @nospecialize(T2::Type)
) where {T,A}
if T == T2
return CuTracedRNumber{T,A}
else
return CuTracedRNumber{Reactant.promote_trace_type(T, T2),A}
end
end

function Base.show(io::IO, a::AT) where {AT<:CuTracedArray}
CUDA.Printf.@printf(io, "%s cu traced array at %p", join(size(a), '×'), Int(pointer(a)))
end
Expand All @@ -249,47 +49,6 @@ function Base.show(io::IO, a::AT) where {AT<:CuTracedRNumber}
)
end

## array interface

Base.elsize(::Type{<:CuTracedArray{T}}) where {T} = sizeof(T)
Base.size(g::CuTracedArray{T,N,A,Size}) where {T,N,A,Size} = Size
Base.sizeof(x::CuTracedArray) = Base.elsize(x) * length(x)
function Base.pointer(x::CuTracedArray{T,<:Any,A}) where {T,A}
return Base.unsafe_convert(Core.LLVMPtr{T,A}, x)
end
@inline function Base.pointer(x::CuTracedArray{T,<:Any,A}, i::Integer) where {T,A}
return Base.unsafe_convert(Core.LLVMPtr{T,A}, x) + Base._memory_offset(x, i)
end

## conversions

function Base.unsafe_convert(
::Type{Core.LLVMPtr{T,A}}, x::CuTracedArray{T,<:Any,A}
) where {T,A}
return x.ptr
end

# TODO: arrays as allocated by the CUDA APIs are 256-byte aligned. we should keep track of
# this information, because it enables optimizations like Load Store Vectorization
# (cfr. shared memory and its wider-than-datatype alignment)

@generated function alignment(::CuTracedArray{T}) where {T}
if Base.isbitsunion(T)
_, sz, al = Base.uniontype_layout(T)
al
else
Base.datatype_alignment(T)
end
end
@generated function alignment(::CuTracedRNumber{T}) where {T}
if Base.isbitsunion(T)
_, sz, al = Base.uniontype_layout(T)
al
else
Base.datatype_alignment(T)
end
end

## indexing intrinsics

CUDA.@device_function @inline function arrayref(
Expand Down Expand Up @@ -382,26 +141,11 @@ end

## indexing

Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear()

Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} =
arrayref(A, i1)
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} =
arrayset(A, convert(T, x)::T, i1)

# preserve the specific integer type when indexing device arrays,
# to avoid extending 32-bit hardware indices to 64-bit.
Base.to_index(::CuTracedArray, i::Integer) = i

# Base doesn't like Integer indices, so we need our own ND get and setindex! routines.
# See also: https://github.com/JuliaLang/julia/pull/42289
Base.@propagate_inbounds Base.getindex(
A::CuTracedArray, I::Union{Integer,CartesianIndex}...
) = A[Base._to_linear_index(A, to_indices(A, I)...)]
Base.@propagate_inbounds Base.setindex!(
A::CuTracedArray, x, I::Union{Integer,CartesianIndex}...
) = A[Base._to_linear_index(A, to_indices(A, I)...)] = x

## const indexing

"""
Expand All @@ -415,8 +159,8 @@ This API can only be used on devices with compute capability 3.5 or higher.
!!! warning
Experimental API. Subject to change without deprecation.
"""
struct Const{T,N,AS} <: DenseArray{T,N}
a::CuTracedArray{T,N,AS}
struct Const{T,N,AS,Size} <: DenseArray{T,N}
a::CuTracedArray{T,N,AS,Size}
end
Base.Experimental.Const(A::CuTracedArray) = Const(A)

Expand All @@ -430,14 +174,6 @@ Base.@propagate_inbounds ldg(A::CuTracedArray, i1::Integer) = const_arrayref(A,

## other

@inline function Base.iterate(A::CuTracedArray, i=1)
if (i % UInt) - 1 < length(A)
(@inbounds A[i], i + 1)
else
nothing
end
end

function Base.reinterpret(::Type{T}, a::CuTracedArray{S,N,A}) where {T,S,N,A}
err = GPUArrays._reinterpret_exception(T, a)
err === nothing || throw(err)
Expand Down Expand Up @@ -467,7 +203,7 @@ function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M
if N == M && dims == size(a)
return a
end
return _derived_array(a, T, dims)
return _derived_array(a, T, dims) # XXX: what is _derived_array?
end

struct ReactantKernelAdaptor end
Expand Down
2 changes: 1 addition & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2087,7 +2087,7 @@ function compile_mlir!(

pad_op = MLIR.Dialects.stablehlo.pad(
results[i],
Reactant.TracedUtils.promote_to(
Reactant.promote_to(
TracedRNumber{Reactant.unwrapped_eltype(res)}, 0
).mlir_data;
edge_padding_low=MLIR.IR.DenseArrayAttribute(
Expand Down
Loading
Loading