Skip to content

Commit

Permalink
Change the test to work around #549.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Aug 21, 2024
1 parent 6a2d258 commit a96ebbb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion thunder/tests/test_nvfuser_remat.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_find_cut_dropout(executor, device, _):
ext_producer_outputs = find_external_producer_outputs(utils.consumers(trace), (), producer, consumer)
cut = find_cut(ext_producer_outputs, producer, consumer)
assert cut[0] == producer.args[0].name
# Note cut[1]/producer.output[0] is the boolean mask for dropout. It should
# Note cut[1]/producer.output[1] is the boolean mask for dropout. It should
# be chosen over the float32 mask. See this issue: "The Recomputation
# Algorithm on Dropout choses a float32 mask to save"
producer_output_names = tuple(o.name for o in producer.output)
Expand Down

0 comments on commit a96ebbb

Please sign in to comment.