From b1604eebaf142d851f7a2a9192f3117f2f1c2966 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Fri, 20 Dec 2024 22:49:35 +0100 Subject: [PATCH] wip simple fusion --- thunder/examine/memory_calculation.py | 4 ++-- thunder/executors/data_dependent_partition.py | 15 +++++++++++++++ thunder/executors/nvfuserex_impl.py | 4 +++- thunder/executors/torch_compile.py | 2 +- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/thunder/examine/memory_calculation.py b/thunder/examine/memory_calculation.py index f97e908b7f..c4255773c6 100644 --- a/thunder/examine/memory_calculation.py +++ b/thunder/examine/memory_calculation.py @@ -191,8 +191,8 @@ def get_alloc_memory(trc: TraceCtx, *, annotate=False) -> tuple[int, dict[str, i allocated += impl(bsym, tensor_to_memory_data, name_to_alloc_memory) if annotate: if bsym.header: - bsym.header += ' ' - bsym.header += f'mem after next op: ~{allocated/(2**30):2f}GB' + bsym.header += " " + bsym.header += f"mem after next op: ~{allocated/(2**30):2f}GB" max_allocated = max(max_allocated, allocated) return max_allocated, name_to_alloc_memory diff --git a/thunder/executors/data_dependent_partition.py b/thunder/executors/data_dependent_partition.py index 1edc0cb033..7d37f4f327 100644 --- a/thunder/executors/data_dependent_partition.py +++ b/thunder/executors/data_dependent_partition.py @@ -301,3 +301,18 @@ def fuse_bound_symbols(trace: TraceCtx, merge_func: Callable): dataflow_merge(graph, merge_func) ret = horizontal_merge(graph, merge_func) return ret + + +def fuse_bound_symbols(trace: TraceCtx, can_fuse: Callable): + fusions = [[]] + for bsym in trace.bound_symbols: + if can_fuse(bsym): + fusions[-1].append(bsym) + else: + if fusions[-1]: + fusions.append([]) + fusions[-1].append(bsym) + fusions.append([]) + if not fusions[-1]: + del fusions[-1] + return fusions diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index d1e162dbae..62e97b80b9 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -865,7 +865,9 @@ def _can_fuse_node(n: Node): return _can_fuse_node(a) and _can_fuse_node(b) - bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) + bound_symbol_groups = fuse_bound_symbols( + trace, lambda bsym: self.can_fuse(bsym) and self.has_cuda_input_or_output(bsym) + ) # _should_fuse) # Counts how many fusions (per executor) have been constructed # (Used to name fusions like nvFusion0, nvFusion1, ...) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index ae4ae7a30a..35d5afb317 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -173,7 +173,7 @@ def _can_fuse_node(n: Node): return _can_fuse_node(a) and _can_fuse_node(b) - bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) + bound_symbol_groups = fuse_bound_symbols(trace, self.can_fuse) fused_bsyms = [] # Counts how many fusions (per executor) have been constructed