Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise when in place operations occur on leafs requiring grad #1458

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

beverlylytle
Copy link
Collaborator

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #1284

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@beverlylytle beverlylytle marked this pull request as ready for review November 21, 2024 11:21
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix looks good. We should add a small test to verify that this error raised when expected. Thanks @beverlylytle

thunder/tests/test_inplace_functionalization.py Outdated Show resolved Hide resolved
@@ -2190,6 +2182,9 @@ def is_float_type(self, input):


def _copy__impl(copy_from, copy_to):
cd = get_compile_data()
if cd is not None and cd.is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad:
raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if Symbol copy_ in thunder/torch/__init__.py is more appropriate location for the check.

@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return prims.copy_(b, a)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a and b are proxies and it it not clear to me if a proxy knows that it is a leaf.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do not. It's only a PyTorch concept that's available at runtime inside _copy__impl.

Copy link
Collaborator

@kshitij12345 kshitij12345 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, previously I missed that the fix was in copy_impl. And since, it is happening at runtime, I am wondering if compile_data is actually available.

Quick test shows (see below) that it wouldn't be. So, we probably need a way to check if this copy was called under no_grad in users code (as PyTorch supports inplace of leaf tensors under no_grad, see comment).

Snippet to check if compile_data is available -

import torch
import thunder
from thunder.extend import OperatorExecutor
from thunder.core.compile_data import get_compile_data
from thunder.core.proxies import TensorProxy

ex = OperatorExecutor("ex")

def clone_impl(x):
    cd = get_compile_data()
    print(cd)  # None
    return x

clone = ex.register_operator("clone", meta=lambda x: TensorProxy(like=x), fn=clone_impl)

def fn(x):
    return clone(x)

x = torch.ones(3)

jfn = thunder.jit(fn)

jfn(x)
exec_trace = thunder.last_traces(jfn)[-1]
# print(exec_trace)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, compile_data was not available, but now it should be with the added context manager in thunder/init.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is still incorrect because as discussed in #1486, the value of compile_data.is_grad_enabled here would be that of last updated state which can lead to incorrectness when used outside of tracing context.

We can see the discrepancy here.

import torch
import thunder

x = torch.randn(3, 3, requires_grad=True)

@torch.no_grad
def fn(x):
  return x.add_(1)

fn(x)  # This works

thunder.jit(fn)(x)  # This raises error

So, whether the copy is in no_grad region needs to be captured during the tracing time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is why I created the other issue. This PR fixes the leaf/grad issue when there is no annotation. When there is an annotation, another approach is required. This other approach may or may not involve using compile data in _copy__impl.

As far as I understand, compile data is the medium for passing around data such as whether grad is enabled. But as the other issue points out, compile data reflects the end state of a function call and not the "live" state, at least at the time it reaches _copy__impl. So I'm left with the questions "are there other mechanisms for passing around whether grad is enabled?" "where else in the execution is it simultaneously knowable that a (1) leaf tensor (2) requiring grad is being (3) copied when (4) grad is enabled?" "is it feasible/desirable to make the compile data more dynamic?" "is there a way to context-manage the tensors so that their requires_grad flags are set to False when the interpreter sees torch._C._set_grad_enabled(False), and then later restored, thereby obviating the need for the compile data for this check?" Do you have suggestions for a fix that addresses both issues? Or can we close out this issue and move the discussion to the more involved issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to tackle - leaf tensor requiring grad being copied into when grad is enabled, I think similar to a previous commit,
we can update prims.copy to take a argument is_grad_enabled. With this, ltorch.copy will query cd.is_grad_enabled and call prims.copy by also passing this argument.

@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return prims.copy_(b, a)

With these changes, the copy_impl's signature will also change to accept is_grad_enabled and it will be called at runtime with a tensor which we can query if it is a leaf and also whether grad was enabled or not when calling that particular copy. Wdyt @beverlylytle?

Though, I am curious if there is another approach to this - cc: @IvanYashchuk

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see what the CI thinks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with modifying thunder.torch.copy to query cd.is_grad_enabled and passing that to prims.copy.

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me, I just have a couple of questions. Thank you @beverlylytle

