diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 9d8c6aef0e..a163821f52 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1805,16 +1805,19 @@ def forward(self, *args): inps = torch.randn(512, 1024, requires_grad=True) jm = thunder.jit(m, executors=()) # no rematerialization + res = jm(inps) + res.sum().backward() + + torch.cuda.reset_peak_memory_stats() mem_base = torch.cuda.memory_allocated() - torch.cuda.reset_accumulated_memory_stats() res = jm(inps) res.sum().backward() mem_max = torch.cuda.max_memory_allocated() - # the rematerialization pass moved all(?) recomputation to the front, - # making the peak mem about 46MB. + # without chewckpointing the peak mem about 43MB. # With checkpointing as coded in the model and recomputation where the - # values are used, we get about 12-20MB, so we put the barrier at 24MB - assert mem_max - mem_base < 24 * 2**20 + # values are used, we get a little over 10MB, so we put the barrier at 16MB + mb_used = (mem_max - mem_base) / 2**20 + assert mb_used < 16 def test_inconsistent_output_length_grad_transform():