Skip to content

Commit

Permalink
filter proxies in remat, improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 18, 2024
1 parent 13725d5 commit af68c02
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
4 changes: 3 additions & 1 deletion thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def apply_rematerialization_for_consumer(
_, leaves = bsym_list_to_dag(list(new_subsymbols))
new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP)
proxy_order = order_proxies(new_subsymbols)
new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
new_consumer_args = tuple(
sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name])
)
new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols)
return new_consumer

Expand Down
7 changes: 5 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,8 +1813,8 @@ def forward(self, *args):
# the rematerialization pass moved all(?) recomputation to the front,
# making the peak mem about 46MB.
# With checkpointing as coded in the model and recomputation where the
# values are used, we get about 12MB, so we put the barrier at 16MB
assert mem_max - mem_base < 16 * 2**20
# values are used, we get about 12-20MB, so we put the barrier at 24MB
assert mem_max - mem_base < 24 * 2**20


def test_inconsistent_output_length_grad_transform():
Expand Down Expand Up @@ -1906,6 +1906,9 @@ def forward(x):

@pytest.mark.parametrize("device", ("cuda", "cpu"))
def test_backward_recomputation_decomposed_ops(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

def fn(a):
return torch.nn.functional.gelu(a)

Expand Down

0 comments on commit af68c02

Please sign in to comment.