Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise when in place operations occur on leafs requiring grad #1458

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions thunder/core/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from typing import TYPE_CHECKING

from thunder.core.compile_data import get_compile_data
import thunder.core.prims as prims
from thunder.core.proxies import variableify, TensorProxy, unvariableify, ProxyInterface
from thunder.core.pytree import tree_flatten, tree_unflatten
Expand Down Expand Up @@ -499,8 +500,12 @@ def _reshape_bsym_ctor(src: TensorProxy, dst: TensorProxy, trace: Trace) -> tupl
copy_from_for_new_copy = reshaped_copy_from
else:
copy_from_for_new_copy = copy_from
new_copy_return = prims.copy_.meta(copy_from_for_new_copy, new_copy_to)
new_copy_bsym = prims.copy_.bind(copy_from_for_new_copy, new_copy_to, output=new_copy_return)
cd = get_compile_data()
grad_enabled = cd.is_grad_enabled if cd is not None else False
new_copy_return = prims.copy_.meta(copy_from_for_new_copy, new_copy_to, grad_enabled=grad_enabled)
new_copy_bsym = prims.copy_.bind(
copy_from_for_new_copy, new_copy_to, grad_enabled=grad_enabled, output=new_copy_return
)
copy_bsyms.append(new_copy_bsym)
else:
var_copy_to = variableify(copy_to)
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -4030,6 +4030,8 @@ def embedding_backward_meta(grad, indices, num_weights, padding_idx, scale_grad_
def copy__meta(
copy_from: TensorProxy,
copy_to: TensorProxy,
*,
grad_enabled: bool,
):
utils.check_type(copy_from, TensorProxy)
utils.check_type(copy_to, TensorProxy)
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,7 +1625,7 @@ def zeros_like(x):
prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)),
prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)),
prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)),
prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()),
prims.PrimIDs.COPY_: lambda x, y, grad_enabled: (prims.copy_(x, y, grad_enabled=grad_enabled), tuple()),
prims.PrimIDs.CLONE: lambda x: (prims.clone(x), tuple()),
}

