diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index 79f4cf4601..9bb5afe96e 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -330,7 +330,7 @@ def test_find_cut_dropout(executor, device, _): ext_producer_outputs = find_external_producer_outputs(utils.consumers(trace), (), producer, consumer) cut = find_cut(ext_producer_outputs, producer, consumer) assert cut[0] == producer.args[0].name - # Note cut[1]/producer.output[0] is the boolean mask for dropout. It should + # Note cut[1]/producer.output[1] is the boolean mask for dropout. It should # be chosen over the float32 mask. See this issue: "The Recomputation # Algorithm on Dropout choses a float32 mask to save" producer_output_names = tuple(o.name for o in producer.output)