From b95c2db1132cf3c508ef339ce1c1d4f06e3ae4d9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Sep 2025 22:46:28 -0400 Subject: [PATCH 1/9] Add new optimization patterns to Compiler.jl --- src/Compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/Compiler.jl b/src/Compiler.jl index 1226b7757d..ad1c73cee6 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -906,6 +906,8 @@ function optimization_passes( "trivial_reduce_window_to_reduce_op", "dot_general_add_distributive_simplify", "dot_general_subtract_distributive_simplify", + "dus_to_dynamic_pad", + "dynamic_pad_to_pad", ] if !compile_options.disable_auto_batching_passes @@ -968,6 +970,7 @@ function optimization_passes( "gather_elementwise", ## const prop patterns "gather_const_prop", + "clamp_const_prop", ], ) end From af2e65aef3df108e7983ff95c4c31199a03bb246 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 30 Sep 2025 22:48:31 -0400 Subject: [PATCH 2/9] fix --- src/Compiler.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index ad1c73cee6..64cf3879bc 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -924,6 +924,7 @@ function optimization_passes( "concat_insert_dim_reduce", "concat_insert_dim_sort", "concat_insert_dim_reduce_window", + "concat_insert_dim_elementwise", "dot_general_slice_to_batch", "gather_slice_to_batch", "iota_slice_to_batch", @@ -970,7 +971,6 @@ function optimization_passes( "gather_elementwise", ## const prop patterns "gather_const_prop", - "clamp_const_prop", ], ) end @@ -1074,6 +1074,7 @@ function optimization_passes( "const_prop_through_barrier<16>", "concat_const_prop<1>($max_constant_threshold)", "dynamic_update_slice_const_prop($max_constant_threshold)", + "clamp_const_prop", ], ) From 29e312e244b8ad938fed019a36698713e8ca5297 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 1 Oct 2025 10:00:48 -0400 Subject: [PATCH 3/9] Update src/Compiler.jl --- src/Compiler.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 64cf3879bc..8d8b9f84a9 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -924,7 +924,8 @@ function optimization_passes( "concat_insert_dim_reduce", "concat_insert_dim_sort", "concat_insert_dim_reduce_window", - "concat_insert_dim_elementwise", + # XXX: busted `map of slices` tests, needs upstream fix + # "concat_insert_dim_elementwise", "dot_general_slice_to_batch", "gather_slice_to_batch", "iota_slice_to_batch", From bd152362d3bef711d393466c12778af67442111e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 3 Oct 2025 20:13:25 -0500 Subject: [PATCH 4/9] fix: use new commit --- src/Compiler.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 8d8b9f84a9..64cf3879bc 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -924,8 +924,7 @@ function optimization_passes( "concat_insert_dim_reduce", "concat_insert_dim_sort", "concat_insert_dim_reduce_window", - # XXX: busted `map of slices` tests, needs upstream fix - # "concat_insert_dim_elementwise", + "concat_insert_dim_elementwise", "dot_general_slice_to_batch", "gather_slice_to_batch", "iota_slice_to_batch", From ce1060e38c27c5b47442bfe9109cf65b683080d8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 3 Oct 2025 20:25:45 -0500 Subject: [PATCH 5/9] feat: use cuda cc functions directly [skip ci] --- ext/ReactantCUDAExt.jl | 8 -------- src/Compiler.jl | 2 ++ src/xla/XLA.jl | 9 +++++++++ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 388d99d148..bd9b55edde 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -1460,14 +1460,6 @@ function Reactant.make_tracer( return newa end -function __init__() - if CUDA.functional() && !Reactant.precompiling() - cap = CUDA.capability(CUDA.device()) - Reactant.Compiler.cubinChip[] = "sm_$(cap.major)$(cap.minor)" - end - return nothing -end - # In Julia v1.11.3 precompiling this module caches bad code: # . @static if !Sys.isapple() diff --git a/src/Compiler.jl b/src/Compiler.jl index 64cf3879bc..2a74133e25 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1429,6 +1429,8 @@ const cubinChip = Ref{String}("sm_60") const cubinFormat = Ref{String}("bin") const cuindexBitWidth = Ref{Int}(32) const cuOptLevel = Ref{Int}(2) +const cuWarpSize = Ref{Int}(32) + # Wgatever the relevant highest version from our LLVM is within NVPTX.td # Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684 # We see relevant ptx version is CUDA 12.6 -> 85 diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index f14139b890..1a7ffc17f2 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -234,6 +234,15 @@ for runtime in (:PJRT, :IFRT) ) state.clients["cuda"] = gpu state.default_client = gpu + + # set values for cuda. This is being done here since we need cuda + # to be initialized before we can use it. initializing the devices + # implicitly initializes cuda. + cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32 + cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32 + Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)" + + Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32 catch e println(stdout, e) end From 9cb5c9ec63928dfad2e4a7c8d26584ff3b665cd4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 5 Oct 2025 19:53:04 -0500 Subject: [PATCH 6/9] fix: drop attrs + disable (post) pad_dot_general for now [skip ci] --- src/Compiler.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 2a74133e25..d0356b3fb0 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1109,7 +1109,8 @@ function optimization_passes( "reshape_dus", "dot_reshape_pad<1>", "pad_dot_general<1>(0)", - "pad_dot_general<1>(1)", + # XXX: see https://github.com/EnzymeAD/Enzyme-JAX/issues/1445 + # "pad_dot_general<1>(1)", "reshape_pad", "reshape_wrap", "reshape_rotate", @@ -3522,6 +3523,9 @@ function compile_xla( module_string = "" end + # Drop some of our attributes + run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes") + if before_xla_optimizations exec = nothing hlo_modules = XLA.HloModule(mod) From c47c2e1f17cb13bf69b34890cc53021a1b7c2b7b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 5 Oct 2025 20:53:05 -0500 Subject: [PATCH 7/9] chore: bump JLL --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 30831a2c12..6e15763cbe 100644 --- a/Project.toml +++ b/Project.toml @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.246" +Reactant_jll = "0.0.247" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" From b116f17eefca021fc9150797326f81a794d90442 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 5 Oct 2025 21:10:57 -0500 Subject: [PATCH 8/9] docs: update --- docs/src/tutorials/partial-evaluation.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/tutorials/partial-evaluation.md b/docs/src/tutorials/partial-evaluation.md index 16a535b1c9..7538aedfa9 100644 --- a/docs/src/tutorials/partial-evaluation.md +++ b/docs/src/tutorials/partial-evaluation.md @@ -57,7 +57,7 @@ The StableHLO IR code generated here is: # output module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} { - func.func @main(%arg0: tensor, %arg1: tensor) -> tensor attributes {enzymexla.memory_effects = []} { + func.func @main(%arg0: tensor {enzymexla.memory_effects = []}, %arg1: tensor {enzymexla.memory_effects = []}) -> tensor attributes {enzymexla.memory_effects = []} { %0 = stablehlo.add %arg0, %arg1 : tensor return %0 : tensor } @@ -101,7 +101,7 @@ variable input `%arg0`: # output module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} { - func.func @main(%arg0: tensor) -> tensor attributes {enzymexla.memory_effects = []} { + func.func @main(%arg0: tensor {enzymexla.memory_effects = []}) -> tensor attributes {enzymexla.memory_effects = []} { %c = stablehlo.constant dense<4> : tensor %0 = stablehlo.add %arg0, %c : tensor return %0 : tensor From a5a2b1d4ab8f3ebed2a82dc1c9161c1dae72416f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 5 Oct 2025 22:03:26 -0500 Subject: [PATCH 9/9] test: update buffer donation --- test/buffer_donation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/buffer_donation.jl b/test/buffer_donation.jl index 9afde797f3..a842ac26c8 100644 --- a/test/buffer_donation.jl +++ b/test/buffer_donation.jl @@ -50,7 +50,7 @@ end hlo = @code_hlo(multiple_donated_args(a, b, c)) @test contains( repr(hlo), - "@main(%arg0: tensor<2x2xf64> {tf.aliasing_output = 0 : i32}, %arg1: tensor<4x3xf64> {tf.aliasing_output = 2 : i32}, %arg2: tensor<2x2xf64> {tf.aliasing_output = 1 : i32})", + "@main(%arg0: tensor<2x2xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}, %arg1: tensor<4x3xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 2 : i32}, %arg2: tensor<2x2xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 1 : i32}) -> (tensor<2x2xf64>, tensor<2x2xf64>, tensor<4x3xf64>) attributes {enzymexla.memory_effects = []} {", ) end