diff --git a/thunder/core/functionalization.py b/thunder/core/functionalization.py index 86859e6428..a9cb944010 100644 --- a/thunder/core/functionalization.py +++ b/thunder/core/functionalization.py @@ -2,6 +2,7 @@ from collections import defaultdict from typing import TYPE_CHECKING +from thunder.core.compile_data import get_compile_data import thunder.core.prims as prims from thunder.core.proxies import variableify, TensorProxy, unvariableify, ProxyInterface from thunder.core.pytree import tree_flatten, tree_unflatten @@ -499,8 +500,12 @@ def _reshape_bsym_ctor(src: TensorProxy, dst: TensorProxy, trace: Trace) -> tupl copy_from_for_new_copy = reshaped_copy_from else: copy_from_for_new_copy = copy_from - new_copy_return = prims.copy_.meta(copy_from_for_new_copy, new_copy_to) - new_copy_bsym = prims.copy_.bind(copy_from_for_new_copy, new_copy_to, output=new_copy_return) + cd = get_compile_data() + grad_enabled = cd.is_grad_enabled if cd is not None else False + new_copy_return = prims.copy_.meta(copy_from_for_new_copy, new_copy_to, grad_enabled=grad_enabled) + new_copy_bsym = prims.copy_.bind( + copy_from_for_new_copy, new_copy_to, grad_enabled=grad_enabled, output=new_copy_return + ) copy_bsyms.append(new_copy_bsym) else: var_copy_to = variableify(copy_to) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index c17a28296c..9393ad2496 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -4030,6 +4030,8 @@ def embedding_backward_meta(grad, indices, num_weights, padding_idx, scale_grad_ def copy__meta( copy_from: TensorProxy, copy_to: TensorProxy, + *, + grad_enabled: bool, ): utils.check_type(copy_from, TensorProxy) utils.check_type(copy_to, TensorProxy) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 3cb34a623e..566d340fac 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1625,7 +1625,7 @@ def zeros_like(x): prims.PrimIDs.LOG2: lambda x: (prims.log2(x), (x,)), prims.PrimIDs.ZETA: lambda x, y: (prims.zeta(x, y), (x, y)), prims.PrimIDs.FMOD: lambda x, y: (prims.fmod(x, y), (x, y)), - prims.PrimIDs.COPY_: lambda x, y: (prims.copy_(x, y), tuple()), + prims.PrimIDs.COPY_: lambda x, y, grad_enabled: (prims.copy_(x, y, grad_enabled=grad_enabled), tuple()), prims.PrimIDs.CLONE: lambda x: (prims.clone(x), tuple()), } diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 4bd78b72b5..3d82140a66 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2093,6 +2093,8 @@ def var_mean( def _copy__check( copy_from: TensorProxy, copy_to: TensorProxy, + *, + grad_enabled: bool, ) -> bool: return are_supported_tensors(copy_from, copy_to) @@ -2101,6 +2103,7 @@ def copy_( copy_from: TensorProxy, copy_to: TensorProxy, *, + grad_enabled: bool, fd: FusionDefinition, lc_to_nv_map: dict, ) -> Any: diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 6de644204d..10ae6e0059 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1,32 +1,24 @@ from __future__ import annotations import operator import importlib -from dataclasses import replace -from contextlib import ContextDecorator -from functools import wraps, partial -from inspect import signature -from itertools import groupby +from functools import partial, wraps from numbers import Number from typing import TYPE_CHECKING from collections.abc import Callable from collections.abc import Hashable, Sequence from collections.abc import Sequence from types import ModuleType -from enum import Enum, auto import torch -import math -from looseversion import LooseVersion +from thunder.core.compile_data import get_compile_data import thunder.core.dtypes as dtypes from thunder.core.dtypes import to_torch_dtype, to_dtype import thunder.core.devices as devices from thunder.core.devices import to_torch_device, to_device import thunder.core.prims as prims -from thunder.core.trace import TraceCtx, set_tracectx, reset_tracectx, from_trace -from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, variableify, pytype -from thunder.core.pytree import tree_flatten, tree_unflatten -from thunder.core.symbol import Symbol, BoundSymbol +from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, pytype +from thunder.core.symbol import Symbol from thunder.distributed.prims import DistributedReduceOps import thunder.distributed.prims as dist_prims import thunder.core.utils as utils @@ -2202,12 +2194,16 @@ def is_float_type(self, input): einops._backends._type2backend[TensorProxy] = EinopsThunderBackend() -def _copy__impl(copy_from, copy_to): +def _copy__impl(copy_from, copy_to, grad_enabled): + if grad_enabled and copy_to.is_leaf and copy_to.requires_grad: + raise RuntimeError("a leaf Variable that requires grad is being used in an in-place operation.") copy_to.copy_(copy_from) return copy_to -copy_ = ex.register_operator("copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl) +copy_ = ex.register_operator( + "copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl, module=torch.Tensor +) _register_implementation(prims.copy_, copy_, checker=_always_executable) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index bf6e9bd7d3..1462ccde0d 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -331,6 +331,7 @@ def _init_group(self, group, params, grads): params.append(p) grads.append(p.grad) + @torch.no_grad def step(self): for group in self.param_groups: params = [] diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 952df2faf0..e324b5fb03 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -7,7 +7,7 @@ import thunder import thunder.core.dtypes as datatypes import thunder.torch as ttorch -from thunder.tests.framework import instantiate, nvFuserExecutor +from thunder.tests.framework import instantiate, nvFuserExecutor, TorchExecutor @instantiate(dtypes=datatypes.all_dtypes - datatypes.float_8bit_dtypes) @@ -20,7 +20,7 @@ def torch_foo(x, y): def foo(x, y): z = x + y # NOTE: nvfuserex doesn't support `return z`, i.e. the copy_from argument - o = thunder.core.prims.copy_(z, x) + o = thunder.core.prims.copy_(z, x, grad_enabled=True) return o traced_nvfuser_foo = executor.make_callable(foo) @@ -49,7 +49,7 @@ def torch_foo(x, y): def foo(x, y): z = x * y z = z * x - o = thunder.core.prims.copy_(z, x) + o = thunder.core.prims.copy_(z, x, grad_enabled=True) p = y * y return p @@ -120,25 +120,25 @@ def forward(self, x): def test_inplace_copy_sanity_check(executor, device, dtype): def func0(x, y): z = x * y - x = thunder.core.prims.copy_(z, x) + x = thunder.core.prims.copy_(z, x, grad_enabled=True) return x + y def func1(x, y): z = x * y - o1 = thunder.core.prims.copy_(z, x) - o2 = thunder.core.prims.copy_(y, x) + o1 = thunder.core.prims.copy_(z, x, grad_enabled=True) + o2 = thunder.core.prims.copy_(y, x, grad_enabled=True) return x, o1, o2 def func2(x, y): z = x * y - o1 = thunder.core.prims.copy_(z, x) - o2 = thunder.core.prims.copy_(x, y) + o1 = thunder.core.prims.copy_(z, x, grad_enabled=True) + o2 = thunder.core.prims.copy_(x, y, grad_enabled=True) return y, o1, o2 def func3(x, y): z = x * y - o1 = thunder.core.prims.copy_(z, x) - o2 = thunder.core.prims.copy_(o1, y) + o1 = thunder.core.prims.copy_(z, x, grad_enabled=True) + o2 = thunder.core.prims.copy_(o1, y, grad_enabled=True) return y, o2 for foo in (func0, func1, func2, func3): @@ -178,3 +178,16 @@ def func(T0): assert_close(a_ref, a) for o, o_ref in zip(o_thunder, o_eager): assert_close(o, o_ref) + + +@instantiate(executors=(TorchExecutor,), dtypes=datatypes.float_math_dtypes) +def test_inplace_copy_of_leaf_requiring_grad_fails(executor, device, dtype): + def fn(x): + x.copy_(x) + + jitted_fn = executor.make_callable(fn) + + tdtype = ttorch.to_torch_dtype(dtype) + a = make_tensor((4, 4), device=device, dtype=tdtype, requires_grad=True) + with pytest.raises(RuntimeError): + jitted_fn(a) diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 6f88f1f8eb..9826bc27de 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -478,11 +478,11 @@ def f(xs, ys, z): def test_inplace_to_tensors_with_grad(executor, device, _): @torch.no_grad def add_y(x, y): - x.add_(y, alpha=0.1) + return x.add_(y, alpha=0.1) @torch.no_grad def add_grad(x, y): - x.add_(x.grad, alpha=0.1) + return x.add_(x.grad, alpha=0.1) for f in (add_y, add_grad): jitted_f = executor.make_callable(f) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 38d0353c1f..cf3fc02736 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -222,6 +222,16 @@ def register_function(torchfn, thunderfn_impl): _torch_to_thunder_function_map[torchfn] = thunderfn_impl +def _copy_(a, b, /): + cd = get_compile_data() + return prims.copy_(b, a, grad_enabled=cd.is_grad_enabled if cd is not None else False) + + +@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,)) +def copy_(a, b, /): + return _copy_(a, b) + + # # Tensor properties # @@ -978,7 +988,7 @@ def setitem(inp, idx, val): @torchsymbol(torch.Tensor.__setitem__, id="setitem_", is_method=True, tags=(prims.OpTags.IN_PLACE,)) def setitem_(inp, idx, val): - prims.copy_(setitem(inp, idx, val), inp) + _copy_(inp, setitem(inp, idx, val)) @torchsymbol(torch.Tensor.__getitem__, id="torch.Tensor.__getitem__", method_name="getitem") @@ -1358,7 +1368,7 @@ def abs(a: NumberLike | TensorLike, /) -> Number | TensorLike: @torchsymbol(torch.abs_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def abs_(a: NumberLike | TensorLike, /) -> Number | TensorLike: - return prims.copy_(abs(a), a) + return _copy_(a, abs(a)) @torchsymbol(torch.acos, is_method=True) @@ -1368,7 +1378,7 @@ def acos(a: NumberLike | TensorLike, /) -> Number | TensorLike: @torchsymbol(torch.acos_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def acos_(a: TensorLike, /) -> TensorLike: - return prims.copy_(acos(a), a) + return _copy_(a, acos(a)) @torchsymbol(torch.acosh, is_method=True) @@ -1378,7 +1388,7 @@ def acosh(a): @torchsymbol(torch.acosh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def acosh_(a): - return prims.copy_(acosh(a), a) + return _copy_(a, acosh(a)) @torchsymbol(torch.asin, is_method=True) @@ -1388,7 +1398,7 @@ def asin(a): @torchsymbol(torch.asin_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def asin_(a): - return prims.copy_(asin(a), a) + return _copy_(a, asin(a)) @torchsymbol(torch.asinh, is_method=True) @@ -1398,7 +1408,7 @@ def asinh(a): @torchsymbol(torch.asinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def asinh_(a): - return prims.copy_(asinh(a), a) + return _copy_(a, asinh(a)) @torchsymbol(torch.atan, is_method=True) @@ -1408,7 +1418,7 @@ def atan(a): @torchsymbol(torch.atan_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def atan_(a): - return prims.copy_(atan(a), a) + return _copy_(a, atan(a)) @torchsymbol(torch.atanh, is_method=True) @@ -1418,7 +1428,7 @@ def atanh(a): @torchsymbol(torch.atanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def atanh_(a): - return prims.copy_(atanh(a), a) + return _copy_(a, atanh(a)) @torchsymbol(torch.bitwise_not, is_method=True) @@ -1428,7 +1438,7 @@ def bitwise_not(a): @torchsymbol(torch.Tensor.bitwise_not_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def bitwise_not_(a): - return prims.copy_(bitwise_not(a), a) + return _copy_(a, bitwise_not(a)) @torchsymbol(torch.ceil, is_method=True) @@ -1438,7 +1448,7 @@ def ceil(a): @torchsymbol(torch.ceil_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def ceil_(a): - return prims.copy_(ceil(a), a) + return _copy_(a, ceil(a)) @torchsymbol(torch.cos, is_method=True) @@ -1448,7 +1458,7 @@ def cos(a): @torchsymbol(torch.cos_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def cos_(a): - return prims.copy_(cos(a), a) + return _copy_(a, cos(a)) @torchsymbol(torch.cosh, is_method=True) @@ -1458,7 +1468,7 @@ def cosh(a): @torchsymbol(torch.cosh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def cosh_(a): - return prims.copy_(cosh(a), a) + return _copy_(a, cosh(a)) @torchsymbol(torch.digamma, torch.special.digamma, is_method=True) @@ -1468,7 +1478,7 @@ def digamma(a): @torchsymbol(torch.Tensor.digamma_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def digamma_(a): - return prims.copy_(digamma(a), a) + return _copy_(a, digamma(a)) @torchsymbol(torch.erf, is_method=True) @@ -1478,7 +1488,7 @@ def erf(a): @torchsymbol(torch.erf_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def erf_(a): - return prims.copy_(erf(a), a) + return _copy_(a, erf(a)) @torchsymbol(torch.erfc, is_method=True) @@ -1488,7 +1498,7 @@ def erfc(a): @torchsymbol(torch.erfc_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def erfc_(a): - return prims.copy_(erfc(a), a) + return _copy_(a, erfc(a)) @torchsymbol(torch.erfinv, is_method=True) @@ -1498,7 +1508,7 @@ def erfinv(a): @torchsymbol(torch.Tensor.erfinv_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def erfinv_(a): - return prims.copy_(erfinv(a), a) + return _copy_(a, erfinv(a)) @torchsymbol(torch.exp, is_method=True) @@ -1508,7 +1518,7 @@ def exp(a): @torchsymbol(torch.exp_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def exp_(a): - return prims.copy_(exp(a), a) + return _copy_(a, exp(a)) @torchsymbol(torch.exp2, is_method=True) @@ -1518,7 +1528,7 @@ def exp2(a): @torchsymbol(torch.exp2_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def exp2_(a): - return prims.copy_(exp2(a), a) + return _copy_(a, exp2(a)) # fake out of place variant @@ -1552,7 +1562,7 @@ def exponential(a: Tensor, rate: float = 1, *, generator: None | torch.Generator @torchsymbol(torch.Tensor.exponential_, id="exponential_", is_method=True, tags=(prims.OpTags.IN_PLACE,)) def exponential_(a: Tensor, rate: float = 1, *, generator: None | torch.Generator = None) -> Tensor: - return prims.copy_(exponential(a, rate=rate, generator=generator), a) + return _copy_(a, exponential(a, rate=rate, generator=generator)) @torchsymbol(torch.expm1, is_method=True) @@ -1562,7 +1572,7 @@ def expm1(a): @torchsymbol(torch.expm1_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def expm1_(a): - return prims.copy_(expm1(a), a) + return _copy_(a, expm1(a)) @torchsymbol(torch.floor, is_method=True) @@ -1572,7 +1582,7 @@ def floor(a): @torchsymbol(torch.floor_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def floor_(a): - return prims.copy_(floor(a), a) + return _copy_(a, floor(a)) @torchsymbol(torch.isfinite, is_method=True) @@ -1587,7 +1597,7 @@ def lgamma(a): @torchsymbol(torch.Tensor.lgamma_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def lgamma_(a): - return prims.copy_(lgamma(a), a) + return _copy_(a, lgamma(a)) @torchsymbol(torch.log, is_method=True) @@ -1597,7 +1607,7 @@ def log(a): @torchsymbol(torch.log_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def log_(a): - return prims.copy_(log(a), a) + return _copy_(a, log(a)) @torchsymbol(torch.log10, is_method=True) @@ -1607,7 +1617,7 @@ def log10(a): @torchsymbol(torch.log10_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def log10_(a): - return prims.copy_(log10(a), a) + return _copy_(a, log10(a)) @torchsymbol(torch.log1p, is_method=True) @@ -1617,7 +1627,7 @@ def log1p(a): @torchsymbol(torch.log1p_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def log1p_(a): - return prims.copy_(log1p(a), a) + return _copy_(a, log1p(a)) @torchsymbol(torch.log2, is_method=True) @@ -1627,7 +1637,7 @@ def log2(a): @torchsymbol(torch.log2_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def log2_(a): - return prims.copy_(log2(a), a) + return _copy_(a, log2(a)) # TODO Move to special @@ -1643,7 +1653,7 @@ def neg(a): @torchsymbol(torch.neg_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def neg_(a): - return prims.copy_(neg(a), a) + return _copy_(a, neg(a)) @torchsymbol(torch.reciprocal, is_method=True) @@ -1653,7 +1663,7 @@ def reciprocal(a): @torchsymbol(torch.reciprocal_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def reciprocal_(a): - return prims.copy_(reciprocal(a), a) + return _copy_(a, reciprocal(a)) @torchsymbol(torch.round, is_method=True) @@ -1663,7 +1673,7 @@ def round(a): @torchsymbol(torch.round_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def round_(a): - return prims.copy_(round(a), a) + return _copy_(a, round(a)) @torchsymbol(torch.rsqrt, is_method=True) @@ -1673,7 +1683,7 @@ def rsqrt(a): @torchsymbol(torch.rsqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def rsqrt_(a): - return prims.copy_(rsqrt(a), a) + return _copy_(a, rsqrt(a)) # TODO Complain about complex numbers like PyTorch does? @@ -1685,7 +1695,7 @@ def sign(a): @torchsymbol(torch.Tensor.sign_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sign_(a): - return prims.copy_(sign(a), a) + return _copy_(a, sign(a)) @torchsymbol(torch.signbit, is_method=True) @@ -1700,7 +1710,7 @@ def sin(a): @torchsymbol(torch.sin_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sin_(a): - return prims.copy_(sin(a), a) + return _copy_(a, sin(a)) @torchsymbol(torch.sinh, is_method=True) @@ -1710,7 +1720,7 @@ def sinh(a): @torchsymbol(torch.sinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sinh_(a): - return prims.copy_(sinh(a), a) + return _copy_(a, sinh(a)) @torchsymbol(torch.sqrt, is_method=True) @@ -1720,7 +1730,7 @@ def sqrt(a): @torchsymbol(torch.sqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sqrt_(a): - return prims.copy_(sqrt(a), a) + return _copy_(a, sqrt(a)) @torchsymbol(torch.tan, is_method=True) @@ -1730,7 +1740,7 @@ def tan(a): @torchsymbol(torch.tan_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def tan_(a): - return prims.copy_(tan(a), a) + return _copy_(a, tan(a)) @torchsymbol(torch.tanh, is_method=True) @@ -1740,7 +1750,7 @@ def tanh(a): @torchsymbol(torch.tanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def tanh_(a): - return prims.copy_(tanh(a), a) + return _copy_(a, tanh(a)) @torchsymbol(torch.trunc, is_method=True) @@ -1750,7 +1760,7 @@ def trunc(a): @torchsymbol(torch.trunc_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def trunc_(a): - return prims.copy_(trunc(a), a) + return _copy_(a, trunc(a)) @torchsymbol(torch.real, is_method=False) @@ -1769,7 +1779,7 @@ def celu(a: TensorLike, /, alpha: float = 1.0, inplace: bool = False) -> TensorL negative_domain_value = alpha * expm1(a / alpha) out = where(a > 0, a, negative_domain_value) if inplace: - return prims.copy_(out, a) + return _copy_(a, out) return out @@ -1781,7 +1791,7 @@ def elu(a: TensorProxy, /, alpha: float = 1.0, inplace: bool = False) -> TensorL negative_domain_value = alpha * expm1(a) out = where(a > 0, a, negative_domain_value) if inplace: - return prims.copy_(out, a) + return _copy_(a, out) return out @@ -1806,7 +1816,7 @@ def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike: def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool = False) -> TensorLike: out = where(a > 0, a, a * negative_slope) if inplace: - return prims.copy_(out, a) + return _copy_(a, out) return out @@ -1832,7 +1842,7 @@ def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike: out = where(a > 0, a, 0) if inplace: - return prims.copy_(out, a) + return _copy_(a, out) return out @@ -1844,7 +1854,7 @@ def relu_( a: TensorLike, /, ) -> TensorLike: - return prims.copy_(relu(a, False), a) + return _copy_(a, relu(a, False)) # The default value of `inplace` is False, so no need to tweak args/kwargs @@ -1856,7 +1866,7 @@ def relu_( def relu6(a: TensorProxy, /, inplace: bool = False) -> TensorLike: out = clamp(a, 0, 6) if inplace: - return prims.copy_(out, a) + return _copy_(a, out) return out @@ -1881,7 +1891,7 @@ def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike: ) out = a * relu6(a + 3) / 6 if inplace: - return prims.copy_(out, a) + return _copy_(a, out) return out @@ -1898,7 +1908,7 @@ def selu(a: TensorProxy, /, inplace: bool = False) -> TensorLike: out = scale * where(a > 0, a, rhs) if inplace: - return prims.copy_(out, a) + return _copy_(a, out) return out @@ -1909,7 +1919,7 @@ def selu(a: TensorProxy, /, inplace: bool = False) -> TensorLike: def silu(a: TensorLike, /, inplace: bool = False) -> TensorLike: out = clang.silu(a) if inplace: - return prims.copy_(out, a) + return _copy_(a, out) return out @@ -1946,7 +1956,7 @@ def add_( *, alpha: Number | TensorLike = 1, ) -> TensorLike: - return prims.copy_(add(a, b, alpha=alpha), a) + return _copy_(a, add(a, b, alpha=alpha)) @torchsymbol(torch.atan2, is_method=True) @@ -1956,7 +1966,7 @@ def atan2(a, b, /): @torchsymbol(torch.Tensor.atan2_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def atan2_(a, b, /): - return prims.copy_(atan2(a, b), a) + return _copy_(a, atan2(a, b)) @torchsymbol(torch.bitwise_and, is_method=True) @@ -1966,7 +1976,7 @@ def bitwise_and(a, b, /): @torchsymbol(torch.Tensor.bitwise_and_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def bitwise_and_(a, b, /): - return prims.copy_(bitwise_and(a, b), a) + return _copy_(a, bitwise_and(a, b)) @torchsymbol(torch.bitwise_or, is_method=True) @@ -1976,7 +1986,7 @@ def bitwise_or(a, b, /): @torchsymbol(torch.Tensor.bitwise_or_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def bitwise_or_(a, b, /): - return prims.copy_(bitwise_or(a, b), a) + return _copy_(a, bitwise_or(a, b)) @torchsymbol(torch.bitwise_xor, is_method=True) @@ -1986,7 +1996,7 @@ def bitwise_xor(a, b, /): @torchsymbol(torch.Tensor.bitwise_xor_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def bitwise_xor_(a, b, /): - return prims.copy_(bitwise_xor(a, b), a) + return _copy_(a, bitwise_xor(a, b)) @torchsymbol(torch.copysign, is_method=True) @@ -1994,14 +2004,9 @@ def copysign(a, b, /): return clang.copysign(a, b) -@torchsymbol(torch.Tensor.copysign_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +@torchsymbol(torch.Tensor.copysign_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,)) def copysign_(a, b, /): - return prims.copy_(copysign(a, b), a) - - -@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,)) -def copy_(a, b, /): - return prims.copy_(b, a) + return _copy_(a, copysign(a, b)) # TODO Implement div @@ -2034,7 +2039,7 @@ def div_( *, rounding_mode: None | str = None, ) -> TensorLike: - return prims.copy_(div(a, b), a) + return _copy_(a, div(a, b)) @torchsymbol(torch.eq, is_method=True) @@ -2044,7 +2049,7 @@ def eq(a, b, /): @torchsymbol(torch.Tensor.eq_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def eq_(a, b, /): - return prims.copy_(eq(a, b), a) + return _copy_(a, eq(a, b)) @torchsymbol(torch.floor_divide, is_method=True) @@ -2054,7 +2059,7 @@ def floor_divide(a, b, /): @torchsymbol(torch.Tensor.floor_divide_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def floor_divide_(a, b, /): - return prims.copy_(floor_divide(a, b), a) + return _copy_(a, floor_divide(a, b)) @torchsymbol(torch.fmod, is_method=True) @@ -2064,7 +2069,7 @@ def fmod(a, b, /): @torchsymbol(torch.Tensor.fmod_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def fmod_(a, b, /): - return prims.copy_(fmod(a, b), a) + return _copy_(a, fmod(a, b)) @torchsymbol(torch.ge, is_method=True) @@ -2074,7 +2079,7 @@ def ge(a, b, /): @torchsymbol(torch.Tensor.ge_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def ge_(a, b, /): - return prims.copy_(ge(a, b), a) + return _copy_(a, ge(a, b)) @torchsymbol(torch.gt, is_method=True) @@ -2084,7 +2089,7 @@ def gt(a, b, /): @torchsymbol(torch.Tensor.gt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def gt_(a, b, /): - return prims.copy_(gt(a, b), a) + return _copy_(a, gt(a, b)) @torchsymbol(torch.logical_and, is_method=True) @@ -2094,7 +2099,7 @@ def logical_and(a, b, /): @torchsymbol(torch.Tensor.logical_and_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def logical_and_(a, b, /): - return prims.copy_(logical_and(a, b), a) + return _copy_(a, logical_and(a, b)) @torchsymbol(torch.logical_not, is_method=True) @@ -2104,7 +2109,7 @@ def logical_not(a: TensorLike, /) -> TensorLike: @torchsymbol(torch.Tensor.logical_not_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def logical_not_(a: TensorLike, /) -> TensorLike: - return prims.copy_(logical_not(a), a) + return _copy_(a, logical_not(a)) @torchsymbol(torch.le, is_method=True) @@ -2114,7 +2119,7 @@ def le(a, b, /): @torchsymbol(torch.Tensor.le_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def le_(a, b, /): - return prims.copy_(le(a, b), a) + return _copy_(a, le(a, b)) @torchsymbol(torch.lt, is_method=True) @@ -2124,7 +2129,7 @@ def lt(a, b, /): @torchsymbol(torch.Tensor.lt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def lt_(a, b, /): - return prims.copy_(lt(a, b), a) + return _copy_(a, lt(a, b)) @torchsymbol(torch.maximum, is_method=True) @@ -2145,7 +2150,7 @@ def mod(a, b): def mod_(a, b): - return prims.copy_(mod(a, b), a) + return _copy_(a, mod(a, b)) @torchsymbol(torch.mul, is_method=True) @@ -2155,7 +2160,7 @@ def mul(a, b, /): @torchsymbol(torch.Tensor.mul_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def mul_(a, b, /): - return prims.copy_(mul(a, b), a) + return _copy_(a, mul(a, b)) @torchsymbol(torch.ne, is_method=True) @@ -2165,7 +2170,7 @@ def ne(a, b, /): @torchsymbol(torch.Tensor.ne_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def ne_(a, b, /): - return prims.copy_(ne(a, b), a) + return _copy_(a, ne(a, b)) @torchsymbol(torch.nextafter, is_method=True) @@ -2175,7 +2180,7 @@ def nextafter(a, b, /): @torchsymbol(torch.Tensor.nextafter_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def nextafter_(a, b, /): - return prims.copy_(nextafter(a, b), a) + return _copy_(a, nextafter(a, b)) # TODO Extend to tensor x tensor @@ -2198,7 +2203,7 @@ def polygamma(n: int, a: TensorLike, /) -> TensorLike: @torchsymbol(torch.Tensor.polygamma_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def polygamma_(n: int, a: TensorLike, /) -> TensorLike: - return prims.copy_(polygamma(n, a), a) + return _copy_(a, polygamma(n, a)) @torchsymbol(torch.pow, is_method=True) @@ -2208,7 +2213,7 @@ def pow(a, b, /): @torchsymbol(torch.Tensor.pow_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def pow_(a, b, /): - return prims.copy_(pow(a, b), a) + return _copy_(a, pow(a, b)) @torchsymbol(torch.remainder, is_method=True) @@ -2218,7 +2223,7 @@ def remainder(a, b, /): @torchsymbol(torch.Tensor.remainder_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def remainder_(a, b, /): - return prims.copy_(remainder(a, b), a) + return _copy_(a, remainder(a, b)) @torchsymbol(torch.sub, is_method=True) @@ -2231,7 +2236,7 @@ def sub(a, b, /, *, alpha: NumberLike | TensorLike = 1): @torchsymbol(torch.Tensor.sub_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sub_(a, b, /, *, alpha: NumberLike | TensorLike = 1): - return prims.copy_(sub(a, b, alpha=alpha), a) + return _copy_(a, sub(a, b, alpha=alpha)) @torchsymbol(torch.true_divide, is_method=True) @@ -2241,7 +2246,7 @@ def true_divide(a: NumberLike | TensorLike, b: NumberLike | TensorLike, /) -> Nu @torchsymbol(torch.Tensor.true_divide_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def true_divide_(a: TensorLike, b: NumberLike | TensorLike, /) -> TensorLike: - return prims.copy_(true_divide(a, b)) + return _copy_(a, true_divide(a, b)) @torchsymbol(torch.special.zeta) @@ -2280,7 +2285,7 @@ def addcmul(a: TensorLike, b: TensorLike, c: TensorLike, /, *, value: None | Num @torchsymbol(torch.Tensor.addcmul_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def addcmul_(a: TensorLike, b: TensorLike, c: TensorLike, /, *, value: None | Number = None) -> TensorLike: - return prims.copy_(addcmul(a, b, c, value=value), a) + return _copy_(a, addcmul(a, b, c, value=value)) @torchsymbol(torch.addcdiv, is_method=True) @@ -2290,7 +2295,7 @@ def addcdiv(a: TensorLike, b: TensorLike, c: TensorLike, /, *, value: None | Num @torchsymbol(torch.Tensor.addcdiv_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def addcdiv_(a: TensorLike, b: TensorLike, c: TensorLike, /, *, value: None | Number = None) -> TensorLike: - return prims.copy_(addcdiv(a, b, c, value=value), a) + return _copy_(a, addcdiv(a, b, c, value=value)) @torchsymbol(torch.lerp, is_method=True) @@ -2300,7 +2305,7 @@ def lerp(start: TensorLike, end: TensorLike, weight: Number | TensorLike) -> Ten @torchsymbol(torch.Tensor.lerp_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def lerp_(start: TensorLike, end: TensorLike, weight: Number | TensorLike) -> TensorLike: - return prims.copy_(lerp(start, end, weight), start) + return _copy_(start, lerp(start, end, weight)) # @@ -2347,7 +2352,7 @@ def clamp( def clamp_( a: TensorLike, /, min: None | Number | TensorLike = None, max: None | Number | TensorLike = None ) -> TensorLike: - return prims.copy_(clamp(a, min, max), a) + return _copy_(a, clamp(a, min, max)) def _mask_tensor(a, mask, fill_value): @@ -2380,7 +2385,7 @@ def masked_fill(a: TensorLike, /, mask: TensorLike, value: NumberLike | TensorLi @torchsymbol(torch.Tensor.masked_fill_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def masked_fill_(a: TensorLike, /, mask: TensorLike, value: NumberLike | TensorLike) -> TensorLike: - return prims.copy_(masked_fill(a, mask, value), a) + return _copy_(a, masked_fill(a, mask, value)) # NOTE The key to understanding tril is that it generates a mask @@ -2407,7 +2412,7 @@ def tril(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = Non @torchsymbol(torch.Tensor.tril_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def tril_(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = None) -> TensorLike: - return prims.copy_(tril(a, diagonal, fill_value=fill_value), a) + return _copy_(a, tril(a, diagonal, fill_value=fill_value)) @torchsymbol(torch.where, is_method=True) @@ -2837,7 +2842,7 @@ def cumsum(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> Tensor @torchsymbol(torch.Tensor.cumsum_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def cumsum_(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> TensorLike: - return prims.copy_(cumsum(a, dim, dtype=dtype), a) + return _copy_(a, cumsum(a, dim, dtype=dtype)) @torchsymbol(torch.var, is_method=True) @@ -2919,7 +2924,7 @@ def index_add(a: TensorLike, /, dim: int, index: TensorLike, source: TensorLike) @torchsymbol(torch.Tensor.index_add_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def index_add_(a: TensorLike, /, dim: int, index: TensorLike, source: TensorLike) -> TensorLike: - return prims.copy_(index_add(a, dim, index, source), a) + return _copy_(a, index_add(a, dim, index, source)) @torchsymbol(torch.index_copy, is_method=True) @@ -2929,7 +2934,7 @@ def index_copy(a: TensorLike, /, dim: int, index: TensorLike, source: TensorLike @torchsymbol(torch.Tensor.index_copy_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def index_copy_(a: TensorLike, /, dim: int, index: TensorLike, source: TensorLike) -> TensorLike: - return prims.copy_(index_copy(a, dim, index, source), a) + return _copy_(a, index_copy(a, dim, index, source)) @torchsymbol(torch.index_select, is_method=True) @@ -2993,7 +2998,7 @@ def scatter_( if src is None: src = value - return prims.copy_(clang.scatter(a, index, src, dim), a) + return _copy_(a, clang.scatter(a, index, src, dim)) # NOTE PyTorch's scatter_add has a parameter named 'src', not 'source' @@ -3004,7 +3009,7 @@ def scatter_add(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) @torchsymbol(torch.Tensor.scatter_add_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def scatter_add_(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike: - return prims.copy_(scatter_add(a, dim, index, src), a) + return _copy_(a, scatter_add(a, dim, index, src)) @torchsymbol(torch.take_along_dim, is_method=True) @@ -3027,7 +3032,7 @@ def index_put_( values: TensorLike, accumulate: bool = False, ) -> TensorLike: - return prims.copy_(index_put(a, indices, values, accumulate), a) + return _copy_(a, index_put(a, indices, values, accumulate)) # @@ -3773,14 +3778,14 @@ def _native_batch_norm( new_running_mean = (1 - momentum) * running_mean + momentum * mean if not utils.are_same_dtypes(new_running_mean, running_mean): new_running_mean = to(new_running_mean, running_mean.dtype) - prims.copy_(new_running_mean, running_mean) + _copy_(running_mean, new_running_mean) if running_var is not None: n = a.numel() / a.shape[1] unbiased_var = biased_var * (n / (n - 1)) new_running_var = (1 - momentum) * running_var + momentum * unbiased_var if not utils.are_same_dtypes(new_running_var, running_var): new_running_var = to(new_running_var, running_var.dtype) - prims.copy_(new_running_var, running_var) + _copy_(running_var, new_running_var) else: running_var_acc = to(running_var, computation_dtype) rstd = rsqrt(running_var_acc + eps) @@ -4661,7 +4666,7 @@ def dropout(a: TensorProxy, /, p: NumberLike = 0.5, training: bool = True, inpla out = a * dropout_mask * scale if inplace: - return prims.copy_(out, a) + return _copy_(a, out) return out @@ -5472,7 +5477,7 @@ def all_gather_( out = dist_prims.wait(out_or_work) else: out = out_or_work - return prims.copy_(out.view(output_tensor.shape), output_tensor) + return _copy_(output_tensor, out.view(output_tensor.shape)) # NOTE torch.distributed.all_reduce is an inplace operation (although the underlying NCCL # call does not need to be inplace). This, however, is modeled as an out-of-place functional @@ -5526,7 +5531,7 @@ def all_reduce_( group = group if group is not None else torch.distributed.new_group() out = dist_prims.all_reduce(a, op, group, async_op, skip_clone=True) - return prims.copy_(out, a) + return _copy_(a, out) @torchsymbol( is_method=False, @@ -5580,7 +5585,7 @@ def reduce_scatter_( out = dist_prims.wait(out_or_work) else: out = out_or_work - return prims.copy_(out.view(output.shape), output) + return _copy_(output, out.view(output.shape)) @torchsymbol( torch.ops._c10d_functional.wait_tensor,