Skip to content

Commit 7cdd96b

Browse files
authored
feat: updates for new JLL (#1721)
* Add new optimization patterns to Compiler.jl * fix * Update src/Compiler.jl * fix: use new commit * feat: use cuda cc functions directly [skip ci] * fix: drop attrs + disable (post) pad_dot_general for now [skip ci] * chore: bump JLL * docs: update * test: update buffer donation
1 parent 2952cb0 commit 7cdd96b

File tree

6 files changed

+24
-13
lines changed

6 files changed

+24
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ PythonCall = "0.9.25"
105105
Random = "1.10"
106106
Random123 = "1.7"
107107
ReactantCore = "0.1.16"
108-
Reactant_jll = "0.0.246"
108+
Reactant_jll = "0.0.247"
109109
ScopedValues = "1.3.0"
110110
Scratch = "1.2"
111111
Sockets = "1.10"

docs/src/tutorials/partial-evaluation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ The StableHLO IR code generated here is:
5757
# output
5858
5959
module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
60-
func.func @main(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
60+
func.func @main(%arg0: tensor<i64> {enzymexla.memory_effects = []}, %arg1: tensor<i64> {enzymexla.memory_effects = []}) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
6161
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
6262
return %0 : tensor<i64>
6363
}
@@ -101,7 +101,7 @@ variable input `%arg0`:
101101
# output
102102
103103
module @reactant_add attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
104-
func.func @main(%arg0: tensor<i64>) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
104+
func.func @main(%arg0: tensor<i64> {enzymexla.memory_effects = []}) -> tensor<i64> attributes {enzymexla.memory_effects = []} {
105105
%c = stablehlo.constant dense<4> : tensor<i64>
106106
%0 = stablehlo.add %arg0, %c : tensor<i64>
107107
return %0 : tensor<i64>

ext/ReactantCUDAExt.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,14 +1460,6 @@ function Reactant.make_tracer(
14601460
return newa
14611461
end
14621462

1463-
function __init__()
1464-
if CUDA.functional() && !Reactant.precompiling()
1465-
cap = CUDA.capability(CUDA.device())
1466-
Reactant.Compiler.cubinChip[] = "sm_$(cap.major)$(cap.minor)"
1467-
end
1468-
return nothing
1469-
end
1470-
14711463
# In Julia v1.11.3 precompiling this module caches bad code:
14721464
# <https://github.com/EnzymeAD/Reactant.jl/issues/614>.
14731465
@static if !Sys.isapple()

src/Compiler.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,8 @@ function optimization_passes(
906906
"trivial_reduce_window_to_reduce_op",
907907
"dot_general_add_distributive_simplify",
908908
"dot_general_subtract_distributive_simplify",
909+
"dus_to_dynamic_pad",
910+
"dynamic_pad_to_pad",
909911
]
910912

911913
if !compile_options.disable_auto_batching_passes
@@ -922,6 +924,7 @@ function optimization_passes(
922924
"concat_insert_dim_reduce",
923925
"concat_insert_dim_sort",
924926
"concat_insert_dim_reduce_window",
927+
"concat_insert_dim_elementwise",
925928
"dot_general_slice_to_batch",
926929
"gather_slice_to_batch",
927930
"iota_slice_to_batch",
@@ -1071,6 +1074,7 @@ function optimization_passes(
10711074
"const_prop_through_barrier<16>",
10721075
"concat_const_prop<1>($max_constant_threshold)",
10731076
"dynamic_update_slice_const_prop($max_constant_threshold)",
1077+
"clamp_const_prop",
10741078
],
10751079
)
10761080

@@ -1105,7 +1109,8 @@ function optimization_passes(
11051109
"reshape_dus",
11061110
"dot_reshape_pad<1>",
11071111
"pad_dot_general<1>(0)",
1108-
"pad_dot_general<1>(1)",
1112+
# XXX: see https://github.com/EnzymeAD/Enzyme-JAX/issues/1445
1113+
# "pad_dot_general<1>(1)",
11091114
"reshape_pad",
11101115
"reshape_wrap",
11111116
"reshape_rotate",
@@ -1425,6 +1430,8 @@ const cubinChip = Ref{String}("sm_60")
14251430
const cubinFormat = Ref{String}("bin")
14261431
const cuindexBitWidth = Ref{Int}(32)
14271432
const cuOptLevel = Ref{Int}(2)
1433+
const cuWarpSize = Ref{Int}(32)
1434+
14281435
# Wgatever the relevant highest version from our LLVM is within NVPTX.td
14291436
# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684
14301437
# We see relevant ptx version is CUDA 12.6 -> 85
@@ -3516,6 +3523,9 @@ function compile_xla(
35163523
module_string = ""
35173524
end
35183525

3526+
# Drop some of our attributes
3527+
run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes")
3528+
35193529
if before_xla_optimizations
35203530
exec = nothing
35213531
hlo_modules = XLA.HloModule(mod)

src/xla/XLA.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,15 @@ for runtime in (:PJRT, :IFRT)
234234
)
235235
state.clients["cuda"] = gpu
236236
state.default_client = gpu
237+
238+
# set values for cuda. This is being done here since we need cuda
239+
# to be initialized before we can use it. initializing the devices
240+
# implicitly initializes cuda.
241+
cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32
242+
cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32
243+
Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)"
244+
245+
Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32
237246
catch e
238247
println(stdout, e)
239248
end

test/buffer_donation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ end
5050
hlo = @code_hlo(multiple_donated_args(a, b, c))
5151
@test contains(
5252
repr(hlo),
53-
"@main(%arg0: tensor<2x2xf64> {tf.aliasing_output = 0 : i32}, %arg1: tensor<4x3xf64> {tf.aliasing_output = 2 : i32}, %arg2: tensor<2x2xf64> {tf.aliasing_output = 1 : i32})",
53+
"@main(%arg0: tensor<2x2xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}, %arg1: tensor<4x3xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 2 : i32}, %arg2: tensor<2x2xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 1 : i32}) -> (tensor<2x2xf64>, tensor<2x2xf64>, tensor<4x3xf64>) attributes {enzymexla.memory_effects = []} {",
5454
)
5555
end
5656

0 commit comments

Comments
 (0)