Skip to content

Commit ce1060e

Browse files
committed
feat: use cuda cc functions directly [skip ci]
1 parent bd15236 commit ce1060e

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,14 +1460,6 @@ function Reactant.make_tracer(
14601460
return newa
14611461
end
14621462

1463-
function __init__()
1464-
if CUDA.functional() && !Reactant.precompiling()
1465-
cap = CUDA.capability(CUDA.device())
1466-
Reactant.Compiler.cubinChip[] = "sm_$(cap.major)$(cap.minor)"
1467-
end
1468-
return nothing
1469-
end
1470-
14711463
# In Julia v1.11.3 precompiling this module caches bad code:
14721464
# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
14731465
@static if !Sys.isapple()

src/Compiler.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,8 @@ const cubinChip = Ref{String}("sm_60")
14291429
const cubinFormat = Ref{String}("bin")
14301430
const cuindexBitWidth = Ref{Int}(32)
14311431
const cuOptLevel = Ref{Int}(2)
1432+
const cuWarpSize = Ref{Int}(32)
1433+
14321434
# Wgatever the relevant highest version from our LLVM is within NVPTX.td
14331435
# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684
14341436
# We see relevant ptx version is CUDA 12.6 -> 85

src/xla/XLA.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,15 @@ for runtime in (:PJRT, :IFRT)
234234
)
235235
state.clients["cuda"] = gpu
236236
state.default_client = gpu
237+
238+
# set values for cuda. This is being done here since we need cuda
239+
# to be initialized before we can use it. initializing the devices
240+
# implicitly initializes cuda.
241+
cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32
242+
cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32
243+
Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)"
244+
245+
Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32
237246
catch e
238247
println(stdout, e)
239248
end

0 commit comments

Comments
 (0)