Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support no_grad in thunder.jit #1423

Merged
merged 13 commits into from
Nov 18, 2024
5 changes: 5 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,11 @@ def get_computation_and_inputs(*args, **kwargs):
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)

# Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform
# to treat certain Symbols as constant.
cache_info["is_grad_enabled"] = pytorch.is_grad_enabled()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cache_info entry needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well if you call with grad enabled and then without, you would want to have a cache miss?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we think we don't need it, let's remove it in a follow-up.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for the following cases

Example 1

jfn = thunder.jit(fn)
with torch.no_grad():
    jfn(x)   # This will be compiled with no_grad

jfn(x)   # We want this to be recompiled.

Example 2

jfn = thunder.jit(fn)

jfn(x)

with torch.no_grad():
    jfn(x)   # We want this to be recompiled

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well if you call with grad enabled and then without, you would want to have a cache miss?

Of course.

How does this work if the content of cache_info["is_grad_enabled"] is not checked anywhere in this pull request?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In jit_ext.py, we read the values from cache_info and add corresponding checks in prologue.

cache_info = thunder._get_cache_info()
# assert len of cache info to ensure that we're not missing anything?
if cache_info:
cache_info_p = Proxy(name="cache_info")
bsym = prims.unpack_cache_info.bind(cache_info_p, output=cache_info_p)
prologue_trace.bound_symbols.append(bsym)
for k, v in cache_info.items():
p = proxy(v, name=f"cache_info_{k}", history=None)
bsym = prims.unpack_getitem.bind(cache_info_p, k, output=p)
prologue_trace.bound_symbols.append(bsym)
if isinstance(v, str):
clang.check_string_value(p, v)
elif isinstance(v, (int, bool, float)):
clang.check_number_type_and_value(p, v)
elif isinstance(v, (torch.dtype, torch.device)):
clang.check_literal_like(p, v)
else:
raise NotImplementedError(f"cache info of type {type(v).__name__}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

cd.is_grad_enabled = pytorch.is_grad_enabled()

# TODO RC1 Add module and function checks to prologue (make it a compile option)

# Checks cache
Expand Down
4 changes: 4 additions & 0 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def __init__(
# State for pytorch autocast context managers.
self.autocast_stack: AutocastStack = AutocastStack()

# State to query whether grad is enabled or disabled using
# torch.no_grad/torch.enable_grad/torch._C._set_grad_enabled
self.is_grad_enabled: bool = True

#
# Gathers additional metadata
#
Expand Down
17 changes: 16 additions & 1 deletion thunder/core/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.proxies import Proxy, NumberProxy, variableify, CollectionProxy
from thunder.core.proxies import Proxy, TensorProxy, NumberProxy, variableify, CollectionProxy, ProxyTag
from thunder.core.compile_data import get_compile_data

from thunder.core.trace import (
get_tracectx,
Expand Down Expand Up @@ -320,6 +321,20 @@ def __call__(self, *args, **kwargs):
result = self.meta(*args, **kwargs)
trace.pop_scope()

cd = get_compile_data()
if cd is not None and not cd.is_grad_enabled:
t-vi marked this conversation as resolved.
Show resolved Hide resolved
# If grad is disabled using `torch.no_grad` or `torch._C._set_grad_enabled(False)`,
# tag the results with `DETACHED_AUTOGRAD_GRAPH` which makes this Symbol a constant for
# vjp transform (applied later).
def tag_tensorproxy_output_as_detached(proxy):
if isinstance(proxy, TensorProxy):
# We need to remove name from trace, otherwise replace will return a proxy with new name.
trace.names.remove(proxy.name)
return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,))
t-vi marked this conversation as resolved.
Show resolved Hide resolved
return proxy

result = tree_map(tag_tensorproxy_output_as_detached, result)

bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)
symbols_list = trace.peek_scope()

