File tree Expand file tree Collapse file tree 3 files changed +11
-8
lines changed Expand file tree Collapse file tree 3 files changed +11
-8
lines changed Original file line number Diff line number Diff line change @@ -1460,14 +1460,6 @@ function Reactant.make_tracer(
1460
1460
return newa
1461
1461
end
1462
1462
1463
- function __init__ ()
1464
- if CUDA. functional () && ! Reactant. precompiling ()
1465
- cap = CUDA. capability (CUDA. device ())
1466
- Reactant. Compiler. cubinChip[] = " sm_$(cap. major)$(cap. minor) "
1467
- end
1468
- return nothing
1469
- end
1470
-
1471
1463
# In Julia v1.11.3 precompiling this module caches bad code:
1472
1464
# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
1473
1465
@static if ! Sys. isapple ()
Original file line number Diff line number Diff line change @@ -1429,6 +1429,8 @@ const cubinChip = Ref{String}("sm_60")
1429
1429
const cubinFormat = Ref {String} (" bin" )
1430
1430
const cuindexBitWidth = Ref {Int} (32 )
1431
1431
const cuOptLevel = Ref {Int} (2 )
1432
+ const cuWarpSize = Ref {Int} (32 )
1433
+
1432
1434
# Wgatever the relevant highest version from our LLVM is within NVPTX.td
1433
1435
# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684
1434
1436
# We see relevant ptx version is CUDA 12.6 -> 85
Original file line number Diff line number Diff line change @@ -234,6 +234,15 @@ for runtime in (:PJRT, :IFRT)
234
234
)
235
235
state. clients[" cuda" ] = gpu
236
236
state. default_client = gpu
237
+
238
+ # set values for cuda. This is being done here since we need cuda
239
+ # to be initialized before we can use it. initializing the devices
240
+ # implicitly initializes cuda.
241
+ cc_major = @ccall MLIR. API. mlir_c. ReactantCudaDeviceGetComputeCapalilityMajor ():: Int32
242
+ cc_minor = @ccall MLIR. API. mlir_c. ReactantCudaDeviceGetComputeCapalilityMinor ():: Int32
243
+ Reactant. Compiler. cubinChip[] = " sm_$(cc_major)$(cc_minor) "
244
+
245
+ Reactant. Compiler. cuWarpSize[] = @ccall MLIR. API. mlir_c. ReactantCudaDeviceGetWarpSizeInThreads ():: Int32
237
246
catch e
238
247
println (stdout , e)
239
248
end
You can’t perform that action at this time.
0 commit comments