Expand Down
3 changes: 3 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,6 +2093,8 @@ def var_mean(
def _copy__check(
copy_from: TensorProxy,
copy_to: TensorProxy,
*,
grad_enabled: bool,
) -> bool:
return are_supported_tensors(copy_from, copy_to)

Expand All @@ -2101,6 +2103,7 @@ def copy_(
copy_from: TensorProxy,
copy_to: TensorProxy,
*,
grad_enabled: bool,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:
Expand Down
24 changes: 10 additions & 14 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,24 @@
from __future__ import annotations
import operator
import importlib
from dataclasses import replace
from contextlib import ContextDecorator
from functools import wraps, partial
from inspect import signature
from itertools import groupby
from functools import partial, wraps
from numbers import Number
from typing import TYPE_CHECKING
from collections.abc import Callable
from collections.abc import Hashable, Sequence
from collections.abc import Sequence
from types import ModuleType
from enum import Enum, auto

import torch
import math
from looseversion import LooseVersion

from thunder.core.compile_data import get_compile_data
import thunder.core.dtypes as dtypes
from thunder.core.dtypes import to_torch_dtype, to_dtype
import thunder.core.devices as devices
from thunder.core.devices import to_torch_device, to_device
import thunder.core.prims as prims
from thunder.core.trace import TraceCtx, set_tracectx, reset_tracectx, from_trace
from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, variableify, pytype
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol, BoundSymbol
from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, pytype
from thunder.core.symbol import Symbol
from thunder.distributed.prims import DistributedReduceOps
import thunder.distributed.prims as dist_prims
import thunder.core.utils as utils
Expand Down Expand Up @@ -2202,12 +2194,16 @@ def is_float_type(self, input):
einops._backends._type2backend[TensorProxy] = EinopsThunderBackend()


def _copy__impl(copy_from, copy_to):
def _copy__impl(copy_from, copy_to, grad_enabled):
if grad_enabled and copy_to.is_leaf and copy_to.requires_grad:
raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if Symbol copy_ in thunder/torch/__init__.py is more appropriate location for the check.

@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return prims.copy_(b, a)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a and b are proxies and it it not clear to me if a proxy knows that it is a leaf.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do not. It's only a PyTorch concept that's available at runtime inside _copy__impl.

Copy link
Collaborator

@kshitij12345 kshitij12345 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, previously I missed that the fix was in copy_impl. And since, it is happening at runtime, I am wondering if compile_data is actually available.

Quick test shows (see below) that it wouldn't be. So, we probably need a way to check if this copy was called under no_grad in users code (as PyTorch supports inplace of leaf tensors under no_grad, see comment).

Snippet to check if compile_data is available -

import torch
import thunder
from thunder.extend import OperatorExecutor
from thunder.core.compile_data import get_compile_data
from thunder.core.proxies import TensorProxy

ex = OperatorExecutor("ex")

def clone_impl(x):
    cd = get_compile_data()
    print(cd)  # None
    return x

clone = ex.register_operator("clone", meta=lambda x: TensorProxy(like=x), fn=clone_impl)

def fn(x):
    return clone(x)

x = torch.ones(3)

jfn = thunder.jit(fn)

jfn(x)
exec_trace = thunder.last_traces(jfn)[-1]
# print(exec_trace)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, compile_data was not available, but now it should be with the added context manager in thunder/init.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is still incorrect because as discussed in #1486, the value of compile_data.is_grad_enabled here would be that of last updated state which can lead to incorrectness when used outside of tracing context.

We can see the discrepancy here.

import torch
import thunder

x = torch.randn(3, 3, requires_grad=True)

@torch.no_grad
def fn(x):
  return x.add_(1)

fn(x)  # This works

thunder.jit(fn)(x)  # This raises error

So, whether the copy is in no_grad region needs to be captured during the tracing time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is why I created the other issue. This PR fixes the leaf/grad issue when there is no annotation. When there is an annotation, another approach is required. This other approach may or may not involve using compile data in _copy__impl.

As far as I understand, compile data is the medium for passing around data such as whether grad is enabled. But as the other issue points out, compile data reflects the end state of a function call and not the "live" state, at least at the time it reaches _copy__impl. So I'm left with the questions "are there other mechanisms for passing around whether grad is enabled?" "where else in the execution is it simultaneously knowable that a (1) leaf tensor (2) requiring grad is being (3) copied when (4) grad is enabled?" "is it feasible/desirable to make the compile data more dynamic?" "is there a way to context-manage the tensors so that their requires_grad flags are set to False when the interpreter sees torch._C._set_grad_enabled(False), and then later restored, thereby obviating the need for the compile data for this check?" Do you have suggestions for a fix that addresses both issues? Or can we close out this issue and move the discussion to the more involved issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to tackle - leaf tensor requiring grad being copied into when grad is enabled, I think similar to a previous commit,
we can update prims.copy to take a argument is_grad_enabled. With this, ltorch.copy will query cd.is_grad_enabled and call prims.copy by also passing this argument.

@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return prims.copy_(b, a)

With these changes, the copy_impl's signature will also change to accept is_grad_enabled and it will be called at runtime with a tensor which we can query if it is a leaf and also whether grad was enabled or not when calling that particular copy. Wdyt @beverlylytle?

Though, I am curious if there is another approach to this - cc: @IvanYashchuk

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see what the CI thinks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with modifying thunder.torch.copy to query cd.is_grad_enabled and passing that to prims.copy.

copy_to.copy_(copy_from)
return copy_to


copy_ = ex.register_operator("copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl)
copy_ = ex.register_operator(
"copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl, module=torch.Tensor
)
_register_implementation(prims.copy_, copy_, checker=_always_executable)


Expand Down
1 change: 1 addition & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def _init_group(self, group, params, grads):
params.append(p)
grads.append(p.grad)

@torch.no_grad
def step(self):
for group in self.param_groups:
params = []
Expand Down
33 changes: 23 additions & 10 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import thunder
import thunder.core.dtypes as datatypes
import thunder.torch as ttorch
from thunder.tests.framework import instantiate, nvFuserExecutor
from thunder.tests.framework import instantiate, nvFuserExecutor, TorchExecutor


@instantiate(dtypes=datatypes.all_dtypes - datatypes.float_8bit_dtypes)
Expand All @@ -20,7 +20,7 @@ def torch_foo(x, y):
def foo(x, y):
z = x + y
# NOTE: nvfuserex doesn't support `return z`, i.e. the copy_from argument
o = thunder.core.prims.copy_(z, x)
o = thunder.core.prims.copy_(z, x, grad_enabled=True)
return o

traced_nvfuser_foo = executor.make_callable(foo)
Expand Down Expand Up @@ -49,7 +49,7 @@ def torch_foo(x, y):
def foo(x, y):
z = x * y
z = z * x
o = thunder.core.prims.copy_(z, x)
o = thunder.core.prims.copy_(z, x, grad_enabled=True)
p = y * y
return p

Expand Down Expand Up @@ -120,25 +120,25 @@ def forward(self, x):
def test_inplace_copy_sanity_check(executor, device, dtype):
def func0(x, y):
z = x * y
x = thunder.core.prims.copy_(z, x)
x = thunder.core.prims.copy_(z, x, grad_enabled=True)
return x + y

def func1(x, y):
z = x * y
o1 = thunder.core.prims.copy_(z, x)
o2 = thunder.core.prims.copy_(y, x)
o1 = thunder.core.prims.copy_(z, x, grad_enabled=True)
o2 = thunder.core.prims.copy_(y, x, grad_enabled=True)
return x, o1, o2

def func2(x, y):
z = x * y
o1 = thunder.core.prims.copy_(z, x)
o2 = thunder.core.prims.copy_(x, y)
o1 = thunder.core.prims.copy_(z, x, grad_enabled=True)
o2 = thunder.core.prims.copy_(x, y, grad_enabled=True)
return y, o1, o2

def func3(x, y):
z = x * y
o1 = thunder.core.prims.copy_(z, x)
o2 = thunder.core.prims.copy_(o1, y)
o1 = thunder.core.prims.copy_(z, x, grad_enabled=True)
o2 = thunder.core.prims.copy_(o1, y, grad_enabled=True)
return y, o2

for foo in (func0, func1, func2, func3):
Expand Down Expand Up @@ -178,3 +178,16 @@ def func(T0):
assert_close(a_ref, a)
for o, o_ref in zip(o_thunder, o_eager):
assert_close(o, o_ref)


@instantiate(executors=(TorchExecutor,), dtypes=datatypes.float_math_dtypes)
def test_inplace_copy_of_leaf_requiring_grad_fails(executor, device, dtype):
def fn(x):
x.copy_(x)

jitted_fn = executor.make_callable(fn)

tdtype = ttorch.to_torch_dtype(dtype)
a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=True)
with pytest.raises(RuntimeError):
jitted_fn(a)
4 changes: 2 additions & 2 deletions thunder/tests/test_inplace_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,11 @@ def f(xs, ys, z):
def test_inplace_to_tensors_with_grad(executor, device, _):
@torch.no_grad
beverlylytle marked this conversation as resolved.
Show resolved Hide resolved
def add_y(x, y):
x.add_(y, alpha=0.1)
return x.add_(y, alpha=0.1)

@torch.no_grad
def add_grad(x, y):
x.add_(x.grad, alpha=0.1)
return x.add_(x.grad, alpha=0.1)

for f in (add_y, add_grad):
jitted_f = executor.make_callable(f)
Expand Down
Loading
Loading