Skip to content

Commit a7ece19

Browse files
committed
feat: copy tt.func into main module [skip ci]
1 parent 790db52 commit a7ece19

File tree

2 files changed

+70
-42
lines changed

2 files changed

+70
-42
lines changed

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,18 @@ function overlayed_pycall_with_triton(
108108

109109
ccinfo = triton.compile(src; target=target, options=options.__dict__)
110110

111-
println(pyconvert(String, ccinfo.asm["source"]))
112-
println(pyconvert(String, ccinfo.asm["ttir"]))
111+
@show ccinfo.metadata
112+
@show ccinfo.asm.keys()
113+
# shared = ccinfo.metadata["shared"]
114+
kernel_name = pyconvert(String, ccinfo.metadata.name)
115+
# cluster_dims = ccinfo.metadata["cluster_dims"]
116+
117+
# println(pyconvert(String, ccinfo.asm["source"]))
118+
# println(pyconvert(String, ccinfo.asm["ttir"]))
119+
120+
res = @opcall triton_call(
121+
pyconvert(String, ccinfo.asm["ttir"]), args...; func_name=kernel_name
122+
)
113123

114124
return error("TODO: implement triton")
115125
end

src/Ops.jl

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,68 +1698,38 @@ end
16981698
end
16991699

17001700
# Generate a unique name given a module hash and a function name.
1701-
function _hlo_call_name(orig_name, module_suffix)
1702-
return orig_name * "_hlo_call_" * module_suffix
1703-
end
1701+
_new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix
17041702

1705-
"""
1706-
hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1707-
1708-
Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1709-
with the provided arguments and return a tuple for each result of the call.
1710-
1711-
```julia-repl
1712-
julia> Reactant.@jit(
1713-
hlo_call(
1714-
\"\"\"
1715-
module {
1716-
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1717-
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1718-
return %0 : tensor<3xf32>
1719-
}
1720-
}
1721-
\"\"\",
1722-
Reactant.to_rarray(Float32[1, 2, 3]),
1723-
Reactant.to_rarray(Float32[1, 2, 3]),
1724-
)
1725-
)
1726-
(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1727-
```
1728-
"""
1729-
@noinline function hlo_call(
1730-
code,
1731-
args...;
1732-
func_name="main",
1733-
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
1703+
function _extract_function(
1704+
code::String; func_name::String="main", func_op_kind::String="func.func"
17341705
)
17351706
module_suffix = string(hash(code); base=16)
1736-
name_to_call = _hlo_call_name(func_name, module_suffix)
1707+
name_to_call = _new_function_name(func_name, module_suffix)
17371708

17381709
current_module = MLIR.IR.mmodule()
17391710
top_level_block = MLIR.IR.body(current_module)
17401711

17411712
symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())
1742-
17431713
fn = MLIR.IR.lookup(
17441714
MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call
17451715
)
1716+
17461717
if isnothing(fn)
17471718
new_mod = parse(MLIR.IR.Module, code)
17481719
new_mod_op = MLIR.IR.Operation(new_mod)
17491720
body = MLIR.IR.body(new_mod)
17501721

17511722
operations = collect(MLIR.IR.OperationIterator(body))
17521723
for op in operations
1753-
if MLIR.IR.name(op) == "func.func"
1724+
if MLIR.IR.name(op) == func_op_kind
17541725
fn_name = String(MLIR.IR.attr(op, symbol_attr_name))
17551726
if fn_name == func_name
17561727
fn = op
17571728
end
17581729

1759-
new_name = _hlo_call_name(fn_name, module_suffix)
17601730
res = MLIR.IR.LogicalResult(
17611731
MLIR.API.mlirSymbolTableReplaceAllSymbolUses(
1762-
fn_name, new_name, new_mod_op
1732+
fn_name, name_to_call, new_mod_op
17631733
),
17641734
)
17651735
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"
@@ -1772,7 +1742,7 @@ julia> Reactant.@jit(
17721742
)
17731743

17741744
# Change function name
1775-
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name))
1745+
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call))
17761746
end
17771747
end
17781748

@@ -1786,11 +1756,59 @@ julia> Reactant.@jit(
17861756
error("hlo_call: could not find function $func_name in the provided module")
17871757
end
17881758

1759+
return name_to_call
1760+
end
1761+
1762+
function triton_call(
1763+
mlir_code::String,
1764+
args::Union{TracedRArray,TracedRNumber,Number}...;
1765+
func_name::String="main",
1766+
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
1767+
)
1768+
name_to_call = _extract_function(mlir_code; func_name, func_op_kind="tt.func")
1769+
1770+
@show name_to_call
1771+
display(MLIR.IR.mmodule())
1772+
1773+
error("TODO: implement triton_call")
1774+
end
1775+
1776+
"""
1777+
hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1778+
1779+
Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1780+
with the provided arguments and return a tuple for each result of the call.
1781+
1782+
```julia-repl
1783+
julia> Reactant.@jit(
1784+
hlo_call(
1785+
\"\"\"
1786+
module {
1787+
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1788+
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1789+
return %0 : tensor<3xf32>
1790+
}
1791+
}
1792+
\"\"\",
1793+
Reactant.to_rarray(Float32[1, 2, 3]),
1794+
Reactant.to_rarray(Float32[1, 2, 3]),
1795+
)
1796+
)
1797+
(ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1798+
```
1799+
"""
1800+
@noinline function hlo_call(
1801+
code,
1802+
args::Union{TracedRArray,TracedRNumber}...;
1803+
func_name="main",
1804+
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
1805+
)
1806+
name_to_call = _extract_function(code; func_name, func_op_kind="func.func")
1807+
17891808
ftype_attr = MLIR.IR.attr(fn, "function_type")
17901809
ftype = MLIR.IR.Type(ftype_attr)
17911810

1792-
@assert all(Base.Fix2(isa, Union{TracedRArray,TracedRNumber}), args) "hlo_call: all inputs to hlo_call should be reactant arrays or numbers"
1793-
@assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name"
1811+
@assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name. Expected $(MLIR.IR.ninputs(ftype)), got $(length(args))"
17941812

17951813
for (i, arg) in enumerate(args)
17961814
expected_type = MLIR.IR.input(ftype, i)

0 commit comments

Comments
 (0)