diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 388d99d148..95a59e1c0a 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -3,6 +3,8 @@ module ReactantCUDAExt using Reactant: Reactant, TracedRArray, AnyConcretePJRTArray, MLIR, TracedRNumber using Reactant.Compiler: raising using Reactant.Ops: @opcall +import Reactant: CuTracedArray, CuTracedRNumber +using Reactant.CuTracedOverloads: alignment using Adapt: Adapt, adapt using CUDA: CUDA, CuDim, DenseCuArray, unsafe_cached_load @@ -17,228 +19,26 @@ const KA = KernelAbstractions Reactant.is_extension_loaded(::Val{:CUDA}) = true -struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} - ptr::Core.LLVMPtr{T,A} - - function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size} - gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()] - push!(gc_vec, xs) - @assert gc_vec[end] === xs - ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)) - return new(ptr) - end +function CuTracedArray{T,N,A,Size}(xs::TracedRArray) where {T,N,A,Size} + gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()] + push!(gc_vec, xs) + @assert gc_vec[end] === xs + ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)) + return CuTracedArray{T,N,A,Size}(ptr) end -Reactant.use_overlayed_version(::CuTracedArray) = true - -struct CuTracedRNumber{T,A} <: Number - ptr::Core.LLVMPtr{T,A} - - function CuTracedRNumber{T,A}(xs::TracedRNumber) where {T,A} - gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()] - push!(gc_vec, xs) - @assert gc_vec[end] === xs - ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)) - return new(ptr) - end - function CuTracedRNumber{T,A}(ptr::Core.LLVMPtr{T,A}) where {T,A} - return new(ptr) - end +function CuTracedRNumber{T,A}(xs::TracedRNumber) where {T,A} + gc_vec = Reactant.Compiler.context_gc_vector[MLIR.IR.context()] + push!(gc_vec, xs) + @assert gc_vec[end] === xs + ptr = Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs)) + return CuTracedRNumber{T,A}(ptr) end -Reactant.use_overlayed_version(::CuTracedRNumber) = true - -Base.@nospecializeinfer Reactant.is_traced_number( - @nospecialize(T::Type{<:CuTracedRNumber}) -) = true -Reactant.unwrapped_eltype(::Type{<:CuTracedRNumber{T}}) where {T} = T - @inline CuTracedRNumber{T,A}(val::Number) where {T,A} = convert(CuTracedRNumber{T,A}, val) -function Base.getindex(RN::CuTracedRNumber{T,A}) where {T,A} - align = alignment(RN) - return @inbounds unsafe_load(RN.ptr, 1, Val(align)) -end - Base.convert(::Type{T}, RN::CuTracedRNumber) where {T<:Number} = convert(T, getindex(RN)) -for jlop in ( - :(Base.min), - :(Base.mod), - :(Base.max), - :(Base.:+), - :(Base.:-), - :(Base.:*), - :(Base.:/), - :(Base.:^), - :(Base.rem), - :(Base.isless), - :(Base.:(==)), - :(Base.:(!=)), -) - @eval begin - @inline $jlop(a::CuTracedRNumber, b::CuTracedRNumber) = $jlop(a[], b[]) - @inline $jlop(a::CuTracedRNumber{T,A}, b::Number) where {T,A} = $jlop(a[], b) - @inline $jlop(a::Number, b::CuTracedRNumber{T,A}) where {T,A} = $jlop(a, b[]) - end -end - -@inline Base.ifelse(cond::Bool, a, b::CuTracedRNumber) = ifelse(cond, a, b[]) -@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b) = ifelse(cond, a[], b) -@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b::CuTracedRNumber) = - ifelse(cond, a[], b[]) -@inline Base.ifelse(cond::CuTracedRNumber, a, b) = ifelse(cond[], a, b) -@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b) = ifelse(cond[], a[], b) -@inline Base.ifelse(cond::CuTracedRNumber, a, b::CuTracedRNumber) = ifelse(cond[], a, b[]) -@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b::CuTracedRNumber) = - ifelse(cond[], a[], b[]) - -Base.@constprop :aggressive @inline Base.:^( - a::CuTracedRNumber{T,A}, b::Integer -) where {T,A} = ^(a[], b) - -@inline Base.unsafe_trunc(::Type{T}, a::CuTracedRNumber) where {T} = - Base.unsafe_trunc(T, a[]) - -for jlop in (:(Base.:+), :(Base.:-), :(Base.isnan), :(Base.isfinite), :(Base.isinf)) - @eval begin - @inline $jlop(a::CuTracedRNumber) = $jlop(a[]) - end -end - -Base.OneTo(x::CuTracedRNumber{<:Integer}) = Base.OneTo(x[]) - -@static if isdefined(Base, :unchecked_oneto) - function Base.unchecked_oneto(x::CuTracedRNumber{<:Integer}) - return Base.unchecked_oneto(x[]) - end -end - -@inline function Base.convert(CT::Type{CuTracedRNumber{Float64,1}}, x::Number) - return CT( - Base.reinterpret( - Core.LLVMPtr{Float64,1}, - Base.llvmcall( - ( - """define i8 addrspace(1)* @entry(double %d) alwaysinline { - %a = alloca double - store atomic double %d, double* %a release, align 8 - %bc = bitcast double* %a to i8* - %ac = addrspacecast i8* %bc to i8 addrspace(1)* - ret i8 addrspace(1)* %ac - } - """, - "entry", - ), - Core.LLVMPtr{UInt8,1}, - Tuple{Float64}, - convert(Float64, x), - ), - ), - ) -end - -@inline function Base.convert(CT::Type{CuTracedRNumber{Float32,1}}, x::Number) - return CT( - Base.reinterpret( - Core.LLVMPtr{Float32,1}, - Base.llvmcall( - ( - """define i8 addrspace(1)* @entry(float %d) alwaysinline { - %a = alloca float - store atomic float %d, float* %a release, align 4 - %bc = bitcast float* %a to i8* - %ac = addrspacecast i8* %bc to i8 addrspace(1)* - ret i8 addrspace(1)* %ac - } - """, - "entry", - ), - Core.LLVMPtr{UInt8,1}, - Tuple{Float32}, - convert(Float32, x), - ), - ), - ) -end - -Base.convert(::Type{<:CuTracedRNumber{T}}, x::CuTracedRNumber{T}) where {T} = x - -Base.one(a::CuTracedRNumber) = one(a[]) -Base.one(::Type{<:CuTracedRNumber{T,A}}) where {T,A} = one(T) -Base.zero(a::CuTracedRNumber) = zero(a[]) -Base.zero(::Type{<:CuTracedRNumber{T,A}}) where {T,A} = zero(T) - -Base.@nospecializeinfer function Base.promote_rule( - @nospecialize(a::Type{<:CuTracedRNumber{T}}), - @nospecialize(b::Type{<:CuTracedRNumber{T2}}) -) where {T,T2} - return promote_rule(T, T2) -end -Base.@nospecializeinfer function Base.promote_rule( - ::Type{Any}, @nospecialize(b::Type{<:CuTracedRNumber}) -) - return Any -end -Base.@nospecializeinfer function Base.promote_rule( - @nospecialize(a::Type{<:CuTracedRNumber}), ::Type{Any} -) - return Any -end -Base.@nospecializeinfer function Base.promote_rule( - @nospecialize(T2::Type), @nospecialize(b::Type{<:CuTracedRNumber{T}}) -) where {T} - if T == T2 - return T - else - return promote_rule(T, T2) - end -end -Base.@nospecializeinfer function Base.promote_rule( - @nospecialize(a::Type{<:CuTracedRNumber{T}}), @nospecialize(T2::Type) -) where {T} - if T == T2 - return T - else - return promote_rule(T, T2) - end -end - -Base.@nospecializeinfer function Reactant.promote_traced_type( - @nospecialize(a::Type{<:CuTracedRNumber{T,A}}), - @nospecialize(b::Type{<:CuTracedRNumber{T2,A}}) -) where {T,T2,A} - return CuTracedRNumber{Reactant.promote_traced_type(T, T2),A} -end -Base.@nospecializeinfer function Reactant.promote_traced_type( - ::Type{Any}, @nospecialize(b::Type{<:CuTracedRNumber}) -) - return Any -end -Base.@nospecializeinfer function Reactant.promote_traced_type( - @nospecialize(a::Type{<:CuTracedRNumber}), ::Type{Any} -) - return Any -end -Base.@nospecializeinfer function Reactant.promote_traced_type( - @nospecialize(T2::Type), ::Type{<:CuTracedRNumber{T,A}} -) where {T,A} - if T == T2 - return CuTracedRNumber{T,A} - else - return CuTracedRNumber{Reactant.promote_trace_type(T, T2),A} - end -end -Base.@nospecializeinfer function Reactant.promote_traced_type( - ::Type{<:CuTracedRNumber{T,A}}, @nospecialize(T2::Type) -) where {T,A} - if T == T2 - return CuTracedRNumber{T,A} - else - return CuTracedRNumber{Reactant.promote_trace_type(T, T2),A} - end -end - function Base.show(io::IO, a::AT) where {AT<:CuTracedArray} CUDA.Printf.@printf(io, "%s cu traced array at %p", join(size(a), '×'), Int(pointer(a))) end @@ -249,47 +49,6 @@ function Base.show(io::IO, a::AT) where {AT<:CuTracedRNumber} ) end -## array interface - -Base.elsize(::Type{<:CuTracedArray{T}}) where {T} = sizeof(T) -Base.size(g::CuTracedArray{T,N,A,Size}) where {T,N,A,Size} = Size -Base.sizeof(x::CuTracedArray) = Base.elsize(x) * length(x) -function Base.pointer(x::CuTracedArray{T,<:Any,A}) where {T,A} - return Base.unsafe_convert(Core.LLVMPtr{T,A}, x) -end -@inline function Base.pointer(x::CuTracedArray{T,<:Any,A}, i::Integer) where {T,A} - return Base.unsafe_convert(Core.LLVMPtr{T,A}, x) + Base._memory_offset(x, i) -end - -## conversions - -function Base.unsafe_convert( - ::Type{Core.LLVMPtr{T,A}}, x::CuTracedArray{T,<:Any,A} -) where {T,A} - return x.ptr -end - -# TODO: arrays as allocated by the CUDA APIs are 256-byte aligned. we should keep track of -# this information, because it enables optimizations like Load Store Vectorization -# (cfr. shared memory and its wider-than-datatype alignment) - -@generated function alignment(::CuTracedArray{T}) where {T} - if Base.isbitsunion(T) - _, sz, al = Base.uniontype_layout(T) - al - else - Base.datatype_alignment(T) - end -end -@generated function alignment(::CuTracedRNumber{T}) where {T} - if Base.isbitsunion(T) - _, sz, al = Base.uniontype_layout(T) - al - else - Base.datatype_alignment(T) - end -end - ## indexing intrinsics CUDA.@device_function @inline function arrayref( @@ -382,26 +141,11 @@ end ## indexing -Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear() - Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = arrayref(A, i1) Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = arrayset(A, convert(T, x)::T, i1) -# preserve the specific integer type when indexing device arrays, -# to avoid extending 32-bit hardware indices to 64-bit. -Base.to_index(::CuTracedArray, i::Integer) = i - -# Base doesn't like Integer indices, so we need our own ND get and setindex! routines. -# See also: https://github.com/JuliaLang/julia/pull/42289 -Base.@propagate_inbounds Base.getindex( - A::CuTracedArray, I::Union{Integer,CartesianIndex}... -) = A[Base._to_linear_index(A, to_indices(A, I)...)] -Base.@propagate_inbounds Base.setindex!( - A::CuTracedArray, x, I::Union{Integer,CartesianIndex}... -) = A[Base._to_linear_index(A, to_indices(A, I)...)] = x - ## const indexing """ @@ -415,8 +159,8 @@ This API can only be used on devices with compute capability 3.5 or higher. !!! warning Experimental API. Subject to change without deprecation. """ -struct Const{T,N,AS} <: DenseArray{T,N} - a::CuTracedArray{T,N,AS} +struct Const{T,N,AS,Size} <: DenseArray{T,N} + a::CuTracedArray{T,N,AS,Size} end Base.Experimental.Const(A::CuTracedArray) = Const(A) @@ -430,14 +174,6 @@ Base.@propagate_inbounds ldg(A::CuTracedArray, i1::Integer) = const_arrayref(A, ## other -@inline function Base.iterate(A::CuTracedArray, i=1) - if (i % UInt) - 1 < length(A) - (@inbounds A[i], i + 1) - else - nothing - end -end - function Base.reinterpret(::Type{T}, a::CuTracedArray{S,N,A}) where {T,S,N,A} err = GPUArrays._reinterpret_exception(T, a) err === nothing || throw(err) @@ -467,7 +203,7 @@ function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M if N == M && dims == size(a) return a end - return _derived_array(a, T, dims) + return _derived_array(a, T, dims) # XXX: what is _derived_array? end struct ReactantKernelAdaptor end diff --git a/src/Compiler.jl b/src/Compiler.jl index 942b2bdfb8..bd3d48fbe4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -2087,7 +2087,7 @@ function compile_mlir!( pad_op = MLIR.Dialects.stablehlo.pad( results[i], - Reactant.TracedUtils.promote_to( + Reactant.promote_to( TracedRNumber{Reactant.unwrapped_eltype(res)}, 0 ).mlir_data; edge_padding_low=MLIR.IR.DenseArrayAttribute( diff --git a/src/CuTracedOverloads.jl b/src/CuTracedOverloads.jl new file mode 100644 index 0000000000..32c3af5e11 --- /dev/null +++ b/src/CuTracedOverloads.jl @@ -0,0 +1,220 @@ +module CuTracedOverloads + +using Reactant: CuTracedArray, CuTracedRNumber + +# TODO: arrays as allocated by the CUDA APIs are 256-byte aligned. we should keep track of +# this information, because it enables optimizations like Load Store Vectorization +# (cfr. shared memory and its wider-than-datatype alignment) + +@generated function alignment(::CuTracedArray{T}) where {T} + if Base.isbitsunion(T) + _, sz, al = Base.uniontype_layout(T) + al + else + Base.datatype_alignment(T) + end +end + +@generated function alignment(::CuTracedRNumber{T}) where {T} + if Base.isbitsunion(T) + _, sz, al = Base.uniontype_layout(T) + al + else + Base.datatype_alignment(T) + end +end + +Base.@nospecializeinfer function Base.promote_rule( + @nospecialize(a::Type{<:CuTracedRNumber{T}}), + @nospecialize(b::Type{<:CuTracedRNumber{T2}}) +) where {T,T2} + return promote_rule(T, T2) +end +Base.@nospecializeinfer function Base.promote_rule( + ::Type{Any}, @nospecialize(b::Type{<:CuTracedRNumber}) +) + return Any +end +Base.@nospecializeinfer function Base.promote_rule( + @nospecialize(a::Type{<:CuTracedRNumber}), ::Type{Any} +) + return Any +end +Base.@nospecializeinfer function Base.promote_rule( + @nospecialize(T2::Type), @nospecialize(b::Type{<:CuTracedRNumber{T}}) +) where {T} + if T == T2 + return T + else + return promote_rule(T, T2) + end +end +Base.@nospecializeinfer function Base.promote_rule( + @nospecialize(a::Type{<:CuTracedRNumber{T}}), @nospecialize(T2::Type) +) where {T} + if T == T2 + return T + else + return promote_rule(T, T2) + end +end + +@inline function Base.convert(CT::Type{CuTracedRNumber{Float64,1}}, x::Number) + return CT( + Base.reinterpret( + Core.LLVMPtr{Float64,1}, + Base.llvmcall( + ( + """define i8 addrspace(1)* @entry(double %d) alwaysinline { + %a = alloca double + store atomic double %d, double* %a release, align 8 + %bc = bitcast double* %a to i8* + %ac = addrspacecast i8* %bc to i8 addrspace(1)* + ret i8 addrspace(1)* %ac + } + """, + "entry", + ), + Core.LLVMPtr{UInt8,1}, + Tuple{Float64}, + convert(Float64, x), + ), + ), + ) +end + +@inline function Base.convert(CT::Type{CuTracedRNumber{Float32,1}}, x::Number) + return CT( + Base.reinterpret( + Core.LLVMPtr{Float32,1}, + Base.llvmcall( + ( + """define i8 addrspace(1)* @entry(float %d) alwaysinline { + %a = alloca float + store atomic float %d, float* %a release, align 4 + %bc = bitcast float* %a to i8* + %ac = addrspacecast i8* %bc to i8 addrspace(1)* + ret i8 addrspace(1)* %ac + } + """, + "entry", + ), + Core.LLVMPtr{UInt8,1}, + Tuple{Float32}, + convert(Float32, x), + ), + ), + ) +end + +Base.convert(::Type{<:CuTracedRNumber{T}}, x::CuTracedRNumber{T}) where {T} = x + +## array interface + +Base.elsize(::Type{<:CuTracedArray{T}}) where {T} = sizeof(T) +Base.size(::CuTracedArray{T,N,A,Size}) where {T,N,A,Size} = Size +Base.sizeof(x::CuTracedArray) = Base.elsize(x) * length(x) +function Base.pointer(x::CuTracedArray{T,<:Any,A}) where {T,A} + return Base.unsafe_convert(Core.LLVMPtr{T,A}, x) +end +@inline function Base.pointer(x::CuTracedArray{T,<:Any,A}, i::Integer) where {T,A} + return Base.unsafe_convert(Core.LLVMPtr{T,A}, x) + Base._memory_offset(x, i) +end + +## conversions + +function Base.unsafe_convert( + ::Type{Core.LLVMPtr{T,A}}, x::CuTracedArray{T,<:Any,A} +) where {T,A} + return x.ptr +end + +## indexing + +Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear() + +function Base.getindex(RN::CuTracedRNumber{T,A}) where {T,A} + align = alignment(RN) + return @inbounds unsafe_load(RN.ptr, 1, Val(align)) +end + +# preserve the specific integer type when indexing device arrays, +# to avoid extending 32-bit hardware indices to 64-bit. +Base.to_index(::CuTracedArray, i::Integer) = i + +# Base doesn't like Integer indices, so we need our own ND get and setindex! routines. +# See also: https://github.com/JuliaLang/julia/pull/42289 +Base.@propagate_inbounds Base.getindex( + A::CuTracedArray, I::Union{Integer,CartesianIndex}... +) = A[Base._to_linear_index(A, to_indices(A, I)...)] +Base.@propagate_inbounds Base.setindex!( + A::CuTracedArray, x, I::Union{Integer,CartesianIndex}... +) = A[Base._to_linear_index(A, to_indices(A, I)...)] = x + +@inline function Base.iterate(A::CuTracedArray, i=1) + if (i % UInt) - 1 < length(A) + (@inbounds A[i], i + 1) + else + nothing + end +end + +## ops +for jlop in ( + :(Base.min), + :(Base.mod), + :(Base.max), + :(Base.:+), + :(Base.:-), + :(Base.:*), + :(Base.:/), + :(Base.:^), + :(Base.rem), + :(Base.isless), + :(Base.:(==)), + :(Base.:(!=)), +) + @eval begin + @inline $jlop(a::CuTracedRNumber, b::CuTracedRNumber) = $jlop(a[], b[]) + @inline $jlop(a::CuTracedRNumber{T,A}, b::Number) where {T,A} = $jlop(a[], b) + @inline $jlop(a::Number, b::CuTracedRNumber{T,A}) where {T,A} = $jlop(a, b[]) + end +end + +@inline Base.ifelse(cond::Bool, a, b::CuTracedRNumber) = ifelse(cond, a, b[]) +@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b) = ifelse(cond, a[], b) +@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b::CuTracedRNumber) = + ifelse(cond, a[], b[]) +@inline Base.ifelse(cond::CuTracedRNumber, a, b) = ifelse(cond[], a, b) +@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b) = ifelse(cond[], a[], b) +@inline Base.ifelse(cond::CuTracedRNumber, a, b::CuTracedRNumber) = ifelse(cond[], a, b[]) +@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b::CuTracedRNumber) = + ifelse(cond[], a[], b[]) + +Base.@constprop :aggressive @inline Base.:^( + a::CuTracedRNumber{T,A}, b::Integer +) where {T,A} = ^(a[], b) + +@inline Base.unsafe_trunc(::Type{T}, a::CuTracedRNumber) where {T} = + Base.unsafe_trunc(T, a[]) + +for jlop in (:(Base.:+), :(Base.:-), :(Base.isnan), :(Base.isfinite), :(Base.isinf)) + @eval begin + @inline $jlop(a::CuTracedRNumber) = $jlop(a[]) + end +end + +Base.OneTo(x::CuTracedRNumber{<:Integer}) = Base.OneTo(x[]) + +@static if isdefined(Base, :unchecked_oneto) + function Base.unchecked_oneto(x::CuTracedRNumber{<:Integer}) + return Base.unchecked_oneto(x[]) + end +end + +Base.one(a::CuTracedRNumber) = one(a[]) +Base.one(::Type{<:CuTracedRNumber{T,A}}) where {T,A} = one(T) +Base.zero(a::CuTracedRNumber) = zero(a[]) +Base.zero(::Type{<:CuTracedRNumber{T,A}}) where {T,A} = zero(T) + +end diff --git a/src/Reactant.jl b/src/Reactant.jl index 9e38f92347..cd2431027f 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -106,6 +106,7 @@ isa_traced_soa(::AbstractRange{<:TracedRNumber}) = true unwrapped_eltype(::Type{T}) where {T<:Number} = T unwrapped_eltype(::Type{<:RNumber{T}}) where {T} = T unwrapped_eltype(::Type{TracedRNumber{T}}) where {T} = T +unwrapped_eltype(::Type{<:CuTracedRNumber{T}}) where {T} = T unwrapped_eltype(::T) where {T<:Number} = T unwrapped_eltype(::RNumber{T}) where {T} = T @@ -132,6 +133,40 @@ function TracedRArray(data::MLIR.IR.Value) end promote_traced_type(a::Type, b::Type) = Base.promote_type(a, b) +Base.@nospecializeinfer function promote_traced_type( + @nospecialize(a::Type{<:CuTracedRNumber{T,A}}), + @nospecialize(b::Type{<:CuTracedRNumber{T2,A}}) +) where {T,T2,A} + return CuTracedRNumber{promote_traced_type(T, T2),A} +end +Base.@nospecializeinfer function promote_traced_type( + ::Type{Any}, @nospecialize(b::Type{<:CuTracedRNumber}) +) + return Any +end +Base.@nospecializeinfer function promote_traced_type( + @nospecialize(a::Type{<:CuTracedRNumber}), ::Type{Any} +) + return Any +end +Base.@nospecializeinfer function promote_traced_type( + @nospecialize(T2::Type), ::Type{<:CuTracedRNumber{T,A}} +) where {T,A} + if T == T2 + return CuTracedRNumber{T,A} + else + return CuTracedRNumber{Reactant.promote_trace_type(T, T2),A} + end +end +Base.@nospecializeinfer function promote_traced_type( + ::Type{<:CuTracedRNumber{T,A}}, @nospecialize(T2::Type) +) where {T,A} + if T == T2 + return CuTracedRNumber{T,A} + else + return CuTracedRNumber{Reactant.promote_trace_type(T, T2),A} + end +end aos_to_soa(x::AbstractArray) = x @@ -189,6 +224,7 @@ include("TracedUtils.jl") include("TracedRNumber.jl") include("TracedRArray.jl") include("TracedRange.jl") +include("CuTracedOverloads.jl") include("Indexing.jl") include("ConcreteRArray.jl") @@ -208,7 +244,9 @@ use_overlayed_version(::MissingTracedValue) = true use_overlayed_version(rng::ReactantRNG) = use_overlayed_version(rng.seed) use_overlayed_version(::AbstractArray{<:TracedRNumber}) = true use_overlayed_version(::TracedRArray) = true +use_overlayed_version(::CuTracedArray) = true use_overlayed_version(::TracedRNumber) = true +use_overlayed_version(::CuTracedRNumber) = true use_overlayed_version(::TracedStepRangeLen) = true use_overlayed_version(::TracedUnitRange) = true function use_overlayed_version(x::AbstractArray) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 916d0d13ee..082b49c10b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -4,7 +4,7 @@ using Base: Broadcast using Base.Broadcast: Broadcasted, AbstractArrayStyle, instantiate using ..Reactant: Reactant, TracedRArray, TracedRNumber, AnyTracedRArray, AnyTracedRVector -using ..Reactant: MLIR, unwrapped_eltype +using ..Reactant: MLIR, unwrapped_eltype, CuTracedRNumber, CuTracedArray using ..Ops: @opcall using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array @@ -1181,14 +1181,13 @@ end for (aType, xType) in ( (AbstractRange{<:Real}, TracedRNumber{<:Real}), + (AbstractRange{<:Real}, CuTracedRNumber{<:Real}), (AbstractRange{<:TracedRNumber}, Real), + (AbstractRange{<:CuTracedRNumber}, Real), (AbstractRange{<:TracedRNumber}, TracedRNumber{<:Real}), + (AbstractRange{<:CuTracedRNumber}, CuTracedRNumber{<:Real}), ) - @eval function Base.searchsortedfirst( - a::$(aType), x::$(xType), o::Base.DirectOrdering - )::TracedRNumber{keytype(a)} - x = TracedUtils.promote_to(TracedRNumber{Reactant.unwrapped_eltype(a)}, x) - + @eval function Base.searchsortedfirst(a::$(aType), x::$(xType), o::Base.DirectOrdering) f, h, l = first(a), step(a), last(a) n = round(Int, (x - f) / h + 1) @@ -1198,27 +1197,27 @@ for (aType, xType) in ( ifelse( (h == 0) | Base.Order.lt(o, l, x), length(a) + 1, - ifelse(Base.Order.lt(o, @allowscalar(a[n]), x), n + 1, n), + ifelse(Base.Order.lt(o, a[n], x), n + 1, n), ), ) end end function overloaded_searchsortedfirst(v, x, lo::T, hi::T, o::Base.Ordering) where {T} - v = TracedUtils.broadcast_to_size(v, size(v)) - x = TracedUtils.promote_to(TracedRNumber{Reactant.unwrapped_eltype(v)}, x) + v = Reactant.broadcast_to_size(v, size(v)) + x = Reactant.promote_to(TracedRNumber{Reactant.unwrapped_eltype(v)}, x) return sum(T.(__lt(o, v[lo:hi], x)); init=lo) end function overloaded_searchsortedlast(v, x, lo::T, hi::T, o::Base.Ordering) where {T} - v = TracedUtils.broadcast_to_size(v, size(v)) - x = TracedUtils.promote_to(TracedRNumber{Reactant.unwrapped_eltype(v)}, x) + v = Reactant.broadcast_to_size(v, size(v)) + x = Reactant.promote_to(TracedRNumber{Reactant.unwrapped_eltype(v)}, x) return sum(T.(.!(__lt(o, x, v[lo:hi]))); init=lo - 1) end function overloaded_searchsorted(v, x, lo::T, hi::T, o::Base.Ordering) where {T} - v = TracedUtils.broadcast_to_size(v, size(v)) - x = TracedUtils.promote_to(TracedRNumber{Reactant.unwrapped_eltype(v)}, x) + v = Reactant.broadcast_to_size(v, size(v)) + x = Reactant.promote_to(TracedRNumber{Reactant.unwrapped_eltype(v)}, x) firstidx = overloaded_searchsortedfirst(v, x, lo, hi, o) lastidx = overloaded_searchsortedlast(v, x, lo, hi, o) return Reactant.TracedUnitRange(firstidx, lastidx) @@ -1227,21 +1226,24 @@ end for op in (:searchsortedfirst, :searchsortedlast, :searchsorted) rop = Symbol(:overloaded_, op) - @eval begin - function Base.$(op)( - x::AnyTracedRVector, v, lo::T, hi::T, o::Base.Ordering - ) where {T<:Integer} - return $(rop)(x, v, lo, hi, o) - end - function Base.$(op)( - x::AbstractVector, v::TracedRNumber, lo::T, hi::T, o::Base.Ordering - ) where {T<:Integer} - return $(rop)(x, v, lo, hi, o) - end - function Base.$(op)( - x::AnyTracedRVector, v::TracedRNumber, lo::T, hi::T, o::Base.Ordering - ) where {T<:Integer} - return $(rop)(x, v, lo, hi, o) + @eval function Base.$(op)( + x::AnyTracedRVector, v, lo::T, hi::T, o::Base.Ordering + ) where {T<:Integer} + return $(rop)(x, v, lo, hi, o) + end + + for numType in (:TracedRNumber, :CuTracedRNumber) + @eval begin + function Base.$(op)( + x::AbstractVector, v::$(numType), lo::T, hi::T, o::Base.Ordering + ) where {T<:Integer} + return $(rop)(x, v, lo, hi, o) + end + function Base.$(op)( + x::AnyTracedRVector, v::$(numType), lo::T, hi::T, o::Base.Ordering + ) where {T<:Integer} + return $(rop)(x, v, lo, hi, o) + end end end end diff --git a/src/Tracing.jl b/src/Tracing.jl index 015a2ccff7..2c0953210e 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -14,6 +14,7 @@ end is_traced_number(x::Type) = false Base.@nospecializeinfer is_traced_number(@nospecialize(T::Type{<:TracedRNumber})) = true +Base.@nospecializeinfer is_traced_number(@nospecialize(T::Type{<:CuTracedRNumber})) = true function traced_type_inner end diff --git a/src/Types.jl b/src/Types.jl index cc257c4ebf..f66429fce6 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -58,6 +58,13 @@ end @leaf TracedRNumber +## CuTracedRNumber +struct CuTracedRNumber{T,A} <: RNumber{T} + ptr::Core.LLVMPtr{T,A} +end + +@leaf CuTracedRNumber + ## TracedRArray mutable struct TracedRArray{T,N} <: RArray{TracedRNumber{T},N} paths::Tuple @@ -82,6 +89,16 @@ end @leaf TracedRArray Adapt.parent_type(::Type{TracedRArray{T,N}}) where {T,N} = TracedRArray{T,N} +## CuTracedArray +struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} + ptr::Core.LLVMPtr{T,A} +end + +@leaf CuTracedArray +function Adapt.parent_type(::Type{CuTracedArray{T,N,A,Size}}) where {T,N,A,Size} + return CuTracedArray{T,N,A,Size} +end + ## TracedStepRangeLen struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T} ref::R diff --git a/test/integration/cuda.jl b/test/integration/cuda.jl index 71c8254fa4..aeae8a2e4a 100644 --- a/test/integration/cuda.jl +++ b/test/integration/cuda.jl @@ -2,10 +2,8 @@ using Reactant using Test using CUDA -const ReactantCUDAExt = Base.get_extension(Reactant, :ReactantCUDAExt) - @testset "Promote CuTraced" begin - TFT = ReactantCUDAExt.CuTracedRNumber{Float64,1} + TFT = Reactant.CuTracedRNumber{Float64,1} FT = Float64 @test Reactant.promote_traced_type(TFT, FT) == TFT @test Base.promote_type(TFT, FT) == FT @@ -197,10 +195,16 @@ end oA = collect(Float64, 1:1:64) A = Reactant.to_rarray(oA) B = ConcreteRNumber(3.1) - @test begin + + @testset "raise = default" begin @jit searchsorted!(A, B) - all(Array(A) .≈ 311) - end broken = contains(string(Reactant.devices()[1]), "TPU") + @test all(Array(A) .≈ 311) + end + + @testset "raise = true" begin + @jit raise = true searchsorted!(A, B) + @test all(Array(A) .≈ 311) + end end function convert_mul_kernel!(Gu, w::FT) where {FT}