@@ -7,12 +7,13 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe
7
7
)
8
8
end
9
9
10
- function overlayed_pycall (f:: Py , args... )
10
+ function overlayed_pycall (f:: Py , args... ; kwargs ... )
11
11
@assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[]
12
12
# TODO : check for Autotuner and Heutistics as well
13
13
if TRITON_COMPILE_SUPPORTED[] && pyisinstance (f, tritonptr[]. JITFunction)
14
- return overlayed_pycall_with_triton (f, args... )
14
+ return overlayed_pycall_with_triton (f, args... ; kwargs ... )
15
15
else
16
+ @assert isempty (kwargs) " `kwargs` are not supported for jax traced functions."
16
17
return overlayed_pycall_with_jax_tracing (f, args... )
17
18
end
18
19
end
@@ -46,6 +47,69 @@ function overlayed_pycall_with_jax_tracing(f::Py, args...)
46
47
return length (res) == 0 ? nothing : (length (res) == 1 ? res[1 ] : res)
47
48
end
48
49
49
- function overlayed_pycall_with_triton (f:: Py , args... )
50
- error (" TODO: implement triton" )
50
+ # TODO : support using metaparams here
51
+ normalize_grid (grid:: Integer ) = normalize_grid ((grid,))
52
+ function normalize_grid (grid:: Dims{N} ) where {N}
53
+ @assert N <= 3
54
+ @assert all (grid .> 0 )
55
+ return (grid... , ntuple (_ -> 1 , 3 - N)... )
56
+ end
57
+
58
+ signature_string (:: TracedRArray{T} ) where {T} = " *$(MLIR_TYPE_STRING[T]) " , nothing
59
+ signature_string (:: TracedRNumber{T} ) where {T} = " $(MLIR_TYPE_STRING[T]) " , nothing
60
+ signature_string (x:: T ) where {T<: Number } = string (x), x
61
+ signature_string (x) = error (" Unsupported argument type: $(typeof (x)) " )
62
+
63
+ function overlayed_pycall_with_triton (
64
+ kernel:: Py , args... ; grid, num_warps:: Integer = 1 , num_stages:: Integer = 3 , hints= nothing
65
+ )
66
+ triton = tritonptr[]
67
+
68
+ grid = normalize_grid (grid)
69
+
70
+ mapped = map (signature_string, args)
71
+ signature = first .(mapped)
72
+ # TODO : are hints actually correctly set?
73
+ hints =
74
+ hints === nothing ? Dict () : Dict (kernel. arg_names[i - 1 ] => v for (i, v) in hints)
75
+ constants = Dict (
76
+ kernel. arg_names[i - 1 ] => constant for
77
+ (i, constant) in enumerate (last .(mapped)) if constant != = nothing
78
+ )
79
+ for (k, v) in hints
80
+ v == 1 && (constants[kernel. arg_names[k - 1 ]] = v)
81
+ end
82
+ attrs = Dict (k => [[" tt.divisibility" , 16 ]] for (k, v) in hints if v == 16 )
83
+
84
+ sigmap = Dict (kernel. arg_names[i - 1 ] => sig for (i, sig) in enumerate (signature))
85
+ for k in keys (constants)
86
+ sigmap[k] = " constexpr"
87
+ end
88
+
89
+ for h in values (hints)
90
+ @assert h in (1 , 16 ) " Only 1 and 16 are valid hints, got $h "
91
+ end
92
+ attrs = Dict (k => [[" tt.divisibility" , 16 ]] for (k, v) in hints if v == 16 )
93
+
94
+ src = triton. compiler. ASTSource (;
95
+ fn= kernel, constexprs= constants, signature= sigmap, attrs= attrs
96
+ )
97
+
98
+ # TODO : check that we are using CUDA. Get compute_capability from the target
99
+ target = triton. backends. compiler. GPUTarget (" cuda" , 80 , 32 )
100
+ backend = triton. compiler. make_backend (target)
101
+ options = backend. parse_options (
102
+ pydict (
103
+ " num_warps" => num_warps,
104
+ " num_stages" => num_stages,
105
+ " extern_libs" => pytuple ((pytuple ((" libdevice" , Reactant_jll. libdevice)),)),
106
+ ),
107
+ )
108
+
109
+ ccinfo = triton. compile (src; target= target, options= options. __dict__)
110
+
111
+ println (pyconvert (String, ccinfo. asm[" source" ]))
112
+ println (pyconvert (String, ccinfo. asm[" ttir" ]))
113
+
114
+ return error (" TODO: implement triton" )
51
115
end
0 commit comments