Skip to content

Commit

Permalink
switches
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 17, 2024
1 parent 9e1d1c3 commit 32d8ccb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
3 changes: 0 additions & 3 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def add_edges(var):
if not required_producer_vars:
# If there are no required producer variables, we need to make sure that
# the source node is added to the graph.
print("#### hello")
add_edge("source", "source", capacity=float("inf"))

for var in required_producer_vars:
Expand All @@ -375,8 +374,6 @@ def add_edges(var):
g = nx.DiGraph()
g.add_edges_from(edges)

print("#####", dict(g.nodes), g.edges)

try:
_, (reachable, non_reachable) = nx.minimum_cut(g, "source", "sink")
except Exception:
Expand Down
26 changes: 15 additions & 11 deletions thunder/tests/test_nvfuser_remat.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_find_producer_symbols(executor, device, _):
# We will try to find a subgraph for rematerializing __c and __d
t0 = make_tensor(2, 2, dtype=torch.float32, device=device)
initial_trace = thunder.trace()(func, t0)
compiled_func = thunder.jit(initial_trace.python_callable())
compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True)
_ = compiled_func(t0)
traces = thunder.last_traces(compiled_func)
trace = traces[-1]
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_find_producer_symbols(executor, device, _):
def test_apply_rematerialization_producer(executor, device, _):
t0 = make_tensor(2, 2, dtype=torch.float32, device=device)
initial_trace = thunder.trace()(func, t0)
compiled_func = thunder.jit(initial_trace.python_callable())
compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True)
_ = compiled_func(t0)
traces = thunder.last_traces(compiled_func)
trace = traces[-1]
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_apply_rematerialization_producer(executor, device, _):
def test_apply_rematerialization_consumer(executor, device, _):
t0 = make_tensor(2, 2, dtype=torch.float32, device=device)
initial_trace = thunder.trace()(func, t0)
compiled_func = thunder.jit(initial_trace.python_callable())
compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True)
_ = compiled_func(t0)
traces = thunder.last_traces(compiled_func)
trace = traces[-1]
Expand Down Expand Up @@ -217,7 +217,7 @@ def foo(t0):

t0 = make_tensor(2, 2, dtype=torch.float32, device=device)
initial_trace = thunder.trace()(foo, t0)
compiled_func = thunder.jit(initial_trace.python_callable())
compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True)
_ = compiled_func(t0)
traces = thunder.last_traces(compiled_func)
trace = traces[-1]
Expand Down Expand Up @@ -256,7 +256,7 @@ def func(t0):

t0 = make_tensor(2, 2, dtype=torch.float32, device=device)
initial_trace = thunder.trace()(func, t0)
compiled_func = thunder.jit(initial_trace.python_callable())
compiled_func = thunder.jit(initial_trace.python_callable(), use_rematerialization=True)
_ = compiled_func(t0)
traces = thunder.last_traces(compiled_func)
trace = traces[-1]
Expand Down Expand Up @@ -368,7 +368,9 @@ def func(
from thunder.executors.torch_compile import torch_compile_cat_ex

try:
compiled_func = thunder.jit(func, executors=(torch_compile_cat_ex, thunder.nvfuser_executor))
compiled_func = thunder.jit(
func, executors=(torch_compile_cat_ex, thunder.nvfuser_executor), use_rematerialization=True
)
_ = compiled_func(
t0,
t1,
Expand All @@ -390,7 +392,7 @@ def func(
def test_find_cut(executor, device, _):
t0 = make_tensor(2, 2, dtype=torch.float32, device=device)
intial_trace = thunder.trace()(func, t0)
compiled_func = thunder.jit(intial_trace.python_callable())
compiled_func = thunder.jit(intial_trace.python_callable(), use_rematerialization=True)
_ = compiled_func(t0)
traces = thunder.last_traces(compiled_func)
trace = traces[-1]
Expand All @@ -417,7 +419,7 @@ def test_find_cut_dropout(executor, device, _):

with patch("thunder.core.rematerialization.replace_uniform", new=replace_uniform_mock):
intial_trace = thunder.trace()(func_with_dropout, t0)
compiled_func = thunder.jit(intial_trace.python_callable())
compiled_func = thunder.jit(intial_trace.python_callable(), use_rematerialization=True)
_ = compiled_func(t0)
traces = thunder.last_traces(compiled_func)
trace = traces[-1]
Expand Down Expand Up @@ -458,10 +460,12 @@ def func(t0):

# Result with rematerialization and without rematerialization should match
initial_trace = thunder.trace()(func, t0)
result_with_remat = thunder.jit(initial_trace.python_callable())(t0)
result_with_remat = thunder.jit(initial_trace.python_callable(), use_rematerialization=True)(t0)
assert not isinstance(result_with_remat, Exception)

result_without_remat = disable_rematerialization_in_nvfuser_fusion(thunder.jit(initial_trace.python_callable()))(t0)
result_without_remat = disable_rematerialization_in_nvfuser_fusion(
thunder.jit(initial_trace.python_callable(), use_rematerialization=True)
)(t0)

torch.testing.assert_close(result_with_remat, result_without_remat)

Expand All @@ -473,7 +477,7 @@ def test_rematerialization_name_collision():
def forward(x):
return x.softmax(dim=1, dtype=torch.float)

jforward = thunder.jit(forward)
jforward = thunder.jit(forward, use_rematerialization=True)

x = torch.randn([32768, 8], dtype=torch.bfloat16, device="cuda", requires_grad=True)

Expand Down

0 comments on commit 32d8ccb

Please sign in to comment.