diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index ee9b8f4464..f24ee1fad2 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -361,7 +361,6 @@ def add_edges(var): if not required_producer_vars: # If there are no required producer variables, we need to make sure that # the source node is added to the graph. - print("#### hello") add_edge("source", "source", capacity=float("inf")) for var in required_producer_vars: @@ -375,8 +374,6 @@ def add_edges(var): g = nx.DiGraph() g.add_edges_from(edges) - print("#####", dict(g.nodes), g.edges) - try: _, (reachable, non_reachable) = nx.minimum_cut(g, "source", "sink") except Exception: diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index b7de2fba7a..690c4f93bc 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -67,7 +67,7 @@ def test_find_producer_symbols(executor, device, _): # We will try to find a subgraph for rematerializing __c and __d t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -117,7 +117,7 @@ def test_find_producer_symbols(executor, device, _): def test_apply_rematerialization_producer(executor, device, _): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -151,7 +151,7 @@ def test_apply_rematerialization_producer(executor, device, _): def test_apply_rematerialization_consumer(executor, device, _): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -217,7 +217,7 @@ def foo(t0): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(foo, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -256,7 +256,7 @@ def func(t0): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) initial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(initial_trace.python_callable()) + compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -368,7 +368,9 @@ def func( from thunder.executors.torch_compile import torch_compile_cat_ex try: - compiled_func = thunder.jit(func, executors=(torch_compile_cat_ex, thunder.nvfuser_executor)) + compiled_func = thunder.jit( + func, executors=(torch_compile_cat_ex, thunder.nvfuser_executor), use_rematerialization=True + ) _ = compiled_func( t0, t1, @@ -390,7 +392,7 @@ def func( def test_find_cut(executor, device, _): t0 = make_tensor(2, 2, dtype=torch.float32, device=device) intial_trace = thunder.trace()(func, t0) - compiled_func = thunder.jit(intial_trace.python_callable()) + compiled_func = thunder.jit(intial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -417,7 +419,7 @@ def test_find_cut_dropout(executor, device, _): with patch("thunder.core.rematerialization.replace_uniform", new=replace_uniform_mock): intial_trace = thunder.trace()(func_with_dropout, t0) - compiled_func = thunder.jit(intial_trace.python_callable()) + compiled_func = thunder.jit(intial_trace.python_callable(), use_rematerialization=True) _ = compiled_func(t0) traces = thunder.last_traces(compiled_func) trace = traces[-1] @@ -458,10 +460,12 @@ def func(t0): # Result with rematerialization and without rematerialization should match initial_trace = thunder.trace()(func, t0) - result_with_remat = thunder.jit(initial_trace.python_callable())(t0) + result_with_remat = thunder.jit(initial_trace.python_callable(), use_rematerialization=True)(t0) assert not isinstance(result_with_remat, Exception) - result_without_remat = disable_rematerialization_in_nvfuser_fusion(thunder.jit(initial_trace.python_callable()))(t0) + result_without_remat = disable_rematerialization_in_nvfuser_fusion( + thunder.jit(initial_trace.python_callable(), use_rematerialization=True) + )(t0) torch.testing.assert_close(result_with_remat, result_without_remat) @@ -473,7 +477,7 @@ def test_rematerialization_name_collision(): def forward(x): return x.softmax(dim=1, dtype=torch.float) - jforward = thunder.jit(forward) + jforward = thunder.jit(forward, use_rematerialization=True) x = torch.randn([32768, 8], dtype=torch.bfloat16, device="cuda", requires_grad=True)