From dac2864921d229fd68b93e4671ba6b5807e4ff3f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Sep 2025 22:16:04 -0500 Subject: [PATCH 01/16] feat: initial triton setup [skip ci] --- CondaPkg.toml | 1 + .../ReactantPythonCallExt.jl | 14 +++++++++++++- ext/ReactantPythonCallExt/overlays.jl | 2 +- ext/ReactantPythonCallExt/pycall.jl | 16 +++++++++++++++- 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/CondaPkg.toml b/CondaPkg.toml index b1db4f8e75..40cc769513 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4" jax = ">= 0.6" tensorflow = ">= 2.17" numpy = ">= 2" +triton = "" # TODO: version bound diff --git a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl index 1f10630808..3b4a85e951 100644 --- a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl @@ -1,6 +1,6 @@ module ReactantPythonCallExt -using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist +using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay using Reactant.Ops: @opcall @@ -9,6 +9,10 @@ const jnpptr = Ref{Py}() const JAX_TRACING_SUPPORTED = Ref{Bool}(false) +const tritonptr = Ref{Py}() + +const TRITON_COMPILE_SUPPORTED = Ref{Bool}(false) + const tfptr = Ref{Py}() const tf2xlaptr = Ref{Py}() const npptr = Ref{Py}() @@ -43,6 +47,14 @@ function __init__() be supported." exception = (err, catch_backtrace()) end + try + tritonptr[] = pyimport("triton") + TRITON_COMPILE_SUPPORTED[] = true + catch err + @warn "Failed to import triton. Compiling jax functions with triton won't be \ + supported." exception = (err, catch_backtrace()) + end + try tfptr[] = pyimport("tensorflow") tfptr[].config.set_visible_devices(pylist(); device_type="GPU") diff --git a/ext/ReactantPythonCallExt/overlays.jl b/ext/ReactantPythonCallExt/overlays.jl index 20ffa7384f..20a9210023 100644 --- a/ext/ReactantPythonCallExt/overlays.jl +++ b/ext/ReactantPythonCallExt/overlays.jl @@ -1,6 +1,6 @@ @reactant_overlay function PythonCall.pycall(f::Py, args...) if Reactant.looped_any(Reactant.use_overlayed_version, args) - return pycall_with_jax_tracing(f, args...) + return overlayed_pycall(f, args...) else return Base.inferencebarrier(PythonCall.pycall)(f, args...) end diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 23674d9155..8f81b50049 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -7,7 +7,17 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe ) end -function pycall_with_jax_tracing(f::Py, args...) +function overlayed_pycall(f::Py, args...) + @assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[] + # TODO: check for Autotuner and Heutistics as well + if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction) + return overlayed_pycall_with_triton(f, args...) + else + return overlayed_pycall_with_jax_tracing(f, args...) + end +end + +function overlayed_pycall_with_jax_tracing(f::Py, args...) JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.") seen_args = Reactant.OrderedIdDict() @@ -35,3 +45,7 @@ function pycall_with_jax_tracing(f::Py, args...) res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end + +function overlayed_pycall_with_triton(f::Py, args...) + error("TODO: implement triton") +end From 6c1d287b091bc1a1dbde5141edb6bd1f3d3cc1d5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Sep 2025 23:30:22 -0500 Subject: [PATCH 02/16] feat: auto-trace triton code --- CondaPkg.toml | 2 +- .../ReactantPythonCallExt.jl | 26 ++++++- ext/ReactantPythonCallExt/overlays.jl | 6 +- ext/ReactantPythonCallExt/pycall.jl | 72 +++++++++++++++++-- 4 files changed, 97 insertions(+), 9 deletions(-) diff --git a/CondaPkg.toml b/CondaPkg.toml index 40cc769513..00aa12cb4a 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -5,4 +5,4 @@ python = "<=3.13,>=3.9,<4" jax = ">= 0.6" tensorflow = ">= 2.17" numpy = ">= 2" -triton = "" # TODO: version bound +triton = ">= 3.4" diff --git a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl index 3b4a85e951..af3852ce2e 100644 --- a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl @@ -1,8 +1,10 @@ module ReactantPythonCallExt -using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance +using PythonCall: + PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance, pytuple using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay using Reactant.Ops: @opcall +using Reactant_jll: Reactant_jll const jaxptr = Ref{Py}() const jnpptr = Ref{Py}() @@ -37,6 +39,28 @@ const NUMPY_SIMPLE_TYPES = Dict( ComplexF64 => :complex64, ) +const MLIR_TYPE_STRING = Dict( + Float64 => "fp64", + Float32 => "fp32", + Float16 => "fp16", + Int64 => "i64", + Int32 => "i32", + Int16 => "i16", + Int8 => "i8", + UInt64 => "ui64", + UInt32 => "ui32", + UInt16 => "ui16", + UInt8 => "ui8", + Bool => "i1", + Reactant.F8E4M3FN => "fp8e4nv", + Reactant.F8E5M2FNUZ => "fp8e5b16", + Reactant.F8E4M3FNUZ => "fp8e4b8", + Reactant.F8E5M2 => "fp8e5", +) +if isdefined(Core, :BFloat16) + MLIR_TYPE_STRING[Core.BFloat16] = "bf16" +end + function __init__() try jaxptr[] = pyimport("jax") diff --git a/ext/ReactantPythonCallExt/overlays.jl b/ext/ReactantPythonCallExt/overlays.jl index 20a9210023..ca5bcfcea5 100644 --- a/ext/ReactantPythonCallExt/overlays.jl +++ b/ext/ReactantPythonCallExt/overlays.jl @@ -1,7 +1,7 @@ -@reactant_overlay function PythonCall.pycall(f::Py, args...) +@reactant_overlay function PythonCall.pycall(f::Py, args...; kwargs...) if Reactant.looped_any(Reactant.use_overlayed_version, args) - return overlayed_pycall(f, args...) + return overlayed_pycall(f, args...; kwargs...) else - return Base.inferencebarrier(PythonCall.pycall)(f, args...) + return Base.inferencebarrier(PythonCall.pycall)(f, args...; kwargs...) end end diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 8f81b50049..7786c1b73a 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -7,12 +7,13 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe ) end -function overlayed_pycall(f::Py, args...) +function overlayed_pycall(f::Py, args...; kwargs...) @assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[] # TODO: check for Autotuner and Heutistics as well if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction) - return overlayed_pycall_with_triton(f, args...) + return overlayed_pycall_with_triton(f, args...; kwargs...) else + @assert isempty(kwargs) "`kwargs` are not supported for jax traced functions." return overlayed_pycall_with_jax_tracing(f, args...) end end @@ -46,6 +47,69 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end -function overlayed_pycall_with_triton(f::Py, args...) - error("TODO: implement triton") +# TODO: support using metaparams here +normalize_grid(grid::Integer) = normalize_grid((grid,)) +function normalize_grid(grid::Dims{N}) where {N} + @assert N <= 3 + @assert all(grid .> 0) + return (grid..., ntuple(_ -> 1, 3 - N)...) +end + +signature_string(::TracedRArray{T}) where {T} = "*$(MLIR_TYPE_STRING[T])", nothing +signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothing +signature_string(x::T) where {T<:Number} = string(x), x +signature_string(x) = error("Unsupported argument type: $(typeof(x))") + +function overlayed_pycall_with_triton( + kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing +) + triton = tritonptr[] + + grid = normalize_grid(grid) + + mapped = map(signature_string, args) + signature = first.(mapped) + # TODO: are hints actually correctly set? + hints = + hints === nothing ? Dict() : Dict(kernel.arg_names[i - 1] => v for (i, v) in hints) + constants = Dict( + kernel.arg_names[i - 1] => constant for + (i, constant) in enumerate(last.(mapped)) if constant !== nothing + ) + for (k, v) in hints + v == 1 && (constants[kernel.arg_names[k - 1]] = v) + end + attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16) + + sigmap = Dict(kernel.arg_names[i - 1] => sig for (i, sig) in enumerate(signature)) + for k in keys(constants) + sigmap[k] = "constexpr" + end + + for h in values(hints) + @assert h in (1, 16) "Only 1 and 16 are valid hints, got $h" + end + attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16) + + src = triton.compiler.ASTSource(; + fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs + ) + + # TODO: check that we are using CUDA. Get compute_capability from the target + target = triton.backends.compiler.GPUTarget("cuda", 80, 32) + backend = triton.compiler.make_backend(target) + options = backend.parse_options( + pydict( + "num_warps" => num_warps, + "num_stages" => num_stages, + "extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)), + ), + ) + + ccinfo = triton.compile(src; target=target, options=options.__dict__) + + println(pyconvert(String, ccinfo.asm["source"])) + println(pyconvert(String, ccinfo.asm["ttir"])) + + return error("TODO: implement triton") end From ba29516974bd430e8b4c2777ad2448343989e4ee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 25 Sep 2025 00:04:41 -0500 Subject: [PATCH 03/16] feat: copy tt.func into main module [skip ci] --- ext/ReactantPythonCallExt/pycall.jl | 14 ++++- src/Ops.jl | 98 +++++++++++++++++------------ 2 files changed, 70 insertions(+), 42 deletions(-) diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 7786c1b73a..c1c5662a67 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -108,8 +108,18 @@ function overlayed_pycall_with_triton( ccinfo = triton.compile(src; target=target, options=options.__dict__) - println(pyconvert(String, ccinfo.asm["source"])) - println(pyconvert(String, ccinfo.asm["ttir"])) + @show ccinfo.metadata + @show ccinfo.asm.keys() + # shared = ccinfo.metadata["shared"] + kernel_name = pyconvert(String, ccinfo.metadata.name) + # cluster_dims = ccinfo.metadata["cluster_dims"] + + # println(pyconvert(String, ccinfo.asm["source"])) + # println(pyconvert(String, ccinfo.asm["ttir"])) + + res = @opcall triton_call( + pyconvert(String, ccinfo.asm["ttir"]), args...; func_name=kernel_name + ) return error("TODO: implement triton") end diff --git a/src/Ops.jl b/src/Ops.jl index 39c9790dea..d5f9bc179f 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1743,51 +1743,22 @@ end end # Generate a unique name given a module hash and a function name. -function _hlo_call_name(orig_name, module_suffix) - return orig_name * "_hlo_call_" * module_suffix -end +_new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix -""" - hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray} - -Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main") -with the provided arguments and return a tuple for each result of the call. - -```julia-repl -julia> Reactant.@jit( - hlo_call( - \"\"\" - module { - func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { - %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> - return %0 : tensor<3xf32> - } - } - \"\"\", - Reactant.to_rarray(Float32[1, 2, 3]), - Reactant.to_rarray(Float32[1, 2, 3]), - ) - ) -(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) -``` -""" -@noinline function hlo_call( - code, - args...; - func_name="main", - location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), +function _extract_function( + code::String; func_name::String="main", func_op_kind::String="func.func" ) module_suffix = string(hash(code); base=16) - name_to_call = _hlo_call_name(func_name, module_suffix) + name_to_call = _new_function_name(func_name, module_suffix) current_module = MLIR.IR.mmodule() top_level_block = MLIR.IR.body(current_module) symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - fn = MLIR.IR.lookup( MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call ) + if isnothing(fn) new_mod = parse(MLIR.IR.Module, code) new_mod_op = MLIR.IR.Operation(new_mod) @@ -1795,16 +1766,15 @@ julia> Reactant.@jit( operations = collect(MLIR.IR.OperationIterator(body)) for op in operations - if MLIR.IR.name(op) == "func.func" + if MLIR.IR.name(op) == func_op_kind fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) if fn_name == func_name fn = op end - new_name = _hlo_call_name(fn_name, module_suffix) res = MLIR.IR.LogicalResult( MLIR.API.mlirSymbolTableReplaceAllSymbolUses( - fn_name, new_name, new_mod_op + fn_name, name_to_call, new_mod_op ), ) @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" @@ -1817,7 +1787,7 @@ julia> Reactant.@jit( ) # Change function name - MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name)) + MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) end end @@ -1831,11 +1801,59 @@ julia> Reactant.@jit( error("hlo_call: could not find function $func_name in the provided module") end + return name_to_call +end + +function triton_call( + mlir_code::String, + args::Union{TracedRArray,TracedRNumber,Number}...; + func_name::String="main", + location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), +) + name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") + + @show name_to_call + display(MLIR.IR.mmodule()) + + error("TODO: implement triton_call") +end + +""" + hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray} + +Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main") +with the provided arguments and return a tuple for each result of the call. + +```julia-repl +julia> Reactant.@jit( + hlo_call( + \"\"\" + module { + func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { + %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32> + return %0 : tensor<3xf32> + } + } + \"\"\", + Reactant.to_rarray(Float32[1, 2, 3]), + Reactant.to_rarray(Float32[1, 2, 3]), + ) + ) +(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) +``` +""" +@noinline function hlo_call( + code, + args::Union{TracedRArray,TracedRNumber}...; + func_name="main", + location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), +) + name_to_call = _extract_function(code; func_name, func_op_kind="func.func") + ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) - @assert all(Base.Fix2(isa, Union{TracedRArray,TracedRNumber}), args) "hlo_call: all inputs to hlo_call should be reactant arrays or numbers" - @assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name" + @assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name. Expected $(MLIR.IR.ninputs(ftype)), got $(length(args))" for (i, arg) in enumerate(args) expected_type = MLIR.IR.input(ftype, i) From c26da6ff69d05713d810d18a370ba99a09cedb5c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Sep 2025 08:17:27 -0500 Subject: [PATCH 04/16] chore: regen mlir bindings --- src/mlir/Dialects/EnzymeXLA.jl | 47 ++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl index ca396710ca..d6a44439fe 100755 --- a/src/mlir/Dialects/EnzymeXLA.jl +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -837,6 +837,53 @@ function stream2token(source::Value; result::IR.Type, location=Location()) ) end +function triton_call( + gridx::Value, + gridy::Value, + gridz::Value, + shmem::Value, + inputs::Vector{Value}; + result_0::Vector{IR.Type}, + fn, + backend_config=nothing, + operand_layouts=nothing, + result_layouts=nothing, + arg_attrs=nothing, + res_attrs=nothing, + output_operand_aliases=nothing, + xla_side_effect_free=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[gridx, gridy, gridz, shmem, inputs...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(backend_config) && + push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && + push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && + push!(attributes, namedattribute("result_layouts", result_layouts)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(output_operand_aliases) && + push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && + push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + return create_operation( + "enzymexla.triton_call", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function wrap( operand::Value; result=nothing::Union{Nothing,IR.Type}, From 26d217f2100adb60cc4ade44de0e5a9e40716912 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Sep 2025 08:56:21 -0500 Subject: [PATCH 05/16] feat: tracing fully functional --- ext/ReactantPythonCallExt/pycall.jl | 23 +++++++++++------------ src/Ops.jl | 19 ++++++++++++++++--- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index c1c5662a67..4c9a8cba82 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -106,20 +106,19 @@ function overlayed_pycall_with_triton( ), ) + # Currently we are doing a double compilation here. can we do better? + # we are compiling here + lowering again inside enzymejax ccinfo = triton.compile(src; target=target, options=options.__dict__) - @show ccinfo.metadata - @show ccinfo.asm.keys() - # shared = ccinfo.metadata["shared"] - kernel_name = pyconvert(String, ccinfo.metadata.name) - # cluster_dims = ccinfo.metadata["cluster_dims"] - - # println(pyconvert(String, ccinfo.asm["source"])) - # println(pyconvert(String, ccinfo.asm["ttir"])) - - res = @opcall triton_call( - pyconvert(String, ccinfo.asm["ttir"]), args...; func_name=kernel_name + @opcall triton_call( + pyconvert(String, ccinfo.asm["ttir"]), + filter(x -> x isa Reactant.TracedType, args)...; + func_name=pyconvert(String, ccinfo.metadata.name), + grid_x=@opcall(constant(grid[1])), + grid_y=@opcall(constant(grid[2])), + grid_z=@opcall(constant(grid[3])), + shmem=@opcall(constant(pyconvert(Int, ccinfo.metadata.shared))), ) - return error("TODO: implement triton") + return nothing end diff --git a/src/Ops.jl b/src/Ops.jl index d5f9bc179f..2a09399dc9 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1808,14 +1808,27 @@ function triton_call( mlir_code::String, args::Union{TracedRArray,TracedRNumber,Number}...; func_name::String="main", + grid_x::TracedRNumber{<:Integer}, + grid_y::TracedRNumber{<:Integer}, + grid_z::TracedRNumber{<:Integer}, + shmem::TracedRNumber{<:Integer}, location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), + # TODO: other kwargs ) name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") - @show name_to_call - display(MLIR.IR.mmodule()) + enzymexla.triton_call( + grid_x.mlir_data, + grid_y.mlir_data, + grid_z.mlir_data, + shmem.mlir_data, + [Reactant.TracedUtils.get_mlir_data(a) for a in args]; + fn=MLIR.IR.FlatSymbolRefAttribute(name_to_call), + result_0=MLIR.IR.Type[], + location, + ) - error("TODO: implement triton_call") + return nothing end """ From 933f67a862f888acab075fcf2c877041a3609788 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Sep 2025 10:11:26 -0500 Subject: [PATCH 06/16] fix: hlo_call --- src/Ops.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 2a09399dc9..c3fe3bce58 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1801,7 +1801,7 @@ function _extract_function( error("hlo_call: could not find function $func_name in the provided module") end - return name_to_call + return fn, name_to_call end function triton_call( @@ -1815,7 +1815,7 @@ function triton_call( location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) - name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") + _, name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") enzymexla.triton_call( grid_x.mlir_data, @@ -1861,7 +1861,7 @@ julia> Reactant.@jit( func_name="main", location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), ) - name_to_call = _extract_function(code; func_name, func_op_kind="func.func") + fn, name_to_call = _extract_function(code; func_name, func_op_kind="func.func") ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) From 013180cb2d68aed26a93810449f28cb26aa9c114 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Sep 2025 23:19:56 -0500 Subject: [PATCH 07/16] feat: more triton passes + keep triton func in a separate module --- deps/ReactantExtra/BUILD | 3 ++ ext/ReactantPythonCallExt/pycall.jl | 10 +++-- src/Compiler.jl | 60 ++++++++++++++++++++++++++++- src/Ops.jl | 14 ++++++- src/mlir/IR/Module.jl | 3 +- 5 files changed, 83 insertions(+), 7 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 36134e154f..83946212e2 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -979,6 +979,9 @@ cc_library( "-Wl,-exported_symbol,_ReactantFuncSetArgAttr", "-Wl,-exported_symbol,_ReactantHermeticCudaGetVersion", "-Wl,-exported_symbol,_ReactantCudaDriverGetVersion", + "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor", + "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor", + "-Wl,-exported_symbol,_ReactantCudaDeviceGetWarpSizeInThreads", "-Wl,-exported_symbol,_ReactantLLVMParseCommandLineOptions", "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMajor", "-Wl,-exported_symbol,_ReactantCudaDeviceGetComputeCapalilityMinor", diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 4c9a8cba82..40026af81f 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -60,6 +60,7 @@ signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothi signature_string(x::T) where {T<:Number} = string(x), x signature_string(x) = error("Unsupported argument type: $(typeof(x))") +# TODO: better name for hints? function overlayed_pycall_with_triton( kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing ) @@ -95,8 +96,11 @@ function overlayed_pycall_with_triton( fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs ) - # TODO: check that we are using CUDA. Get compute_capability from the target - target = triton.backends.compiler.GPUTarget("cuda", 80, 32) + target = triton.backends.compiler.GPUTarget( + "cuda", + parse(Int, Reactant.Compiler.cubinChip[][4:end]), + Reactant.Compiler.cuWarpSize[], + ) backend = triton.compiler.make_backend(target) options = backend.parse_options( pydict( @@ -111,7 +115,7 @@ function overlayed_pycall_with_triton( ccinfo = triton.compile(src; target=target, options=options.__dict__) @opcall triton_call( - pyconvert(String, ccinfo.asm["ttir"]), + pyconvert(String, ccinfo.asm["source"]), filter(x -> x isa Reactant.TracedType, args)...; func_name=pyconvert(String, ccinfo.metadata.name), grid_x=@opcall(constant(grid[1])), diff --git a/src/Compiler.jl b/src/Compiler.jl index 46d577bf43..9f31757f27 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1298,9 +1298,66 @@ function optimization_passes( push!(passes, "remove-duplicate-func-def") end push!(passes, func_passes) + if backend == "cuda" + push!(passes, triton_optimization_passes()) + end return join(passes, ',') end +# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc +function triton_optimization_passes() + # TODO: check that all triton passes are included here + return join( + [ + # convert passes + "convert-scf-to-cf", + "convert-cf-to-llvm", + "convert-index-to-llvm", + "convert-arith-to-llvm", + "convert-nvvm-to-llvm", + # common passes + "canonicalize", + # # ttir passes + # "triton-combine", + # "triton-reorder-broadcast", + # "triton-rewrite-tensor-pointer", + # "triton-rewrite-tensor-descriptor-to-pointer", + # "triton-loop-unroll", + # "triton-licm", + # "triton-loop-aware-cse", + # # TODO: should num-warps and num-ctas be set for each kernel? + # "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}", + # # ttgir passes + # "tritongpu-coalesce", + # "tritongpu-optimize-thread-locality", + # "tritongpu-hoist-tmem-alloc", + # "tritongpu-assign-latencies", + # "tritongpu-pipeline", + # "tritongpu-schedule-loops", + # "tritongpu-automatic-warp-specialization", + # "tritongpu-prefetch", + # "tritongpu-accelerate-matmul", + # "tritongpu-reorder-instructions", + # "tritongpu-F32DotTC", + # "tritongpu-optimize-dot-operands", + # "tritongpu-remove-layout-conversions", + # "tritongpu-reduce-data-duplication", + # "tritongpu-hoist-tmem-alloc", + # "tritongpu-fuse-nested-loops", + # "tritongpu-rewrite-partition-dependencies", + # "tritongpu-partition-loops", + # "tritongpu-combine-tensor-select-and-if", + # # ttgir to llvm passes + # "tritongpu-allocate-warp-groups", + # "allocate-shared-memory", + # "tritongpu-global-scratch-memory-allocation", + # "tritongpu-optimize-accumulator-init", + # "tritongpu-coalesce-async-copy", + ], + ",", + ) +end + # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}" @@ -2254,7 +2311,8 @@ function compile_mlir!( end end - run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects") + # XXX: re-enable this pass + # run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects") func_op = MLIR.API.mlirSymbolTableLookup( MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname diff --git a/src/Ops.jl b/src/Ops.jl index c3fe3bce58..1fe808171c 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1746,12 +1746,20 @@ end _new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix function _extract_function( - code::String; func_name::String="main", func_op_kind::String="func.func" + code::String; + func_name::String="main", + func_op_kind::String="func.func", + nested_module::Bool=false, ) module_suffix = string(hash(code); base=16) name_to_call = _new_function_name(func_name, module_suffix) current_module = MLIR.IR.mmodule() + if nested_module + new_module = MLIR.IR.Module() + push!(MLIR.IR.body(current_module), MLIR.IR.Operation(new_module, true)) + current_module = new_module + end top_level_block = MLIR.IR.body(current_module) symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) @@ -1815,7 +1823,9 @@ function triton_call( location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) - _, name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func") + _, name_to_call = _extract_function( + mlir_code; func_name, func_op_kind="tt.func", nested_module=true + ) enzymexla.triton_call( grid_x.mlir_data, diff --git a/src/mlir/IR/Module.jl b/src/mlir/IR/Module.jl index 12794b30ba..c7f938d5b8 100644 --- a/src/mlir/IR/Module.jl +++ b/src/mlir/IR/Module.jl @@ -52,7 +52,8 @@ body(module_) = Block(API.mlirModuleGetBody(module_), false) Views the module as a generic operation. """ -Operation(module_::Module) = Operation(API.mlirModuleGetOperation(module_), false) +Operation(module_::Module, owned::Bool=false) = + Operation(API.mlirModuleGetOperation(module_), owned) function Base.show(io::IO, module_::Module) return show(io, Operation(module_)) From e29cd0a8d8400dbba013a586dff71bbd5b0df274 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 28 Sep 2025 00:14:35 -0500 Subject: [PATCH 08/16] feat: put the tt func in a separate module and use symbol ref --- src/Compiler.jl | 75 ++++++++++++++++++++--------------------- src/Ops.jl | 90 +++++++++++++++++++++++++++---------------------- 2 files changed, 87 insertions(+), 78 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 9f31757f27..4bb540d83f 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1317,42 +1317,42 @@ function triton_optimization_passes() "convert-nvvm-to-llvm", # common passes "canonicalize", - # # ttir passes - # "triton-combine", - # "triton-reorder-broadcast", - # "triton-rewrite-tensor-pointer", - # "triton-rewrite-tensor-descriptor-to-pointer", - # "triton-loop-unroll", - # "triton-licm", - # "triton-loop-aware-cse", - # # TODO: should num-warps and num-ctas be set for each kernel? - # "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}", - # # ttgir passes - # "tritongpu-coalesce", - # "tritongpu-optimize-thread-locality", - # "tritongpu-hoist-tmem-alloc", - # "tritongpu-assign-latencies", - # "tritongpu-pipeline", - # "tritongpu-schedule-loops", - # "tritongpu-automatic-warp-specialization", - # "tritongpu-prefetch", - # "tritongpu-accelerate-matmul", - # "tritongpu-reorder-instructions", - # "tritongpu-F32DotTC", - # "tritongpu-optimize-dot-operands", - # "tritongpu-remove-layout-conversions", - # "tritongpu-reduce-data-duplication", - # "tritongpu-hoist-tmem-alloc", - # "tritongpu-fuse-nested-loops", - # "tritongpu-rewrite-partition-dependencies", - # "tritongpu-partition-loops", - # "tritongpu-combine-tensor-select-and-if", - # # ttgir to llvm passes - # "tritongpu-allocate-warp-groups", - # "allocate-shared-memory", - # "tritongpu-global-scratch-memory-allocation", - # "tritongpu-optimize-accumulator-init", - # "tritongpu-coalesce-async-copy", + # ttir passes + "triton-combine", + "triton-reorder-broadcast", + "triton-rewrite-tensor-pointer", + "triton-rewrite-tensor-descriptor-to-pointer", + "triton-loop-unroll", + "triton-licm", + "triton-loop-aware-cse", + # TODO: should num-warps and num-ctas be set for each kernel? + "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}", + # ttgir passes + "tritongpu-coalesce", + "tritongpu-optimize-thread-locality", + "tritongpu-hoist-tmem-alloc", + "tritongpu-assign-latencies", + "tritongpu-pipeline", + "tritongpu-schedule-loops", + "tritongpu-automatic-warp-specialization", + "tritongpu-prefetch", + "tritongpu-accelerate-matmul", + "tritongpu-reorder-instructions", + "tritongpu-F32DotTC", + "tritongpu-optimize-dot-operands", + "tritongpu-remove-layout-conversions", + "tritongpu-reduce-data-duplication", + "tritongpu-hoist-tmem-alloc", + "tritongpu-fuse-nested-loops", + "tritongpu-rewrite-partition-dependencies", + "tritongpu-partition-loops", + "tritongpu-combine-tensor-select-and-if", + # ttgir to llvm passes + "tritongpu-allocate-warp-groups", + "allocate-shared-memory", + "tritongpu-global-scratch-memory-allocation", + "tritongpu-optimize-accumulator-init", + "tritongpu-coalesce-async-copy", ], ",", ) @@ -2311,8 +2311,7 @@ function compile_mlir!( end end - # XXX: re-enable this pass - # run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects") + run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects") func_op = MLIR.API.mlirSymbolTableLookup( MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname diff --git a/src/Ops.jl b/src/Ops.jl index 1fe808171c..e526371807 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1750,22 +1750,32 @@ function _extract_function( func_name::String="main", func_op_kind::String="func.func", nested_module::Bool=false, + location::MLIR.IR.Location=MLIR.IR.Location(), ) module_suffix = string(hash(code); base=16) - name_to_call = _new_function_name(func_name, module_suffix) + name_to_call = func_name * "_call_" * module_suffix + mod_name = func_name * "_module_" * module_suffix + symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - current_module = MLIR.IR.mmodule() if nested_module - new_module = MLIR.IR.Module() - push!(MLIR.IR.body(current_module), MLIR.IR.Operation(new_module, true)) - current_module = new_module - end - top_level_block = MLIR.IR.body(current_module) + region = MLIR.IR.Region() + push!(region, MLIR.IR.Block()) + moduleop = MLIR.Dialects.builtin.module_(; + location, bodyRegion=region, sym_name=mod_name + ) + MLIR.IR.rmfromparent!(moduleop) + push!(MLIR.IR.body(MLIR.IR.mmodule()), moduleop) # insert into parent module - symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - fn = MLIR.IR.lookup( - MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call - ) + top_level_block = MLIR.IR.Block( + MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false + ) + fn = nothing + else + current_module = MLIR.IR.mmodule() + moduleop = MLIR.IR.Operation(current_module) + top_level_block = MLIR.IR.body(current_module) + fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call) + end if isnothing(fn) new_mod = parse(MLIR.IR.Module, code) @@ -1773,31 +1783,27 @@ function _extract_function( body = MLIR.IR.body(new_mod) operations = collect(MLIR.IR.OperationIterator(body)) - for op in operations - if MLIR.IR.name(op) == func_op_kind - fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) - if fn_name == func_name - fn = op - end + idx = Base.findfirst(op -> MLIR.IR.name(op) == func_op_kind, operations) + @assert idx !== nothing + op = operations[idx] - res = MLIR.IR.LogicalResult( - MLIR.API.mlirSymbolTableReplaceAllSymbolUses( - fn_name, name_to_call, new_mod_op - ), - ) - @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" - - # Set function private - MLIR.IR.attr!( - op, - MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), - MLIR.IR.Attribute("private"), - ) - - # Change function name - MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) - end - end + fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) + fn_name == func_name && (fn = op) + + res = MLIR.IR.LogicalResult( + MLIR.API.mlirSymbolTableReplaceAllSymbolUses(fn_name, name_to_call, new_mod_op) + ) + @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" + + # Set function private + MLIR.IR.attr!( + op, + MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), + MLIR.IR.Attribute("private"), + ) + + # Change function name + MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) for op in operations MLIR.IR.rmfromparent!(op) @@ -1809,7 +1815,7 @@ function _extract_function( error("hlo_call: could not find function $func_name in the provided module") end - return fn, name_to_call + return fn, name_to_call, mod_name end function triton_call( @@ -1823,8 +1829,8 @@ function triton_call( location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) - _, name_to_call = _extract_function( - mlir_code; func_name, func_op_kind="tt.func", nested_module=true + _, name_to_call, mod_name = _extract_function( + mlir_code; func_name, func_op_kind="tt.func", nested_module=true, location ) enzymexla.triton_call( @@ -1833,7 +1839,9 @@ function triton_call( grid_z.mlir_data, shmem.mlir_data, [Reactant.TracedUtils.get_mlir_data(a) for a in args]; - fn=MLIR.IR.FlatSymbolRefAttribute(name_to_call), + fn=MLIR.IR.SymbolRefAttribute( + mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)] + ), result_0=MLIR.IR.Type[], location, ) @@ -1871,7 +1879,9 @@ julia> Reactant.@jit( func_name="main", location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), ) - fn, name_to_call = _extract_function(code; func_name, func_op_kind="func.func") + fn, name_to_call, _ = _extract_function( + code; func_name, func_op_kind="func.func", location + ) ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) From 30976c08eab559298b8707832fa8c93be4a0df9d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 13:07:04 -0500 Subject: [PATCH 09/16] feat: new triton_ext dialect --- deps/ReactantExtra/BUILD | 18 +++++++ deps/ReactantExtra/make-bindings.jl | 1 + src/mlir/Dialects/EnzymeXLA.jl | 47 ----------------- src/mlir/Dialects/TritonExt.jl | 82 +++++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 47 deletions(-) create mode 100644 src/mlir/Dialects/TritonExt.jl diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 83946212e2..7c8ad1a6ed 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -1438,6 +1438,24 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "TritonExtJLIncGen", + tbl_outs = [ + ( + [ + "--generator=jl-op-defs", + "--disable-module-wrap=0", + ], + "TritonExt.jl", + ), + ], + tblgen = "//:mlir-jl-tblgen", + td_file = "@enzyme_ad//src/enzyme_ad/jax:Dialect/TritonExt/Ops.td", + deps = [ + "@enzyme_ad//src/enzyme_ad/jax:TritonExtDialectTdFiles", + ], +) + gentbl_cc_library( name = "TPUJLIncGen", tbl_outs = [ diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index f84309fef1..ebdb7cd9b0 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -42,6 +42,7 @@ for file in [ "MPI.jl", "MemRef.jl", "SparseTensor.jl", + "TritonExt.jl" ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl index d6a44439fe..ca396710ca 100755 --- a/src/mlir/Dialects/EnzymeXLA.jl +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -837,53 +837,6 @@ function stream2token(source::Value; result::IR.Type, location=Location()) ) end -function triton_call( - gridx::Value, - gridy::Value, - gridz::Value, - shmem::Value, - inputs::Vector{Value}; - result_0::Vector{IR.Type}, - fn, - backend_config=nothing, - operand_layouts=nothing, - result_layouts=nothing, - arg_attrs=nothing, - res_attrs=nothing, - output_operand_aliases=nothing, - xla_side_effect_free=nothing, - location=Location(), -) - op_ty_results = IR.Type[result_0...,] - operands = Value[gridx, gridy, gridz, shmem, inputs...] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] - !isnothing(backend_config) && - push!(attributes, namedattribute("backend_config", backend_config)) - !isnothing(operand_layouts) && - push!(attributes, namedattribute("operand_layouts", operand_layouts)) - !isnothing(result_layouts) && - push!(attributes, namedattribute("result_layouts", result_layouts)) - !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) - !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - !isnothing(output_operand_aliases) && - push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) - !isnothing(xla_side_effect_free) && - push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) - - return create_operation( - "enzymexla.triton_call", - location; - operands, - owned_regions, - successors, - attributes, - results=op_ty_results, - result_inference=false, - ) -end - function wrap( operand::Value; result=nothing::Union{Nothing,IR.Type}, diff --git a/src/mlir/Dialects/TritonExt.jl b/src/mlir/Dialects/TritonExt.jl new file mode 100644 index 0000000000..f59822b909 --- /dev/null +++ b/src/mlir/Dialects/TritonExt.jl @@ -0,0 +1,82 @@ +module triton_ext +using ...IR +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType +import ..Dialects: namedattribute, operandsegmentsizes +import ...API + +function call( + gridx::Value, + gridy::Value, + gridz::Value, + shmem::Value, + inputs::Vector{Value}; + result_0::Vector{IR.Type}, + fn, + backend_config=nothing, + operand_layouts=nothing, + result_layouts=nothing, + arg_attrs=nothing, + res_attrs=nothing, + output_operand_aliases=nothing, + xla_side_effect_free=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[gridx, gridy, gridz, shmem, inputs...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(backend_config) && + push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && + push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && + push!(attributes, namedattribute("result_layouts", result_layouts)) + !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) + !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) + !isnothing(output_operand_aliases) && + push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && + push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + return create_operation( + "triton_ext.call", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function module_(; sym_name, bodyRegion::Region, location=Location()) + op_ty_results = IR.Type[] + operands = Value[] + owned_regions = Region[bodyRegion,] + successors = Block[] + attributes = NamedAttribute[namedattribute("sym_name", sym_name),] + + return create_operation( + "triton_ext.module", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +end # triton_ext From 357f1c017cdacfe59492e0865286559080b28e7c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 15:02:27 -0500 Subject: [PATCH 10/16] feat: triton tracing works now finally --- docs/src/.vitepress/config.mts | 2 + docs/src/api/dialects/tritonext.md | 11 ++++ src/Compiler.jl | 86 ++++++++++++++++++++---------- src/Ops.jl | 60 +++++++++++++-------- 4 files changed, 110 insertions(+), 49 deletions(-) create mode 100644 docs/src/api/dialects/tritonext.md diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index dacce466fb..2853e7abd5 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -131,6 +131,7 @@ export default defineConfig({ { text: "SparseTensor", link: "/api/dialects/sparsetensor" }, { text: "StableHLO", link: "/api/dialects/stablehlo" }, { text: "Triton", link: "/api/dialects/triton" }, + { text: "TritonExt", link: "/api/dialects/tritonext" }, { text: "TPU", link: "/api/dialects/tpu" }, { text: "VHLO", link: "/api/dialects/vhlo" }, ], @@ -221,6 +222,7 @@ export default defineConfig({ { text: "SparseTensor", link: "/api/dialects/sparsetensor" }, { text: "StableHLO", link: "/api/dialects/stablehlo" }, { text: "Triton", link: "/api/dialects/triton" }, + { text: "TritonExt", link: "/api/dialects/tritonext" }, { text: "TPU", link: "/api/dialects/tpu" }, { text: "VHLO", link: "/api/dialects/vhlo" }, ], diff --git a/docs/src/api/dialects/tritonext.md b/docs/src/api/dialects/tritonext.md new file mode 100644 index 0000000000..a727f0dfbb --- /dev/null +++ b/docs/src/api/dialects/tritonext.md @@ -0,0 +1,11 @@ +```@meta +CollapsedDocStrings = true +``` + +# TritonExt Dialect + +Provides extensions to the Triton dialect. + +```@autodocs +Modules = [Reactant.MLIR.Dialects.triton_ext] +``` diff --git a/src/Compiler.jl b/src/Compiler.jl index 4bb540d83f..7ee65b163f 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1305,57 +1305,89 @@ function optimization_passes( end # https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc +# To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes function triton_optimization_passes() - # TODO: check that all triton passes are included here - return join( + all_passes = join( [ - # convert passes - "convert-scf-to-cf", - "convert-cf-to-llvm", - "convert-index-to-llvm", - "convert-arith-to-llvm", - "convert-nvvm-to-llvm", - # common passes "canonicalize", - # ttir passes + "triton-rewrite-tensor-pointer", + "canonicalize", "triton-combine", "triton-reorder-broadcast", - "triton-rewrite-tensor-pointer", - "triton-rewrite-tensor-descriptor-to-pointer", + "cse", + "symbol-dce", "triton-loop-unroll", - "triton-licm", - "triton-loop-aware-cse", - # TODO: should num-warps and num-ctas be set for each kernel? "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}", - # ttgir passes "tritongpu-coalesce", + "tritongpu-F32DotTC", + "triton-nvidia-gpu-plan-cta", + "tritongpu-remove-layout-conversions", "tritongpu-optimize-thread-locality", + "tritongpu-accelerate-matmul", + "tritongpu-remove-layout-conversions", + "tritongpu-optimize-dot-operands", + "canonicalize", + "triton-nvidia-optimize-descriptor-encoding", + "triton-loop-aware-cse", + "tritongpu-fuse-nested-loops", + "canonicalize", + "triton-licm", + "tritongpu-optimize-accumulator-init", "tritongpu-hoist-tmem-alloc", + "tritongpu-promote-lhs-to-tmem", "tritongpu-assign-latencies", - "tritongpu-pipeline", "tritongpu-schedule-loops", "tritongpu-automatic-warp-specialization", + "tritongpu-partition-scheduling", + "tritongpu-load-mma-specialization", + "tritongpu-rewrite-partition-dependencies", + "sccp", + "cse", + "tritongpu-partition-loops", + "tritongpu-optimize-partition-warps", + "tritongpu-schedule-loops", + "tritongpu-pipeline", + "tritongpu-combine-tensor-select-and-if", + "triton-nvidia-gpu-remove-tmem-tokens", + "canonicalize", + "triton-loop-aware-cse", "tritongpu-prefetch", - "tritongpu-accelerate-matmul", - "tritongpu-reorder-instructions", - "tritongpu-F32DotTC", "tritongpu-optimize-dot-operands", + "canonicalize", + "tritongpu-coalesce-async-copy", + "triton-nvidia-optimize-tmem-layouts", "tritongpu-remove-layout-conversions", + "triton-nvidia-interleave-tmem", "tritongpu-reduce-data-duplication", - "tritongpu-hoist-tmem-alloc", - "tritongpu-fuse-nested-loops", - "tritongpu-rewrite-partition-dependencies", - "tritongpu-partition-loops", + "tritongpu-reorder-instructions", + "triton-loop-aware-cse", + "symbol-dce", + "triton-nvidia-tma-lowering", + "triton-nvidia-gpu-fence-insertion", + "sccp", + "canonicalize", + "triton-nvidia-mma-lowering", "tritongpu-combine-tensor-select-and-if", - # ttgir to llvm passes "tritongpu-allocate-warp-groups", + "convert-scf-to-cf", "allocate-shared-memory", + "triton-tensor-memory-allocation", "tritongpu-global-scratch-memory-allocation", - "tritongpu-optimize-accumulator-init", - "tritongpu-coalesce-async-copy", + # TODO: register the commented out passes + # "convert-triton-gpu-to-llvm", + "canonicalize", + "cse", + # "convert-nv-gpu-to-llvm", + # "convert-warp-specialize-to-llvm", + "reconcile-unrealized-casts", + "canonicalize", + "cse", + "symbol-dce", + "enable-line-info", ], ",", ) + return "triton_ext.module(builtin.module($(all_passes)))" end # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate diff --git a/src/Ops.jl b/src/Ops.jl index e526371807..d426a91ea3 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3,7 +3,7 @@ # Julia and Reactant semantics should be considered on the higher abstractions that use these module Ops using ..MLIR: MLIR -using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla +using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla, triton_ext using ..Reactant: Reactant, TracedRArray, @@ -1749,7 +1749,6 @@ function _extract_function( code::String; func_name::String="main", func_op_kind::String="func.func", - nested_module::Bool=false, location::MLIR.IR.Location=MLIR.IR.Location(), ) module_suffix = string(hash(code); base=16) @@ -1757,24 +1756,45 @@ function _extract_function( mod_name = func_name * "_module_" * module_suffix symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - if nested_module + use_ttext_module = split(func_op_kind, ".")[1] == "tt" + + if use_ttext_module + tt_mod_name = func_name * "_tt_module_" * module_suffix + tt_region = MLIR.IR.Region() + tt_block = MLIR.IR.Block() + push!(tt_region, tt_block) + triton_mod_op = triton_ext.module_(; + location, bodyRegion=tt_region, sym_name=tt_mod_name + ) + MLIR.IR.rmfromparent!(triton_mod_op) + push!(MLIR.IR.body(MLIR.IR.mmodule()), triton_mod_op) # insert into parent module + region = MLIR.IR.Region() push!(region, MLIR.IR.Block()) moduleop = MLIR.Dialects.builtin.module_(; location, bodyRegion=region, sym_name=mod_name ) MLIR.IR.rmfromparent!(moduleop) - push!(MLIR.IR.body(MLIR.IR.mmodule()), moduleop) # insert into parent module + push!(tt_block, moduleop) # insert into triton module top_level_block = MLIR.IR.Block( MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false ) fn = nothing + + symref = MLIR.IR.SymbolRefAttribute( + tt_mod_name, + MLIR.IR.Attribute[ + MLIR.IR.FlatSymbolRefAttribute(mod_name), + MLIR.IR.FlatSymbolRefAttribute(name_to_call), + ], + ) else current_module = MLIR.IR.mmodule() moduleop = MLIR.IR.Operation(current_module) top_level_block = MLIR.IR.body(current_module) fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call) + symref = MLIR.IR.FlatSymbolRefAttribute(name_to_call) end if isnothing(fn) @@ -1795,12 +1815,14 @@ function _extract_function( ) @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" - # Set function private - MLIR.IR.attr!( - op, - MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), - MLIR.IR.Attribute("private"), - ) + if !use_ttext_module + # Set function private + MLIR.IR.attr!( + op, + MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), + MLIR.IR.Attribute("private"), + ) + end # Change function name MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) @@ -1815,7 +1837,7 @@ function _extract_function( error("hlo_call: could not find function $func_name in the provided module") end - return fn, name_to_call, mod_name + return fn, symref end function triton_call( @@ -1829,19 +1851,15 @@ function triton_call( location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) - _, name_to_call, mod_name = _extract_function( - mlir_code; func_name, func_op_kind="tt.func", nested_module=true, location - ) + _, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location) - enzymexla.triton_call( + triton_ext.call( grid_x.mlir_data, grid_y.mlir_data, grid_z.mlir_data, shmem.mlir_data, [Reactant.TracedUtils.get_mlir_data(a) for a in args]; - fn=MLIR.IR.SymbolRefAttribute( - mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)] - ), + fn=symref, result_0=MLIR.IR.Type[], location, ) @@ -1879,9 +1897,7 @@ julia> Reactant.@jit( func_name="main", location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), ) - fn, name_to_call, _ = _extract_function( - code; func_name, func_op_kind="func.func", location - ) + fn, symref = _extract_function(code; func_name, func_op_kind="func.func", location) ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) @@ -1898,7 +1914,7 @@ julia> Reactant.@jit( call = MLIR.Dialects.func.call( operands; result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], - callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call), + callee=symref, location, ) From e869448532b330b2d435989d295085adddb67524 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 16:10:43 -0500 Subject: [PATCH 11/16] fix: kind of working --- deps/ReactantExtra/make-bindings.jl | 2 +- src/Compiler.jl | 70 ++++++++++++++++++----------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index ebdb7cd9b0..9e4295e9cb 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -42,7 +42,7 @@ for file in [ "MPI.jl", "MemRef.jl", "SparseTensor.jl", - "TritonExt.jl" + "TritonExt.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/src/Compiler.jl b/src/Compiler.jl index 7ee65b163f..c17d030c98 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -702,6 +702,7 @@ function optimization_passes( lower_comms::Bool=true, max_constant_threshold::Int=1024, backend::String="gpu", + enable_triton_passes::Bool=false, ) transform_passes_list = [ "patterns=compare_op_canon<16>", @@ -1298,7 +1299,7 @@ function optimization_passes( push!(passes, "remove-duplicate-func-def") end push!(passes, func_passes) - if backend == "cuda" + if enable_triton_passes && backend == "cuda" push!(passes, triton_optimization_passes()) end return join(passes, ',') @@ -1373,12 +1374,11 @@ function triton_optimization_passes() "allocate-shared-memory", "triton-tensor-memory-allocation", "tritongpu-global-scratch-memory-allocation", - # TODO: register the commented out passes - # "convert-triton-gpu-to-llvm", + "convert-triton-gpu-to-llvm", "canonicalize", "cse", - # "convert-nv-gpu-to-llvm", - # "convert-warp-specialize-to-llvm", + "convert-nv-gpu-to-llvm", + "convert-warp-specialize-to-llvm", "reconcile-unrealized-casts", "canonicalize", "cse", @@ -1774,10 +1774,28 @@ function compile_mlir!( end opt_passes = optimization_passes( - compile_options; sroa=true, recognize_comms, lower_comms, backend + compile_options; + sroa=true, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=false, ) opt_passes2 = optimization_passes( - compile_options; sroa=false, recognize_comms, lower_comms, backend + compile_options; + sroa=false, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=false, + ) + opt_passes3 = optimization_passes( + compile_options; + sroa=false, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=true, ) raise_passes = if raise isa String @@ -1792,7 +1810,7 @@ function compile_mlir!( opt_passes2 if DUS_TO_CONCAT[] - opt_passes3 = optimization_passes( + opt_passes_dus_to_concat = optimization_passes( compile_options; sroa=false, dus_to_concat=true, @@ -1800,7 +1818,7 @@ function compile_mlir!( lower_comms, backend, ) - result = result * "," * opt_passes3 + result = result * "," * opt_passes_dus_to_concat end result else @@ -1831,12 +1849,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, lower_enzymexla_linalg_pass, jit, ] @@ -1847,12 +1865,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, kern, raise_passes, lower_enzymexla_linalg_pass, @@ -1876,12 +1894,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, ] end, ',', @@ -1901,12 +1919,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, ] else [ @@ -1915,12 +1933,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, kern, raise_passes, ] @@ -1942,12 +1960,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, kern, ] end, @@ -1965,12 +1983,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes3, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, ], ',', ), @@ -2007,7 +2025,7 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, lower_enzymexla_linalg_pass, jit, ] @@ -2020,7 +2038,7 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes3, kern, raise_passes, lower_enzymexla_linalg_pass, @@ -2231,7 +2249,7 @@ function compile_mlir!( run_pass_pipeline!( mod, join( - [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2], + [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3], ",", ), "mid_pad_opts", From 9f1cb47e6b51b8c9b212bd6beae928bcefb8bf0b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 20:41:07 -0500 Subject: [PATCH 12/16] fix: new API --- ext/ReactantPythonCallExt/pycall.jl | 20 +++++++++++----- src/Compiler.jl | 36 ++++++++++++++--------------- src/Ops.jl | 8 +++++-- src/mlir/Dialects/TritonExt.jl | 6 +++-- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 40026af81f..0788daef7e 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -47,9 +47,8 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end -# TODO: support using metaparams here -normalize_grid(grid::Integer) = normalize_grid((grid,)) -function normalize_grid(grid::Dims{N}) where {N} +normalize_grid_and_blocks(grid::Integer) = normalize_grid_and_blocks((grid,)) +function normalize_grid_and_blocks(grid::Dims{N}) where {N} @assert N <= 3 @assert all(grid .> 0) return (grid..., ntuple(_ -> 1, 3 - N)...) @@ -62,11 +61,18 @@ signature_string(x) = error("Unsupported argument type: $(typeof(x))") # TODO: better name for hints? function overlayed_pycall_with_triton( - kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing + kernel::Py, + args...; + grid, + blocks, + num_warps::Integer=1, + num_stages::Integer=3, + hints=nothing, ) triton = tritonptr[] - grid = normalize_grid(grid) + grid = normalize_grid_and_blocks(grid) + blocks = normalize_grid_and_blocks(blocks) mapped = map(signature_string, args) signature = first.(mapped) @@ -121,7 +127,9 @@ function overlayed_pycall_with_triton( grid_x=@opcall(constant(grid[1])), grid_y=@opcall(constant(grid[2])), grid_z=@opcall(constant(grid[3])), - shmem=@opcall(constant(pyconvert(Int, ccinfo.metadata.shared))), + block_x=@opcall(constant(blocks[1])), + block_y=@opcall(constant(blocks[2])), + block_z=@opcall(constant(blocks[3])), ) return nothing diff --git a/src/Compiler.jl b/src/Compiler.jl index c17d030c98..457100a7f0 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1789,7 +1789,7 @@ function compile_mlir!( backend, enable_triton_passes=false, ) - opt_passes3 = optimization_passes( + opt_passes_with_triton = optimization_passes( compile_options; sroa=false, recognize_comms, @@ -1849,12 +1849,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, lower_enzymexla_linalg_pass, jit, ] @@ -1865,12 +1865,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, kern, raise_passes, lower_enzymexla_linalg_pass, @@ -1894,12 +1894,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, ] end, ',', @@ -1919,12 +1919,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, ] else [ @@ -1933,12 +1933,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, kern, raise_passes, ] @@ -1960,12 +1960,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, kern, ] end, @@ -1983,12 +1983,12 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes3, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes2, ], ',', ), @@ -2025,7 +2025,7 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes_with_triton, lower_enzymexla_linalg_pass, jit, ] @@ -2038,7 +2038,7 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes3, + opt_passes_with_triton, kern, raise_passes, lower_enzymexla_linalg_pass, @@ -2249,7 +2249,7 @@ function compile_mlir!( run_pass_pipeline!( mod, join( - [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes3], + [opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2], ",", ), "mid_pad_opts", diff --git a/src/Ops.jl b/src/Ops.jl index d426a91ea3..eac50b8194 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1847,7 +1847,9 @@ function triton_call( grid_x::TracedRNumber{<:Integer}, grid_y::TracedRNumber{<:Integer}, grid_z::TracedRNumber{<:Integer}, - shmem::TracedRNumber{<:Integer}, + block_x::TracedRNumber{<:Integer}, + block_y::TracedRNumber{<:Integer}, + block_z::TracedRNumber{<:Integer}, location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), # TODO: other kwargs ) @@ -1857,7 +1859,9 @@ function triton_call( grid_x.mlir_data, grid_y.mlir_data, grid_z.mlir_data, - shmem.mlir_data, + block_x.mlir_data, + block_y.mlir_data, + block_z.mlir_data, [Reactant.TracedUtils.get_mlir_data(a) for a in args]; fn=symref, result_0=MLIR.IR.Type[], diff --git a/src/mlir/Dialects/TritonExt.jl b/src/mlir/Dialects/TritonExt.jl index f59822b909..bb79bade44 100644 --- a/src/mlir/Dialects/TritonExt.jl +++ b/src/mlir/Dialects/TritonExt.jl @@ -17,7 +17,9 @@ function call( gridx::Value, gridy::Value, gridz::Value, - shmem::Value, + blockx::Value, + blocky::Value, + blockz::Value, inputs::Vector{Value}; result_0::Vector{IR.Type}, fn, @@ -31,7 +33,7 @@ function call( location=Location(), ) op_ty_results = IR.Type[result_0...,] - operands = Value[gridx, gridy, gridz, shmem, inputs...] + operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, inputs...] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("fn", fn),] From e8a3e1dde4d7c9cd9361c36fb5a27e2e61fe037c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 29 Sep 2025 21:48:39 -0500 Subject: [PATCH 13/16] feat: return values --- ext/ReactantPythonCallExt/pycall.jl | 4 +--- src/Ops.jl | 37 +++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 0788daef7e..f99a30b8ac 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -120,7 +120,7 @@ function overlayed_pycall_with_triton( # we are compiling here + lowering again inside enzymejax ccinfo = triton.compile(src; target=target, options=options.__dict__) - @opcall triton_call( + return @opcall triton_call( pyconvert(String, ccinfo.asm["source"]), filter(x -> x isa Reactant.TracedType, args)...; func_name=pyconvert(String, ccinfo.metadata.name), @@ -131,6 +131,4 @@ function overlayed_pycall_with_triton( block_y=@opcall(constant(blocks[2])), block_z=@opcall(constant(blocks[3])), ) - - return nothing end diff --git a/src/Ops.jl b/src/Ops.jl index eac50b8194..7cebe233bb 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1851,11 +1851,28 @@ function triton_call( block_y::TracedRNumber{<:Integer}, block_z::TracedRNumber{<:Integer}, location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), - # TODO: other kwargs ) _, symref = _extract_function(mlir_code; func_name, func_op_kind="tt.func", location) - triton_ext.call( + result_types = MLIR.IR.Type[] + output_operand_aliases = MLIR.IR.Attribute[] + output_to_arg = Int[] + for (i, arg) in enumerate(args) + if arg isa TracedRArray + push!(result_types, mlir_type(typeof(arg), size(arg))) + push!( + output_operand_aliases, + MLIR.IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 0, C_NULL, Int64(i - 1), 0, C_NULL + ), + ), + ) + push!(output_to_arg, i) + end + end + + results = triton_ext.call( grid_x.mlir_data, grid_y.mlir_data, grid_z.mlir_data, @@ -1864,11 +1881,23 @@ function triton_call( block_z.mlir_data, [Reactant.TracedUtils.get_mlir_data(a) for a in args]; fn=symref, - result_0=MLIR.IR.Type[], + result_0=result_types, location, + output_operand_aliases, ) - return nothing + array_results = () + for i in 1:MLIR.IR.nresults(results) + arg = args[output_to_arg[i]] + array_results = ( + array_results..., + Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}( + (), MLIR.IR.result(results, i), size(arg) + ), + ) + end + length(array_results) == 1 && return array_results[1] + return array_results end """ From d5438754dda97a6ec03cf6c0cf00835a5bd89e11 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 4 Oct 2025 15:41:53 -0500 Subject: [PATCH 14/16] feat: lowering triton now works --- src/CompileOptions.jl | 1 + src/Compiler.jl | 44 ++++++++++++++++++++++++++++++++++++++----- src/Ops.jl | 2 +- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index e8cac78be6..f70e63f0be 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -229,6 +229,7 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :no_triton, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index 457100a7f0..c004986aa0 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1844,12 +1844,14 @@ function compile_mlir!( [ "mark-func-memory-effects", opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes_with_triton, + opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", @@ -1871,6 +1873,7 @@ function compile_mlir!( "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, @@ -1881,6 +1884,31 @@ function compile_mlir!( ), "all", ) + elseif compile_options.optimization_passes === :no_triton + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + ["mark-func-memory-effects", opt_passes] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + legalize_chlo_to_stablehlo..., + opt_passes2, + ] + end, + ',', + ), + "before_kernel", + ) elseif compile_options.optimization_passes === :before_kernel run_pass_pipeline!( mod, @@ -1913,13 +1941,14 @@ function compile_mlir!( if compile_options.raise_first [ "mark-func-memory-effects", - opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes_with_triton, + opt_passes2, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", @@ -1939,6 +1968,7 @@ function compile_mlir!( "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, raise_passes, ] @@ -1966,6 +1996,7 @@ function compile_mlir!( "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, ] end, @@ -2039,6 +2070,7 @@ function compile_mlir!( "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes_with_triton, + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, @@ -2056,7 +2088,8 @@ function compile_mlir!( if compile_options.raise_first [ "mark-func-memory-effects", - opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", @@ -2071,9 +2104,10 @@ function compile_mlir!( "mark-func-memory-effects", opt_passes, "enzyme-batch", - opt_passes2, + opt_passes_with_triton, enzyme_pass, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math", + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, diff --git a/src/Ops.jl b/src/Ops.jl index 7cebe233bb..2ea731cbe8 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1864,7 +1864,7 @@ function triton_call( output_operand_aliases, MLIR.IR.Attribute( MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), 0, C_NULL, Int64(i - 1), 0, C_NULL + MLIR.IR.context(), 1, Int64[i - 1], Int64(i - 1), 0, C_NULL ), ), ) From 47082f94a0a3d5756714549c26e2b68e33d691d3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 4 Oct 2025 17:28:19 -0500 Subject: [PATCH 15/16] feat: triton working end to end --- src/Ops.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 2ea731cbe8..c84f4b67b6 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1889,12 +1889,11 @@ function triton_call( array_results = () for i in 1:MLIR.IR.nresults(results) arg = args[output_to_arg[i]] - array_results = ( - array_results..., - Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}( - (), MLIR.IR.result(results, i), size(arg) - ), + res = Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}( + (), MLIR.IR.result(results, i), size(arg) ) + copyto!(arg, res) + array_results = (array_results..., res) end length(array_results) == 1 && return array_results[1] return array_results From 38bbe42a1fd7eea635a40bf3e7eec38b1848025a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 9 Oct 2025 12:32:20 -0400 Subject: [PATCH 16/16] chore: update commit --- deps/ReactantExtra/WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 0a5498a540..f3debba053 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "bb5eb26b2ddc5bbb77e8ff22b8ef2499473c5f5e" +ENZYMEXLA_COMMIT = "fc05b4453eaab6f9941b296a38d2bb4770bd806e" ENZYMEXLA_SHA256 = ""