@@ -1698,68 +1698,38 @@ end
1698
1698
end
1699
1699
1700
1700
# Generate a unique name given a module hash and a function name.
1701
- function _hlo_call_name (orig_name, module_suffix)
1702
- return orig_name * " _hlo_call_" * module_suffix
1703
- end
1701
+ _new_function_name (orig_name, module_suffix) = orig_name * " _call_" * module_suffix
1704
1702
1705
- """
1706
- hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1707
-
1708
- Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1709
- with the provided arguments and return a tuple for each result of the call.
1710
-
1711
- ```julia-repl
1712
- julia> Reactant.@jit(
1713
- hlo_call(
1714
- \"\"\"
1715
- module {
1716
- func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1717
- %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1718
- return %0 : tensor<3xf32>
1719
- }
1720
- }
1721
- \"\"\" ,
1722
- Reactant.to_rarray(Float32[1, 2, 3]),
1723
- Reactant.to_rarray(Float32[1, 2, 3]),
1724
- )
1725
- )
1726
- (ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1727
- ```
1728
- """
1729
- @noinline function hlo_call (
1730
- code,
1731
- args... ;
1732
- func_name= " main" ,
1733
- location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
1703
+ function _extract_function (
1704
+ code:: String ; func_name:: String = " main" , func_op_kind:: String = " func.func"
1734
1705
)
1735
1706
module_suffix = string (hash (code); base= 16 )
1736
- name_to_call = _hlo_call_name (func_name, module_suffix)
1707
+ name_to_call = _new_function_name (func_name, module_suffix)
1737
1708
1738
1709
current_module = MLIR. IR. mmodule ()
1739
1710
top_level_block = MLIR. IR. body (current_module)
1740
1711
1741
1712
symbol_attr_name = String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ())
1742
-
1743
1713
fn = MLIR. IR. lookup (
1744
1714
MLIR. IR. SymbolTable (MLIR. IR. Operation (current_module)), name_to_call
1745
1715
)
1716
+
1746
1717
if isnothing (fn)
1747
1718
new_mod = parse (MLIR. IR. Module, code)
1748
1719
new_mod_op = MLIR. IR. Operation (new_mod)
1749
1720
body = MLIR. IR. body (new_mod)
1750
1721
1751
1722
operations = collect (MLIR. IR. OperationIterator (body))
1752
1723
for op in operations
1753
- if MLIR. IR. name (op) == " func.func "
1724
+ if MLIR. IR. name (op) == func_op_kind
1754
1725
fn_name = String (MLIR. IR. attr (op, symbol_attr_name))
1755
1726
if fn_name == func_name
1756
1727
fn = op
1757
1728
end
1758
1729
1759
- new_name = _hlo_call_name (fn_name, module_suffix)
1760
1730
res = MLIR. IR. LogicalResult (
1761
1731
MLIR. API. mlirSymbolTableReplaceAllSymbolUses (
1762
- fn_name, new_name , new_mod_op
1732
+ fn_name, name_to_call , new_mod_op
1763
1733
),
1764
1734
)
1765
1735
@assert res == MLIR. IR. success () " hlo_call: failed to rename $fn_name "
@@ -1772,7 +1742,7 @@ julia> Reactant.@jit(
1772
1742
)
1773
1743
1774
1744
# Change function name
1775
- MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (new_name ))
1745
+ MLIR. IR. attr! (op, symbol_attr_name, MLIR. IR. Attribute (name_to_call ))
1776
1746
end
1777
1747
end
1778
1748
@@ -1786,11 +1756,59 @@ julia> Reactant.@jit(
1786
1756
error (" hlo_call: could not find function $func_name in the provided module" )
1787
1757
end
1788
1758
1759
+ return name_to_call
1760
+ end
1761
+
1762
+ function triton_call (
1763
+ mlir_code:: String ,
1764
+ args:: Union{TracedRArray,TracedRNumber,Number} ...;
1765
+ func_name:: String = " main" ,
1766
+ location= mlir_stacktrace (" triton_call" , @__FILE__ , @__LINE__ ),
1767
+ )
1768
+ name_to_call = _extract_function (mlir_code; func_name, func_op_kind= " tt.func" )
1769
+
1770
+ @show name_to_call
1771
+ display (MLIR. IR. mmodule ())
1772
+
1773
+ error (" TODO: implement triton_call" )
1774
+ end
1775
+
1776
+ """
1777
+ hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
1778
+
1779
+ Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
1780
+ with the provided arguments and return a tuple for each result of the call.
1781
+
1782
+ ```julia-repl
1783
+ julia> Reactant.@jit(
1784
+ hlo_call(
1785
+ \"\"\"
1786
+ module {
1787
+ func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
1788
+ %0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
1789
+ return %0 : tensor<3xf32>
1790
+ }
1791
+ }
1792
+ \"\"\" ,
1793
+ Reactant.to_rarray(Float32[1, 2, 3]),
1794
+ Reactant.to_rarray(Float32[1, 2, 3]),
1795
+ )
1796
+ )
1797
+ (ConcretePJRTArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
1798
+ ```
1799
+ """
1800
+ @noinline function hlo_call (
1801
+ code,
1802
+ args:: Union{TracedRArray,TracedRNumber} ...;
1803
+ func_name= " main" ,
1804
+ location= mlir_stacktrace (" hlo_call" , @__FILE__ , @__LINE__ ),
1805
+ )
1806
+ name_to_call = _extract_function (code; func_name, func_op_kind= " func.func" )
1807
+
1789
1808
ftype_attr = MLIR. IR. attr (fn, " function_type" )
1790
1809
ftype = MLIR. IR. Type (ftype_attr)
1791
1810
1792
- @assert all (Base. Fix2 (isa, Union{TracedRArray,TracedRNumber}), args) " hlo_call: all inputs to hlo_call should be reactant arrays or numbers"
1793
- @assert MLIR. IR. ninputs (ftype) == length (args) " hlo_call: invalid number of arguments for function $func_name "
1811
+ @assert MLIR. IR. ninputs (ftype) == length (args) " hlo_call: invalid number of arguments for function $func_name . Expected $(MLIR. IR. ninputs (ftype)) , got $(length (args)) "
1794
1812
1795
1813
for (i, arg) in enumerate (args)
1796
1814
expected_type = MLIR. IR. input (ftype, i)
0 commit comments