Skip to content

Commit

Permalink
wip simple fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 20, 2024
1 parent 05a2a66 commit b1604ee
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
4 changes: 2 additions & 2 deletions thunder/examine/memory_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def get_alloc_memory(trc: TraceCtx, *, annotate=False) -> tuple[int, dict[str, i
allocated += impl(bsym, tensor_to_memory_data, name_to_alloc_memory)
if annotate:
if bsym.header:
bsym.header += ' '
bsym.header += f'mem after next op: ~{allocated/(2**30):2f}GB'
bsym.header += " "
bsym.header += f"mem after next op: ~{allocated/(2**30):2f}GB"
max_allocated = max(max_allocated, allocated)

return max_allocated, name_to_alloc_memory
15 changes: 15 additions & 0 deletions thunder/executors/data_dependent_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,18 @@ def fuse_bound_symbols(trace: TraceCtx, merge_func: Callable):
dataflow_merge(graph, merge_func)
ret = horizontal_merge(graph, merge_func)
return ret


def fuse_bound_symbols(trace: TraceCtx, can_fuse: Callable):
fusions = [[]]
for bsym in trace.bound_symbols:
if can_fuse(bsym):
fusions[-1].append(bsym)
else:
if fusions[-1]:
fusions.append([])
fusions[-1].append(bsym)
fusions.append([])
if not fusions[-1]:
del fusions[-1]
return fusions
4 changes: 3 additions & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,9 @@ def _can_fuse_node(n: Node):

return _can_fuse_node(a) and _can_fuse_node(b)

bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse)
bound_symbol_groups = fuse_bound_symbols(
trace, lambda bsym: self.can_fuse(bsym) and self.has_cuda_input_or_output(bsym)
) # _should_fuse)

# Counts how many fusions (per executor) have been constructed
# (Used to name fusions like nvFusion0, nvFusion1, ...)
Expand Down
2 changes: 1 addition & 1 deletion thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _can_fuse_node(n: Node):

return _can_fuse_node(a) and _can_fuse_node(b)

bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse)
bound_symbol_groups = fuse_bound_symbols(trace, self.can_fuse)

fused_bsyms = []
# Counts how many fusions (per executor) have been constructed
Expand Down

0 comments on commit b1604ee

Please sign in to comment.