Skip to content

Commit

Permalink
switches
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 17, 2024
1 parent 2a08222 commit 9e1d1c3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
5 changes: 4 additions & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,10 @@ def _can_fuse_node(n: Node):
# Some of the operations might be better placed with its consumers (for
# example residual connection in transformer block). This pass moves
# them to the consumer.
if self._use_rematerialization:
use_rematerialization: None | bool = get_compile_option(
"use_rematerialization", "use rematerialization of parameters"
)
if use_rematerialization and self._use_rematerialization:
fusedtrace = rematerialize(fusedtrace)

fusedtrace = remove_redundant_casts(fusedtrace)
Expand Down
7 changes: 6 additions & 1 deletion thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

import thunder.core.utils as utils
from thunder.core.compile_data import get_compile_option
from thunder.core.prims import PrimIDs
from thunder.core.proxies import TensorProxy, variableify
from thunder.core.pytree import tree_flatten
Expand Down Expand Up @@ -270,7 +271,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
)
bw_traces.append(bw_extrace)

# fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace)
use_rematerialization: None | bool = get_compile_option(
"use_rematerialization", "use rematerialization of parameters"
)
if use_rematerialization:
fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace)
fw_traces.append(fw_extrace)
bw_traces.append(bw_extrace)

Expand Down
6 changes: 5 additions & 1 deletion thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ def _can_fuse_node(n: Node):

fusedtrace.bound_symbols = fused_bsyms

# fusedtrace = rematerialize(fusedtrace)
use_rematerialization: None | bool = get_compile_option(
"use_rematerialization", "use rematerialization of parameters"
)
if use_rematerialization:
fusedtrace = rematerialize(fusedtrace)
fusedtrace = dce(fusedtrace)
fusedtrace = update_fusion_call_ctx(fusedtrace)

Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_nvfuser_remat.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def forward(self, x):

# At the time of writing, linear and matmul are not fused into nvFuser
# regions by default therefore, we should enable them separately
jmodel = thunder.jit(model, nv_enable_linear=True, nv_enable_matmul=True)
jmodel = thunder.jit(model, nv_enable_linear=True, nv_enable_matmul=True, use_rematerialization=True)
jmodel(inp)

def assert_subsymbol_count(trace: TraceCtx, /, num_linears: int, num_matmuls: int):
Expand Down

0 comments on commit 9e1d1c3

Please sign in to comment.