Skip to content

Commit

Permalink
filter proxies in remat, improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 18, 2024
1 parent 13725d5 commit 2bd6416
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def apply_rematerialization_for_consumer(
_, leaves = bsym_list_to_dag(list(new_subsymbols))
new_subsymbols = toposort_bsym_dag(leaves, TOPOSORT_ORDER.BOTTOM_UP)
proxy_order = order_proxies(new_subsymbols)
new_consumer_args = tuple(sorted(new_consumer_args, key=lambda x: proxy_order[x.name]))
new_consumer_args = tuple(sorted((a for a in new_consumer_args if a.name in proxy_order), key=lambda x: proxy_order[x.name]))
new_consumer = replace(consumer, args=new_consumer_args, subsymbols=new_subsymbols)
return new_consumer

Expand Down
1 change: 0 additions & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@ def order_proxies(bsyms: Sequence[BoundSymbol]) -> dict[str, int]:
"""
counter = 0
proxy_order: dict[str, int] = {} # names to order

def process_bound_symbols(bound_symbols):
nonlocal counter
for bsym in bound_symbols:
Expand Down
7 changes: 5 additions & 2 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,8 +1813,8 @@ def forward(self, *args):
# the rematerialization pass moved all(?) recomputation to the front,
# making the peak mem about 46MB.
# With checkpointing as coded in the model and recomputation where the
# values are used, we get about 12MB, so we put the barrier at 16MB
assert mem_max - mem_base < 16 * 2**20
# values are used, we get about 12-20MB, so we put the barrier at 24MB
assert mem_max - mem_base < 24 * 2**20


def test_inconsistent_output_length_grad_transform():
Expand Down Expand Up @@ -1906,6 +1906,9 @@ def forward(x):

@pytest.mark.parametrize("device", ("cuda", "cpu"))
def test_backward_recomputation_decomposed_ops(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")

def fn(a):
return torch.nn.functional.gelu(a)

Expand Down

0 comments on commit 2bd6416

Please sign in to comment.