Skip to content

Commit 95598f9

Browse files
committed
feat: put the tt func in a separate module and use symbol ref
1 parent 5b37e06 commit 95598f9

File tree

3 files changed

+53
-45
lines changed

3 files changed

+53
-45
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "b59185c7586783a17d9486e682307ae89c713964"
7+
ENZYMEXLA_COMMIT = "52ae936cae8f7050adc26c4ed5e755200497dc86"
88

99
ENZYMEXLA_SHA256 = ""
1010

src/Compiler.jl

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,42 +1310,42 @@ function triton_optimization_passes()
13101310
"convert-nvvm-to-llvm",
13111311
# common passes
13121312
"canonicalize",
1313-
# # ttir passes
1314-
# "triton-combine",
1315-
# "triton-reorder-broadcast",
1316-
# "triton-rewrite-tensor-pointer",
1317-
# "triton-rewrite-tensor-descriptor-to-pointer",
1318-
# "triton-loop-unroll",
1319-
# "triton-licm",
1320-
# "triton-loop-aware-cse",
1321-
# # TODO: should num-warps and num-ctas be set for each kernel?
1322-
# "convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1323-
# # ttgir passes
1324-
# "tritongpu-coalesce",
1325-
# "tritongpu-optimize-thread-locality",
1326-
# "tritongpu-hoist-tmem-alloc",
1327-
# "tritongpu-assign-latencies",
1328-
# "tritongpu-pipeline",
1329-
# "tritongpu-schedule-loops",
1330-
# "tritongpu-automatic-warp-specialization",
1331-
# "tritongpu-prefetch",
1332-
# "tritongpu-accelerate-matmul",
1333-
# "tritongpu-reorder-instructions",
1334-
# "tritongpu-F32DotTC",
1335-
# "tritongpu-optimize-dot-operands",
1336-
# "tritongpu-remove-layout-conversions",
1337-
# "tritongpu-reduce-data-duplication",
1338-
# "tritongpu-hoist-tmem-alloc",
1339-
# "tritongpu-fuse-nested-loops",
1340-
# "tritongpu-rewrite-partition-dependencies",
1341-
# "tritongpu-partition-loops",
1342-
# "tritongpu-combine-tensor-select-and-if",
1343-
# # ttgir to llvm passes
1344-
# "tritongpu-allocate-warp-groups",
1345-
# "allocate-shared-memory",
1346-
# "tritongpu-global-scratch-memory-allocation",
1347-
# "tritongpu-optimize-accumulator-init",
1348-
# "tritongpu-coalesce-async-copy",
1313+
# ttir passes
1314+
"triton-combine",
1315+
"triton-reorder-broadcast",
1316+
"triton-rewrite-tensor-pointer",
1317+
"triton-rewrite-tensor-descriptor-to-pointer",
1318+
"triton-loop-unroll",
1319+
"triton-licm",
1320+
"triton-loop-aware-cse",
1321+
# TODO: should num-warps and num-ctas be set for each kernel?
1322+
"convert-triton-to-tritongpu{target=cuda:$(cubinChip[][4:end]) num-warps=1 threads-per-warp=$(cuWarpSize[]) num-ctas=1}",
1323+
# ttgir passes
1324+
"tritongpu-coalesce",
1325+
"tritongpu-optimize-thread-locality",
1326+
"tritongpu-hoist-tmem-alloc",
1327+
"tritongpu-assign-latencies",
1328+
"tritongpu-pipeline",
1329+
"tritongpu-schedule-loops",
1330+
"tritongpu-automatic-warp-specialization",
1331+
"tritongpu-prefetch",
1332+
"tritongpu-accelerate-matmul",
1333+
"tritongpu-reorder-instructions",
1334+
"tritongpu-F32DotTC",
1335+
"tritongpu-optimize-dot-operands",
1336+
"tritongpu-remove-layout-conversions",
1337+
"tritongpu-reduce-data-duplication",
1338+
"tritongpu-hoist-tmem-alloc",
1339+
"tritongpu-fuse-nested-loops",
1340+
"tritongpu-rewrite-partition-dependencies",
1341+
"tritongpu-partition-loops",
1342+
"tritongpu-combine-tensor-select-and-if",
1343+
# ttgir to llvm passes
1344+
"tritongpu-allocate-warp-groups",
1345+
"allocate-shared-memory",
1346+
"tritongpu-global-scratch-memory-allocation",
1347+
"tritongpu-optimize-accumulator-init",
1348+
"tritongpu-coalesce-async-copy",
13491349
],
13501350
",",
13511351
)
@@ -2303,8 +2303,7 @@ function compile_mlir!(
23032303
end
23042304
end
23052305

2306-
# XXX: re-enable this pass
2307-
# run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
2306+
run_pass_pipeline!(mod, "mark-func-memory-effects", "mark-func-memory-effects")
23082307

23092308
func_op = MLIR.API.mlirSymbolTableLookup(
23102309
MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), fnname

src/Ops.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,12 +1707,19 @@ function _extract_function(
17071707
nested_module::Bool=false,
17081708
)
17091709
module_suffix = string(hash(code); base=16)
1710-
name_to_call = _new_function_name(func_name, module_suffix)
1710+
name_to_call = func_name * "_call_" * module_suffix
1711+
mod_name = func_name * "_module_" * module_suffix
17111712

17121713
current_module = MLIR.IR.mmodule()
17131714
if nested_module
17141715
new_module = MLIR.IR.Module()
1715-
push!(MLIR.IR.body(current_module), MLIR.IR.Operation(new_module, true))
1716+
moduleop = MLIR.IR.Operation(new_module, true)
1717+
MLIR.IR.attr!(
1718+
moduleop,
1719+
String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()),
1720+
MLIR.IR.Attribute(mod_name),
1721+
)
1722+
push!(MLIR.IR.body(current_module), moduleop)
17161723
current_module = new_module
17171724
end
17181725
top_level_block = MLIR.IR.body(current_module)
@@ -1764,7 +1771,7 @@ function _extract_function(
17641771
error("hlo_call: could not find function $func_name in the provided module")
17651772
end
17661773

1767-
return fn, name_to_call
1774+
return fn, name_to_call, mod_name
17681775
end
17691776

17701777
function triton_call(
@@ -1778,7 +1785,7 @@ function triton_call(
17781785
location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__),
17791786
# TODO: other kwargs
17801787
)
1781-
_, name_to_call = _extract_function(
1788+
_, name_to_call, mod_name = _extract_function(
17821789
mlir_code; func_name, func_op_kind="tt.func", nested_module=true
17831790
)
17841791

@@ -1788,7 +1795,9 @@ function triton_call(
17881795
grid_z.mlir_data,
17891796
shmem.mlir_data,
17901797
[Reactant.TracedUtils.get_mlir_data(a) for a in args];
1791-
fn=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
1798+
fn=MLIR.IR.SymbolRefAttribute(
1799+
mod_name, MLIR.IR.Attribute[MLIR.IR.FlatSymbolRefAttribute(name_to_call)]
1800+
),
17921801
result_0=MLIR.IR.Type[],
17931802
location,
17941803
)
@@ -1826,7 +1835,7 @@ julia> Reactant.@jit(
18261835
func_name="main",
18271836
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
18281837
)
1829-
fn, name_to_call = _extract_function(code; func_name, func_op_kind="func.func")
1838+
fn, name_to_call, _ = _extract_function(code; func_name, func_op_kind="func.func")
18301839

18311840
ftype_attr = MLIR.IR.attr(fn, "function_type")
18321841
ftype = MLIR.IR.Type(ftype_attr)

0 commit comments

Comments
 (0)