@@ -2085,6 +2087,7 @@ def copy_(
*,
fd: FusionDefinition,
lc_to_nv_map: dict,
grad_enabled: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the behaviour for nvfuser? I think that we ignore this argument. Should we raise a warning instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the argument is ignored, and nvfuser does not fail.

thunder/tests/test_inplace_functionalization.py Outdated Show resolved Hide resolved
@@ -1983,7 +1983,8 @@ def copysign_(a, b, /):

@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return prims.copy_(b, a)
cd = get_compile_data()
return prims.copy_(b, a, grad_enabled=cd.is_grad_enabled if cd is not None else False)
Copy link
Collaborator

@kshitij12345 kshitij12345 Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if cd is None (probably happens for thunder.trace with default arguments), should we assume that we are running with grad_enabled with a warning? I think that it is likely case. Wdyt?

cc: @IvanYashchuk

Copy link
Collaborator Author

@beverlylytle beverlylytle Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have the background to have an opinion on this. I defer to you @kshitij12345 and @IvanYashchuk.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cd is our controlled way of specifying and querying the state of PyTorch. If it's None, I don't think we should do anything special. It's the responsibility of the outside system to set up a correct cd object. is_grad_enabled is a sensible default because if nothing else is specified, we should assume we are executing a program as given in the "inference" mode with no additional side transformations.

thunder/core/prims.py Outdated Show resolved Hide resolved
thunder/executors/nvfuserex_impl.py Outdated Show resolved Hide resolved
@@ -2190,6 +2182,9 @@ def is_float_type(self, input):


def _copy__impl(copy_from, copy_to):
cd = get_compile_data()
if cd is not None and cd.is_grad_enabled and copy_to.is_leaf and copy_to.requires_grad:
raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with modifying thunder.torch.copy to query cd.is_grad_enabled and passing that to prims.copy.

thunder/tests/test_inplace_functionalization.py Outdated Show resolved Hide resolved
thunder/tests/test_inplace_functionalization.py Outdated Show resolved Hide resolved
thunder/executors/nvfuserex_impl.py Outdated Show resolved Hide resolved
Comment on lines +225 to +234
def _copy_(a, b, /):
cd = get_compile_data()
return prims.copy_(b, a, grad_enabled=cd.is_grad_enabled if cd is not None else False)


@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return _copy_(a, b)


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider the following snippet:

import thunder
import torch


x = torch.rand((2,3), dtype=torch.float32, device='cuda')

def f(x):
    return x.to(torch.float64).sin_()

Here a new tensor of type float64 is being created and an inplace operation is occurring on it. Explicitly:

def f(x):
    y = x.to(torch.float64)
    return y.sin_()

One might expect that the following would be a less efficient, but still more or less equivalent version of the above:

def g(x):
    y = x.to(torch.float64)
    z = y.sin()
    return y.copy_(z)

However, for

jf = thunder.jit(f); jg = thunder.jit(g)

jf(x) executes successfully while jg(x) results in

An error occurred while defining nvFuser FusionDefinition None.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
Here's a script to reproduce the error:
```python
# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
# torch version: 2.5.1+cu124
# cuda version: 12.4
# nvfuser version: 0.2.23+gitd53be45
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_incomplete_fusion(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[2, 3], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.ops.cast(T0, dtype=DataType.Double)
    T2 = fd.ops.sin(T1)
    T3 = fd.ops.set(T2)
    fd.add_output(T3, T1)
    fd.add_output(T1)

with FusionDefinition() as fd:
    nvfuser_fusion_idNone(fd)
```<adding extra characters for md>
Traceback (most recent call last):
  File "/home/blytle/miniforge3/envs/thdrs/lib/python3.10/site-packages/nvfuser/__init__.py", line 105, in __exit__
    self._finalize_definition()
RuntimeError:  INTERNAL ASSERT FAILED at "/workspace/Fuser/csrc/python_frontend/fusion_state.cpp":141, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. 
Detected exception while building Fusion Ir. The failing RecordFunctor is: fd.add_output(T3, T1)
NvFuser error message is:  INTERNAL ASSERT FAILED at "/workspace/Fuser/csrc/fusion.cpp":784, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. alias source can only be a fusion input
Exception raised from aliasOutputToInput at /workspace/Fuser/csrc/fusion.cpp:784 (most recent call first):
….

nvFuser does not like inplace operations performed on tensors which are not input to the fusion definition. Happily, they are usually functionalized away in during tracing. The functionalization pass makes assumptions about what an inplace op looks like. Using a version of copy_ annotated by torchsymbol within the other inplace ops, like sin_ , breaks those assumptions and leads to many tests failing with errors as above. Hence, the split between _copy_ and copy_.

As a side note, I was also surprised to discover that

def f(x):
    x.sin_()
    return x

def g(x):
    z = torch.sin(x)
    x.copy_(z)
    return x


x = torch.rand((2, 2), device='cuda', dtype=torch.float64)

jf = thunder.jit(f)
jf(x)                  # this is fine
jg = thunder.jit(g)
jg(x)                  # fails with an AssertionError on "assert return_bsym.sym.id == prims.PrimIDs.RETURN"

@@ -2241,7 +2246,7 @@ def true_divide(a: NumberLike | TensorLike, b: NumberLike | TensorLike, /) -> Nu

@torchsymbol(torch.Tensor.true_divide_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def true_divide_(a: TensorLike, b: NumberLike | TensorLike, /) -> TensorLike:
return prims.copy_(true_divide(a, b))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the lack of a second argument here odd.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[inplace] Silently incorrect gradient when leaf variable is used in an inplace operation
5 participants