Skip to content

Commit 790db52

Browse files
committed
feat: auto-trace triton code
1 parent 90776d2 commit 790db52

File tree

4 files changed

+97
-9
lines changed

4 files changed

+97
-9
lines changed

CondaPkg.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
jax = ">= 0.6"
33
tensorflow = ">= 2.17"
44
numpy = ">= 2"
5-
triton = "" # TODO: version bound
5+
triton = ">= 3.4"

ext/ReactantPythonCallExt/ReactantPythonCallExt.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
module ReactantPythonCallExt
22

3-
using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance
3+
using PythonCall:
4+
PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance, pytuple
45
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
56
using Reactant.Ops: @opcall
7+
using Reactant_jll: Reactant_jll
68

79
const jaxptr = Ref{Py}()
810
const jnpptr = Ref{Py}()
@@ -37,6 +39,28 @@ const NUMPY_SIMPLE_TYPES = Dict(
3739
ComplexF64 => :complex64,
3840
)
3941

42+
const MLIR_TYPE_STRING = Dict(
43+
Float64 => "fp64",
44+
Float32 => "fp32",
45+
Float16 => "fp16",
46+
Int64 => "i64",
47+
Int32 => "i32",
48+
Int16 => "i16",
49+
Int8 => "i8",
50+
UInt64 => "ui64",
51+
UInt32 => "ui32",
52+
UInt16 => "ui16",
53+
UInt8 => "ui8",
54+
Bool => "i1",
55+
Reactant.F8E4M3FN => "fp8e4nv",
56+
Reactant.F8E5M2FNUZ => "fp8e5b16",
57+
Reactant.F8E4M3FNUZ => "fp8e4b8",
58+
Reactant.F8E5M2 => "fp8e5",
59+
)
60+
if isdefined(Core, :BFloat16)
61+
MLIR_TYPE_STRING[Core.BFloat16] = "bf16"
62+
end
63+
4064
function __init__()
4165
try
4266
jaxptr[] = pyimport("jax")
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
@reactant_overlay function PythonCall.pycall(f::Py, args...)
1+
@reactant_overlay function PythonCall.pycall(f::Py, args...; kwargs...)
22
if Reactant.looped_any(Reactant.use_overlayed_version, args)
3-
return overlayed_pycall(f, args...)
3+
return overlayed_pycall(f, args...; kwargs...)
44
else
5-
return Base.inferencebarrier(PythonCall.pycall)(f, args...)
5+
return Base.inferencebarrier(PythonCall.pycall)(f, args...; kwargs...)
66
end
77
end

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe
77
)
88
end
99

10-
function overlayed_pycall(f::Py, args...)
10+
function overlayed_pycall(f::Py, args...; kwargs...)
1111
@assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[]
1212
# TODO: check for Autotuner and Heutistics as well
1313
if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction)
14-
return overlayed_pycall_with_triton(f, args...)
14+
return overlayed_pycall_with_triton(f, args...; kwargs...)
1515
else
16+
@assert isempty(kwargs) "`kwargs` are not supported for jax traced functions."
1617
return overlayed_pycall_with_jax_tracing(f, args...)
1718
end
1819
end
@@ -46,6 +47,69 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
4647
return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res)
4748
end
4849

49-
function overlayed_pycall_with_triton(f::Py, args...)
50-
error("TODO: implement triton")
50+
# TODO: support using metaparams here
51+
normalize_grid(grid::Integer) = normalize_grid((grid,))
52+
function normalize_grid(grid::Dims{N}) where {N}
53+
@assert N <= 3
54+
@assert all(grid .> 0)
55+
return (grid..., ntuple(_ -> 1, 3 - N)...)
56+
end
57+
58+
signature_string(::TracedRArray{T}) where {T} = "*$(MLIR_TYPE_STRING[T])", nothing
59+
signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothing
60+
signature_string(x::T) where {T<:Number} = string(x), x
61+
signature_string(x) = error("Unsupported argument type: $(typeof(x))")
62+
63+
function overlayed_pycall_with_triton(
64+
kernel::Py, args...; grid, num_warps::Integer=1, num_stages::Integer=3, hints=nothing
65+
)
66+
triton = tritonptr[]
67+
68+
grid = normalize_grid(grid)
69+
70+
mapped = map(signature_string, args)
71+
signature = first.(mapped)
72+
# TODO: are hints actually correctly set?
73+
hints =
74+
hints === nothing ? Dict() : Dict(kernel.arg_names[i - 1] => v for (i, v) in hints)
75+
constants = Dict(
76+
kernel.arg_names[i - 1] => constant for
77+
(i, constant) in enumerate(last.(mapped)) if constant !== nothing
78+
)
79+
for (k, v) in hints
80+
v == 1 && (constants[kernel.arg_names[k - 1]] = v)
81+
end
82+
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)
83+
84+
sigmap = Dict(kernel.arg_names[i - 1] => sig for (i, sig) in enumerate(signature))
85+
for k in keys(constants)
86+
sigmap[k] = "constexpr"
87+
end
88+
89+
for h in values(hints)
90+
@assert h in (1, 16) "Only 1 and 16 are valid hints, got $h"
91+
end
92+
attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16)
93+
94+
src = triton.compiler.ASTSource(;
95+
fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs
96+
)
97+
98+
# TODO: check that we are using CUDA. Get compute_capability from the target
99+
target = triton.backends.compiler.GPUTarget("cuda", 80, 32)
100+
backend = triton.compiler.make_backend(target)
101+
options = backend.parse_options(
102+
pydict(
103+
"num_warps" => num_warps,
104+
"num_stages" => num_stages,
105+
"extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)),
106+
),
107+
)
108+
109+
ccinfo = triton.compile(src; target=target, options=options.__dict__)
110+
111+
println(pyconvert(String, ccinfo.asm["source"]))
112+
println(pyconvert(String, ccinfo.asm["ttir"]))
113+
114+
return error("TODO: implement triton")
51115
end

0 commit comments

Comments
 (0)