Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ PythonCall = "0.9.25"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.16"
Reactant_jll = "0.0.246"
Reactant_jll = "0.0.247"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/tutorials/partial-evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ The StableHLO IR code generated here is:
# output

module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
func.func @main(%arg0: tensor<i64> {enzymexla.memory_effects = []}, %arg1: tensor<i64> {enzymexla.memory_effects = []}) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
return %0 : tensor<i64>
}
Expand Down Expand Up @@ -101,7 +101,7 @@ variable input `%arg0`:
# output

module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<i64>) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
func.func @main(%arg0: tensor<i64> {enzymexla.memory_effects = []}) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
%c = stablehlo.constant dense<4> : tensor<i64>
%0 = stablehlo.add %arg0, %c : tensor<i64>
return %0 : tensor<i64>
Expand Down
8 changes: 0 additions & 8 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1460,14 +1460,6 @@ function Reactant.make_tracer(
return newa
end

function __init__()
if CUDA.functional() && !Reactant.precompiling()
cap = CUDA.capability(CUDA.device())
Reactant.Compiler.cubinChip[] = "sm_$(cap.major)$(cap.minor)"
end
return nothing
end

# In Julia v1.11.3 precompiling this module caches bad code:
# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
@static if !Sys.isapple()
Expand Down
12 changes: 11 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,8 @@ function optimization_passes(
"trivial_reduce_window_to_reduce_op",
"dot_general_add_distributive_simplify",
"dot_general_subtract_distributive_simplify",
"dus_to_dynamic_pad",
"dynamic_pad_to_pad",
]

if !compile_options.disable_auto_batching_passes
Expand All @@ -922,6 +924,7 @@ function optimization_passes(
"concat_insert_dim_reduce",
"concat_insert_dim_sort",
"concat_insert_dim_reduce_window",
"concat_insert_dim_elementwise",
"dot_general_slice_to_batch",
"gather_slice_to_batch",
"iota_slice_to_batch",
Expand Down Expand Up @@ -1071,6 +1074,7 @@ function optimization_passes(
"const_prop_through_barrier<16>",
"concat_const_prop<1>($max_constant_threshold)",
"dynamic_update_slice_const_prop($max_constant_threshold)",
"clamp_const_prop",
],
)

Expand Down Expand Up @@ -1105,7 +1109,8 @@ function optimization_passes(
"reshape_dus",
"dot_reshape_pad<1>",
"pad_dot_general<1>(0)",
"pad_dot_general<1>(1)",
# XXX: see https://github.com/EnzymeAD/Enzyme-JAX/issues/1445
# "pad_dot_general<1>(1)",
"reshape_pad",
"reshape_wrap",
"reshape_rotate",
Expand Down Expand Up @@ -1425,6 +1430,8 @@ const cubinChip = Ref{String}("sm_60")
const cubinFormat = Ref{String}("bin")
const cuindexBitWidth = Ref{Int}(32)
const cuOptLevel = Ref{Int}(2)
const cuWarpSize = Ref{Int}(32)

# Wgatever the relevant highest version from our LLVM is within NVPTX.td
# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684
# We see relevant ptx version is CUDA 12.6 -> 85
Expand Down Expand Up @@ -3516,6 +3523,9 @@ function compile_xla(
module_string = ""
end

# Drop some of our attributes
run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes")

if before_xla_optimizations
exec = nothing
hlo_modules = XLA.HloModule(mod)
Expand Down
9 changes: 9 additions & 0 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,15 @@ for runtime in (:PJRT, :IFRT)
)
state.clients["cuda"] = gpu
state.default_client = gpu

# set values for cuda. This is being done here since we need cuda
# to be initialized before we can use it. initializing the devices
# implicitly initializes cuda.
cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32
cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32
Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)"

Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32
catch e
println(stdout, e)
end
Expand Down
2 changes: 1 addition & 1 deletion test/buffer_donation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
hlo = @code_hlo(multiple_donated_args(a, b, c))
@test contains(
repr(hlo),
"@main(%arg0: tensor<2x2xf64> {tf.aliasing_output = 0 : i32}, %arg1: tensor<4x3xf64> {tf.aliasing_output = 2 : i32}, %arg2: tensor<2x2xf64> {tf.aliasing_output = 1 : i32})",
"@main(%arg0: tensor<2x2xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}, %arg1: tensor<4x3xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 2 : i32}, %arg2: tensor<2x2xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 1 : i32}) -> (tensor<2x2xf64>, tensor<2x2xf64>, tensor<4x3xf64>) attributes {enzymexla.memory_effects = []} {",
)
end

Expand Down
Loading