From 81f83f3549d0cde9fc012fb9ebfb9eb4a3254e61 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 25 Nov 2024 09:05:32 +0100 Subject: [PATCH] dce all-const returning symbols in subsymbols (#1465) --- thunder/core/transform_common.py | 6 +++++- thunder/tests/test_networks.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index a21d1b207c..c9449f9f87 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -54,7 +54,11 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None: for sbsym in bsym.subsymbols: if len(sbsym.subsymbols) == 0 and not sbsym.sym.is_prim: continue - + # if all outputs are constants, we elmininate the subsymbol + if not has_tags(bsym, {prims.OpTags.DONT_DCE}) and not any( + o is not None for o in sbsym.flat_proxy_outs + ): # is not None to avoid cast to bool + continue _remove_noop_subsymbols(sbsym) nsbsyms.append(sbsym) diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index a1a3f0ca3f..bec0d7f0d4 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -527,3 +527,7 @@ def test_hf_llama(): res2 = jm(past_key_values=res["past_key_values"], **args2) expected2 = model(past_key_values=res["past_key_values"], **args2) assert_close(res2, expected2, rtol=1e-1, atol=1e-1) + + top_level_symbol_names = {bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols} + # changes this to fewer as needed, the goal is to not have too many fusions + assert len([s for s in top_level_symbol_names if s.startswith("nvFusion")]) == 7