Skip to content

Commit

Permalink
fix cuda mem accounting
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 18, 2024
1 parent 1159797 commit 6f5ab69
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,16 +1805,19 @@ def forward(self, *args):
inps = torch.randn(512, 1024, requires_grad=True)

jm = thunder.jit(m, executors=()) # no rematerialization
res = jm(inps)
res.sum().backward()

torch.cuda.reset_peak_memory_stats()
mem_base = torch.cuda.memory_allocated()
torch.cuda.reset_accumulated_memory_stats()
res = jm(inps)
res.sum().backward()
mem_max = torch.cuda.max_memory_allocated()
# the rematerialization pass moved all(?) recomputation to the front,
# making the peak mem about 46MB.
# without chewckpointing the peak mem about 43MB.
# With checkpointing as coded in the model and recomputation where the
# values are used, we get about 12-20MB, so we put the barrier at 24MB
assert mem_max - mem_base < 24 * 2**20
# values are used, we get a little over 10MB, so we put the barrier at 16MB
mb_used = (mem_max - mem_base) / 2**20
assert mb_used < 16


def test_inconsistent_output_length_grad_transform():
Expand Down

0 comments on commit 6f5ab69

Please sign in to comment.