From 9e1d1c32677363092429e0b2dacf965d792e4bef Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 17 Dec 2024 12:22:07 +0100 Subject: [PATCH] switches --- thunder/executors/nvfuserex_impl.py | 5 ++++- thunder/executors/torch_autograd.py | 7 ++++++- thunder/executors/torch_compile.py | 6 +++++- thunder/tests/test_nvfuser_remat.py | 2 +- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 4bd78b72b5..5695b0d038 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -931,7 +931,10 @@ def _can_fuse_node(n: Node): # Some of the operations might be better placed with its consumers (for # example residual connection in transformer block). This pass moves # them to the consumer. - if self._use_rematerialization: + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization and self._use_rematerialization: fusedtrace = rematerialize(fusedtrace) fusedtrace = remove_redundant_casts(fusedtrace) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index ac8d20cb24..373b6eb3c1 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -4,6 +4,7 @@ import torch import thunder.core.utils as utils +from thunder.core.compile_data import get_compile_option from thunder.core.prims import PrimIDs from thunder.core.proxies import TensorProxy, variableify from thunder.core.pytree import tree_flatten @@ -270,7 +271,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat ) bw_traces.append(bw_extrace) - # fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization: + fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index e8ba3f32f2..3d21b8057e 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -194,7 +194,11 @@ def _can_fuse_node(n: Node): fusedtrace.bound_symbols = fused_bsyms - # fusedtrace = rematerialize(fusedtrace) + use_rematerialization: None | bool = get_compile_option( + "use_rematerialization", "use rematerialization of parameters" + ) + if use_rematerialization: + fusedtrace = rematerialize(fusedtrace) fusedtrace = dce(fusedtrace) fusedtrace = update_fusion_call_ctx(fusedtrace) diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index 7021efcb69..b7de2fba7a 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -514,7 +514,7 @@ def forward(self, x): # At the time of writing, linear and matmul are not fused into nvFuser # regions by default therefore, we should enable them separately - jmodel = thunder.jit(model, nv_enable_linear=True, nv_enable_matmul=True) + jmodel = thunder.jit(model, nv_enable_linear=True, nv_enable_matmul=True, use_rematerialization=True) jmodel(inp) def assert_subsymbol_count(trace: TraceCtx, /, num_linears: int, num_matmuls: int):