From d646dd11bace7b06c04071f331c6eb14899facbd Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 11 Nov 2024 21:46:59 +0100 Subject: [PATCH 01/11] support no_grad in thunder.jit --- thunder/__init__.py | 4 +++ thunder/core/transform_common.py | 51 +++++++++++++++++++++++++++++++- thunder/core/transforms.py | 8 +++++ thunder/torch/__init__.py | 14 +++++++-- 4 files changed, 73 insertions(+), 4 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index c09c1cc9b5..797b63da88 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -35,6 +35,7 @@ wrap_return_value_together_with_argments, unwrap_return_value, remove_context_manager_prims_from_trace, + tag_no_grad_symbols_pass, ) from thunder.core.functionalization import ( check_inplace_to_views, @@ -538,6 +539,9 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = wrap_return_value_together_with_argments(computation_trc) computation_traces.append(computation_trc) + computation_trc = tag_no_grad_symbols_pass(computation_trc) + computation_traces.append(computation_trc) + computation_trc = remove_context_manager_prims_from_trace(computation_trc) computation_traces.append(computation_trc) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index e21ca4b28a..139e72e6f3 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -10,7 +10,7 @@ import thunder import thunder.core.prims as prims from thunder.core.baseutils import BoundSymbolInterface -from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify +from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify, ProxyTag from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace, tracectx @@ -492,3 +492,52 @@ def is_context_manager_prim(bsym): new_trace.bound_symbols = filtered_bsyms new_trace.set_provenance(TraceProvenance("Remove context manager prims")) return new_trace + + +def tag_no_grad_symbols_pass(trace: Trace) -> Trace: + """ + This function iterates over trace and marks the BoundSymbols + in `no_grad` regions such that VJP pass will treat them as constant. + """ + is_no_grad_region = False + + # NOTE - This will also copy name from original trace. + new_trace = from_trace(trace) + new_bsyms = [] + + for bsym in trace.bound_symbols: + # case - torch._C._set_grad_enabled(False) + if bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and not bsym.args[0]: + is_no_grad_region = True + continue + # case - torch._C._set_grad_enabled(True) + elif bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and bsym.args[0]: + is_no_grad_region = False + continue + + if is_no_grad_region: + # Mark the TensorProxy output of the `bsym` + # with `ProxyTag.DETACHED_AUTOGRAD_GRAPH` so that + # vjp will treat this as constant. + + def create_detached_output(t): + if isinstance(t, TensorProxy): + # NOTE - We need `tracectx` as creating/replacing name for proxy + # tries a look-up in current trace. + with tracectx(new_trace): + # Remove the name so that we can re-use it. + # Otherwise, we get a proxy with new name. + new_trace.names.remove(t.name) + return t.replace(requires_grad=False, tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) + + return t + + new_output = tree_map(create_detached_output, bsym.output) + # Create a copy of the `bsym` with `new_output` + bsym = bsym.from_bsym(output=new_output) + + new_bsyms.append(bsym) + + new_trace.bound_symbols = new_bsyms + new_trace.set_provenance(TraceProvenance("no_grad detach graph for vjp pass")) + return new_trace diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 2bf91372fe..94596b732e 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -33,6 +33,7 @@ variableify, unvariableify, FutureTensorProxy, + ProxyTag, ) from thunder.core.compile_data import get_compile_data, get_compile_option from thunder.core.langctxs import langctx, Languages @@ -2485,10 +2486,17 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool: bool: True if the symbol is constant, False otherwise. """ are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args) + # `no_grad_detach_graph_pass` tags output of BoundSymbols in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`. + # These are treated as constant for VJP. + # NOTE - `any(()) is False` + output_disconnected_from_graph = any( + ProxyTag.DETACHED_AUTOGRAD_GRAPH in o.tags for o in symbol.flat_outs if isinstance(o, TensorProxy) + ) return ( are_all_args_non_differentiable or symbol.are_all_args_constant or symbol.sym.id in nondifferentiable_vjp_symbols + or output_disconnected_from_graph ) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 7047485ff6..81da76c113 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -47,6 +47,7 @@ ListProxy, DictProxy, numberproxy, + ProxyTag, ) from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten from thunder.core.symbol import Symbol @@ -5226,11 +5227,18 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device: register_function(torch.device, torch_device) -def _set_grad_enabled_with_warning(enabled: bool) -> None: - warnings.warn("torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect under thunder.jit") +# Tag to use on Proxies created in `no_grad` regions. +# VJP transform will treat BOundSymbol's whose output has these tags +# as constant. +ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH") -register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning) +# This is just a marker Symbol. `tag_no_grad_symbols_pass` pass uses these symbols +# to find the `no_grad` regions and mark the BoundSymbols within them as constant +# for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag. +@torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,)) +def _set_grad_enabled_with_warning(enabled: bool) -> None: + pass def _unwrap_if_dead(tensor): From 1d1d201a5a43d48f73cfa929feec80a35686a6a5 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 12 Nov 2024 11:33:05 +0100 Subject: [PATCH 02/11] Update thunder/core/transform_common.py Co-authored-by: Ivan Yashchuk --- thunder/core/transform_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 139e72e6f3..3e4d074934 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -528,7 +528,7 @@ def create_detached_output(t): # Remove the name so that we can re-use it. # Otherwise, we get a proxy with new name. new_trace.names.remove(t.name) - return t.replace(requires_grad=False, tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) + return t.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) return t From 1b8cc7518e395e695b18f6ca67e201b083fb6f7b Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 12 Nov 2024 11:37:25 +0100 Subject: [PATCH 03/11] update existing test --- thunder/tests/test_core.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 4330b5236a..e9688af6ae 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2099,7 +2099,7 @@ def func(x): compiled = executor.make_callable(func) out = compiled(x) assert out is x - initial_trace_with_dce = thunder.last_traces(compiled)[3] + initial_trace_with_dce = thunder.last_traces(compiled)[4] assert "Constructed by Dead Code Elimination" in str(initial_trace_with_dce) assert len(initial_trace_with_dce.bound_symbols) == 2 assert initial_trace_with_dce.bound_symbols[0].sym.id == prims.PrimIDs.UNPACK_TRIVIAL @@ -2485,8 +2485,7 @@ def foo1(x): return x + 1 x = torch.randn(3, 3, requires_grad=True) - with pytest.warns(UserWarning, match="have no effect under thunder.jit"): - thunder.jit(foo1)(x).sum().backward() + thunder.jit(foo1)(x).sum().backward() assert x.grad is not None @@ -2495,12 +2494,9 @@ def foo2(x): return x + 1 x = torch.randn(3, 3, requires_grad=True) - with pytest.warns(UserWarning, match="have no effect under thunder.jit"): - thunder.jit(foo2)(x).sum().backward() + thunder.jit(foo2)(x).sum().backward() - # `torch.no_grad` has no effect on thunder's autodiff which determines whether to compute grad based on `requires_grad=True`. - # Thus when backward is called it computes grad for the input. - assert x.grad is not None + assert x.grad is None def test_serialize_trace(): From 91859f2eb6780775eafc6c2d8dabe802aae7637f Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 13 Nov 2024 14:29:17 +0100 Subject: [PATCH 04/11] update impl to use compiledata --- thunder/__init__.py | 7 ++--- thunder/common.py | 2 ++ thunder/core/symbol.py | 12 +++++++- thunder/core/transform_common.py | 49 -------------------------------- thunder/torch/__init__.py | 2 +- 5 files changed, 17 insertions(+), 55 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 797b63da88..1bc31def1b 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -35,7 +35,6 @@ wrap_return_value_together_with_argments, unwrap_return_value, remove_context_manager_prims_from_trace, - tag_no_grad_symbols_pass, ) from thunder.core.functionalization import ( check_inplace_to_views, @@ -443,6 +442,9 @@ def get_computation_and_inputs(*args, **kwargs): # which seems to break the consistency of cache_info, leading to a failure in cache_info check. cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs) + cache_info["is_grad_enabled"] = pytorch.is_grad_enabled() + cd.is_grad_enabled = pytorch.is_grad_enabled() + # TODO RC1 Add module and function checks to prologue (make it a compile option) # Checks cache @@ -539,9 +541,6 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = wrap_return_value_together_with_argments(computation_trc) computation_traces.append(computation_trc) - computation_trc = tag_no_grad_symbols_pass(computation_trc) - computation_traces.append(computation_trc) - computation_trc = remove_context_manager_prims_from_trace(computation_trc) computation_traces.append(computation_trc) diff --git a/thunder/common.py b/thunder/common.py index bc5f370156..0ab880d79c 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -221,6 +221,8 @@ def __init__( # State for pytorch autocast context managers. self.autocast_stack: AutocastStack = AutocastStack() + self.is_grad_enabled: bool = True + # # Gathers additional metadata # diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index d8827eedbd..61e162489a 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -21,7 +21,8 @@ from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map import thunder.core.dtypes as dtypes import thunder.core.devices as devices -from thunder.core.proxies import Proxy, NumberProxy, variableify, CollectionProxy +from thunder.core.proxies import Proxy, TensorProxy, NumberProxy, variableify, CollectionProxy, ProxyTag +from thunder.core.compile_data import get_compile_data from thunder.core.trace import ( get_tracectx, @@ -320,6 +321,15 @@ def __call__(self, *args, **kwargs): result = self.meta(*args, **kwargs) trace.pop_scope() + if not get_compile_data().is_grad_enabled: + + def tag_tensorproxy_output_as_detached(proxy): + if isinstance(proxy, TensorProxy): + return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) + return proxy + + result = tree_map(tag_tensorproxy_output_as_detached, result) + bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols) symbols_list = trace.peek_scope() diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 3e4d074934..b95a193b3e 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -492,52 +492,3 @@ def is_context_manager_prim(bsym): new_trace.bound_symbols = filtered_bsyms new_trace.set_provenance(TraceProvenance("Remove context manager prims")) return new_trace - - -def tag_no_grad_symbols_pass(trace: Trace) -> Trace: - """ - This function iterates over trace and marks the BoundSymbols - in `no_grad` regions such that VJP pass will treat them as constant. - """ - is_no_grad_region = False - - # NOTE - This will also copy name from original trace. - new_trace = from_trace(trace) - new_bsyms = [] - - for bsym in trace.bound_symbols: - # case - torch._C._set_grad_enabled(False) - if bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and not bsym.args[0]: - is_no_grad_region = True - continue - # case - torch._C._set_grad_enabled(True) - elif bsym.sym.id == thunder.torch._set_grad_enabled_with_warning.id and bsym.args[0]: - is_no_grad_region = False - continue - - if is_no_grad_region: - # Mark the TensorProxy output of the `bsym` - # with `ProxyTag.DETACHED_AUTOGRAD_GRAPH` so that - # vjp will treat this as constant. - - def create_detached_output(t): - if isinstance(t, TensorProxy): - # NOTE - We need `tracectx` as creating/replacing name for proxy - # tries a look-up in current trace. - with tracectx(new_trace): - # Remove the name so that we can re-use it. - # Otherwise, we get a proxy with new name. - new_trace.names.remove(t.name) - return t.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) - - return t - - new_output = tree_map(create_detached_output, bsym.output) - # Create a copy of the `bsym` with `new_output` - bsym = bsym.from_bsym(output=new_output) - - new_bsyms.append(bsym) - - new_trace.bound_symbols = new_bsyms - new_trace.set_provenance(TraceProvenance("no_grad detach graph for vjp pass")) - return new_trace diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 81da76c113..97213ea367 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5238,7 +5238,7 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device: # for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag. @torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,)) def _set_grad_enabled_with_warning(enabled: bool) -> None: - pass + get_compile_data().is_grad_enabled = enabled def _unwrap_if_dead(tensor): From dee92b24bbcd557de21ae45b49cd13739c1dc02a Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 13 Nov 2024 14:54:16 +0100 Subject: [PATCH 05/11] update test --- thunder/tests/test_core.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index e9688af6ae..5fefde7a8c 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2480,24 +2480,58 @@ def foo_error(args): def test_grad_ctx(): + # Test `enable_grad` on a function works correctly @torch.enable_grad() def foo1(x): return x + 1 x = torch.randn(3, 3, requires_grad=True) thunder.jit(foo1)(x).sum().backward() - assert x.grad is not None + # Test `no_grad` on a function works correctly @torch.no_grad() def foo2(x): return x + 1 x = torch.randn(3, 3, requires_grad=True) thunder.jit(foo2)(x).sum().backward() - assert x.grad is None + # Test `no_grad` ctx correctly disable gradient computation + def foo3(x): + with torch.no_grad(): + y = x * 3 + return x * 2 + y + + x = torch.randn(3, 3, requires_grad=True) + with torch.no_grad(): + x_ref = x.clone() + x_ref.requires_grad_(True) + + foo3(x_ref).sum().backward() + thunder.jit(foo3)(x).sum().backward() + # Verify the gradients match + torch.testing.assert_close(x.grad, x_ref.grad) + + # Test nested `no_grad` and `enable_grad` + def foo4(x): + with torch.enable_grad(): + with torch.no_grad(): + y = x * 3 + z = x * 4 + return x * 2 + y + z + + x = torch.randn(3, 3, requires_grad=True) + with torch.no_grad(): + x_ref = x.clone() + x_ref.requires_grad_(True) + + foo4(x_ref).sum().backward() + thunder.jit(foo4)(x).sum().backward() + # Verify the gradients match + torch.testing.assert_close(x.grad, x_ref.grad) + def test_serialize_trace(): import dill as pickle From a8588ce6e44b926ba175bb9175cf69ff0e2e01b1 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 13 Nov 2024 14:58:05 +0100 Subject: [PATCH 06/11] update --- thunder/core/transform_common.py | 2 +- thunder/torch/__init__.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index b95a193b3e..e21ca4b28a 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -10,7 +10,7 @@ import thunder import thunder.core.prims as prims from thunder.core.baseutils import BoundSymbolInterface -from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify, ProxyTag +from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace, tracectx diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 97213ea367..4f497ecaf9 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5238,6 +5238,12 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device: # for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag. @torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,)) def _set_grad_enabled_with_warning(enabled: bool) -> None: + cd = get_compile_data() + if cd is None: + warnings.warn( + "torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect, use thunder.jit for correct behaviour" + ) + return get_compile_data().is_grad_enabled = enabled From 3f7a25d98b44274e5b79b7359920b0e35dbd45c3 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 13 Nov 2024 15:07:57 +0100 Subject: [PATCH 07/11] update --- thunder/core/symbol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 61e162489a..40d8c400d2 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -321,7 +321,8 @@ def __call__(self, *args, **kwargs): result = self.meta(*args, **kwargs) trace.pop_scope() - if not get_compile_data().is_grad_enabled: + cd = get_compile_data() + if cd is not None and not cd.is_grad_enabled: def tag_tensorproxy_output_as_detached(proxy): if isinstance(proxy, TensorProxy): From a082ba325267b5b828bd96148044642a6751b540 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 13 Nov 2024 15:11:41 +0100 Subject: [PATCH 08/11] add comment --- thunder/__init__.py | 2 ++ thunder/common.py | 2 ++ thunder/core/symbol.py | 4 +++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 1bc31def1b..0c958f5872 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -442,6 +442,8 @@ def get_computation_and_inputs(*args, **kwargs): # which seems to break the consistency of cache_info, leading to a failure in cache_info check. cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs) + # Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform + # to treat certain Symbols as constant. cache_info["is_grad_enabled"] = pytorch.is_grad_enabled() cd.is_grad_enabled = pytorch.is_grad_enabled() diff --git a/thunder/common.py b/thunder/common.py index 0ab880d79c..674cab65d8 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -221,6 +221,8 @@ def __init__( # State for pytorch autocast context managers. self.autocast_stack: AutocastStack = AutocastStack() + # State to query whether grad is enabled or disabled using + # torch.no_grad/torch.enable_grad/torch._C._set_grad_enabled self.is_grad_enabled: bool = True # diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index 40d8c400d2..ead4044b5c 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -323,7 +323,9 @@ def __call__(self, *args, **kwargs): cd = get_compile_data() if cd is not None and not cd.is_grad_enabled: - + # If grad is disabled using `torch.no_grad` or `torch._C._set_grad_enabled(False)`, + # tag the results with `DETACHED_AUTOGRAD_GRAPH` which makes this Symbol a constant for + # vjp transform (applied later). def tag_tensorproxy_output_as_detached(proxy): if isinstance(proxy, TensorProxy): return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) From b71214cb5172a1305f1d0e030768974812ce8fcc Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 13 Nov 2024 15:12:40 +0100 Subject: [PATCH 09/11] update comment --- thunder/core/transforms.py | 2 +- thunder/torch/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 94596b732e..9ebfd66cac 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2486,7 +2486,7 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool: bool: True if the symbol is constant, False otherwise. """ are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args) - # `no_grad_detach_graph_pass` tags output of BoundSymbols in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`. + # Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`. # These are treated as constant for VJP. # NOTE - `any(()) is False` output_disconnected_from_graph = any( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 4f497ecaf9..c15fd8b359 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5228,7 +5228,7 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device: # Tag to use on Proxies created in `no_grad` regions. -# VJP transform will treat BOundSymbol's whose output has these tags +# VJP transform will treat BoundSymbol's whose output has these tags # as constant. ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH") From 43468420f812ff7a9aad033a2a668c2a466ca375 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 13 Nov 2024 15:29:25 +0100 Subject: [PATCH 10/11] fix --- thunder/core/symbol.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index ead4044b5c..e5688a013a 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -328,6 +328,8 @@ def __call__(self, *args, **kwargs): # vjp transform (applied later). def tag_tensorproxy_output_as_detached(proxy): if isinstance(proxy, TensorProxy): + # We need to remove name from trace, otherwise replace will return a proxy with new name. + trace.names.remove(proxy.name) return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) return proxy From 649216b81242303ba9533330b8c5356670b9e19d Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Wed, 13 Nov 2024 16:11:57 +0100 Subject: [PATCH 11/11] more test and comment --- thunder/tests/test_core.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 5fefde7a8c..29c491be08 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2480,6 +2480,10 @@ def foo_error(args): def test_grad_ctx(): + # NOTE - This test would start failing if tags on Proxies are dropped + # as the computation under `no_grad` won't be treated as constant + # and grad won't match with PyTorch eager. + # Test `enable_grad` on a function works correctly @torch.enable_grad() def foo1(x): @@ -2532,6 +2536,26 @@ def foo4(x): # Verify the gradients match torch.testing.assert_close(x.grad, x_ref.grad) + def foo5(x): + return x * 2 + + x = torch.randn(3, 3, requires_grad=True) + with torch.no_grad(): + x_ref = x.clone() + x_ref.requires_grad_(True) + + jfoo = thunder.jit(foo5) + with torch.no_grad(): + o = jfoo(x) + assert o.grad_fn is None + assert thunder.cache_misses(jfoo) == 1 # First compilation + + # Running it out of `torch.no_grad`, should lead to recompile. + foo5(x_ref).sum().backward() + jfoo(x).sum().backward() + torch.testing.assert_close(x.grad, x_ref.grad) + assert thunder.cache_misses(jfoo) == 2 + def test_serialize_trace(): import dill as pickle