Expand Down
8 changes: 8 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
variableify,
unvariableify,
FutureTensorProxy,
ProxyTag,
)
from thunder.core.compile_data import get_compile_data, get_compile_option
from thunder.core.langctxs import langctx, Languages
Expand Down Expand Up @@ -2485,10 +2486,17 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
bool: True if the symbol is constant, False otherwise.
"""
are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args)
# Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`.
# These are treated as constant for VJP.
# NOTE - `any(()) is False`
output_disconnected_from_graph = any(
ProxyTag.DETACHED_AUTOGRAD_GRAPH in o.tags for o in symbol.flat_outs if isinstance(o, TensorProxy)
)
return (
are_all_args_non_differentiable
or symbol.are_all_args_constant
or symbol.sym.id in nondifferentiable_vjp_symbols
or output_disconnected_from_graph
)


Expand Down
48 changes: 39 additions & 9 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2099,7 +2099,7 @@ def func(x):
compiled = executor.make_callable(func)
out = compiled(x)
assert out is x
initial_trace_with_dce = thunder.last_traces(compiled)[3]
initial_trace_with_dce = thunder.last_traces(compiled)[4]
assert "Constructed by Dead Code Elimination" in str(initial_trace_with_dce)
assert len(initial_trace_with_dce.bound_symbols) == 2
assert initial_trace_with_dce.bound_symbols[0].sym.id == prims.PrimIDs.UNPACK_TRIVIAL
Expand Down Expand Up @@ -2480,27 +2480,57 @@ def foo_error(args):


def test_grad_ctx():
# Test `enable_grad` on a function works correctly
@torch.enable_grad()
def foo1(x):
return x + 1

x = torch.randn(3, 3, requires_grad=True)
with pytest.warns(UserWarning, match="have no effect under thunder.jit"):
thunder.jit(foo1)(x).sum().backward()

thunder.jit(foo1)(x).sum().backward()
assert x.grad is not None

# Test `no_grad` on a function works correctly
@torch.no_grad()
def foo2(x):
return x + 1

x = torch.randn(3, 3, requires_grad=True)
with pytest.warns(UserWarning, match="have no effect under thunder.jit"):
thunder.jit(foo2)(x).sum().backward()
thunder.jit(foo2)(x).sum().backward()
assert x.grad is None

# `torch.no_grad` has no effect on thunder's autodiff which determines whether to compute grad based on `requires_grad=True`.
# Thus when backward is called it computes grad for the input.
assert x.grad is not None
# Test `no_grad` ctx correctly disable gradient computation
def foo3(x):
with torch.no_grad():
y = x * 3
return x * 2 + y

x = torch.randn(3, 3, requires_grad=True)
with torch.no_grad():
x_ref = x.clone()
x_ref.requires_grad_(True)

foo3(x_ref).sum().backward()
thunder.jit(foo3)(x).sum().backward()
# Verify the gradients match
torch.testing.assert_close(x.grad, x_ref.grad)

# Test nested `no_grad` and `enable_grad`
def foo4(x):
with torch.enable_grad():
with torch.no_grad():
y = x * 3
z = x * 4
return x * 2 + y + z

x = torch.randn(3, 3, requires_grad=True)
with torch.no_grad():
x_ref = x.clone()
x_ref.requires_grad_(True)

foo4(x_ref).sum().backward()
thunder.jit(foo4)(x).sum().backward()
# Verify the gradients match
torch.testing.assert_close(x.grad, x_ref.grad)


def test_serialize_trace():
Expand Down
20 changes: 17 additions & 3 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ListProxy,
DictProxy,
numberproxy,
ProxyTag,
)
from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol
Expand Down Expand Up @@ -5226,11 +5227,24 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device:
register_function(torch.device, torch_device)


def _set_grad_enabled_with_warning(enabled: bool) -> None:
warnings.warn("torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect under thunder.jit")
# Tag to use on Proxies created in `no_grad` regions.
# VJP transform will treat BoundSymbol's whose output has these tags
# as constant.
ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH")


register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning)
# This is just a marker Symbol. `tag_no_grad_symbols_pass` pass uses these symbols
# to find the `no_grad` regions and mark the BoundSymbols within them as constant
# for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag.
@torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,))
def _set_grad_enabled_with_warning(enabled: bool) -> None:
cd = get_compile_data()
if cd is None:
warnings.warn(
"torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect, use thunder.jit for correct behaviour"
)
return
get_compile_data().is_grad_enabled = enabled


def _unwrap_if_dead(tensor):
Expand Down
Loading