From 2bd6416f311d72ae9192192e3a8094813f0ddb9c Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 18 Dec 2024 09:47:19 +0100 Subject: [PATCH] filter proxies in remat, improve tests --- thunder/core/rematerialization.py | 2 +- thunder/core/transform_common.py | 1 - thunder/tests/test_grad.py | 7 +++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index f24ee1fad2..3817770ce5 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -192,7 +192,7 @@ def apply_rematerialization_for_consumer( _, leaves = bsym_list_to_dag(list(new_subsymbols)) new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP) proxy_order = order_proxies(new_subsymbols) - new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name])) + new_consumer_args = tuple(sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name])) new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols) return new_consumer diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index bfe4123dc6..3eb4eb6d52 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -441,7 +441,6 @@ def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]: """ counter = 0 proxy_order: dict[str, int] = {} # names to order - def process_bound_symbols(bound_symbols): nonlocal counter for bsym in bound_symbols: diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index fca022771e..9d8c6aef0e 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1813,8 +1813,8 @@ def forward(self, *args): # the rematerialization pass moved all(?) recomputation to the front, # making the peak mem about 46MB. # With checkpointing as coded in the model and recomputation where the - # values are used, we get about 12MB, so we put the barrier at 16MB - assert mem_max - mem_base < 16 * 2**20 + # values are used, we get about 12-20MB, so we put the barrier at 24MB + assert mem_max - mem_base < 24 * 2**20 def test_inconsistent_output_length_grad_transform(): @@ -1906,6 +1906,9 @@ def forward(x): @pytest.mark.parametrize("device", ("cuda", "cpu")) def test_backward_recomputation_decomposed_ops(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + def fn(a): return torch.nn.functional.gelu(a)