diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index f24ee1fad2..1a5241f237 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -192,7 +192,9 @@ def apply_rematerialization_for_consumer( _, leaves = bsym_list_to_dag(list(new_subsymbols)) new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP) proxy_order = order_proxies(new_subsymbols) - new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name])) + new_consumer_args = tuple( + sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name]) + ) new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols) return new_consumer diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index fca022771e..9d8c6aef0e 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1813,8 +1813,8 @@ def forward(self, *args): # the rematerialization pass moved all(?) recomputation to the front, # making the peak mem about 46MB. # With checkpointing as coded in the model and recomputation where the - # values are used, we get about 12MB, so we put the barrier at 16MB - assert mem_max - mem_base < 16 * 2**20 + # values are used, we get about 12-20MB, so we put the barrier at 24MB + assert mem_max - mem_base < 24 * 2**20 def test_inconsistent_output_length_grad_transform(): @@ -1906,6 +1906,9 @@ def forward(x): @pytest.mark.parametrize("device", ("cuda", "cpu")) def test_backward_recomputation_decomposed_ops(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + def fn(a): return torch.nn.functional.gelu(a)