From 0998d81ecfaeaf9d33c925b0c3d3dccc935bb8e0 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 26 Dec 2024 01:42:26 +0900 Subject: [PATCH 1/3] core aten ops Signed-off-by: Masaki Kozuki --- thunder/torch/__init__.py | 1013 ++++++++++++++++++++++++++++++++++--- 1 file changed, 955 insertions(+), 58 deletions(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 38d0353c1f..34245270b8 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -617,6 +617,21 @@ def full( return clang.full(shape, fill_value, device=device, dtype=dtype) +@torchsymbol(torch.ops.aten.full, torch.ops.aten.full.default, id="torch.ops.aten.full") +def core_aten_full( + size: Sequence[int], + fill_value: NumberLike, + *, + dtype: None | dtypeLike = None, + layout: None | torhc.layout = None, + device: None | DeviceLike = None, + pin_memory: None | bool = None, +): + device = to_device(maybe_get_default_device(device)) + dtype = _infer_full_dtype(fill_value, dtype) + return clang.full(size, fill_value, device=device, dtype=dtype) + + @torchsymbol(torch.full_like) def full_like( a: TensorLike, /, fill_value: NumberLike, *, device: None | DeviceLike = None, dtype: None | dtypeLike = None @@ -768,6 +783,34 @@ def randn( return prims.randn(shape, device=device, dtype=dtype) +@torchsymbol(torch.ops.aten.randn, torch.ops.aten.randn.default, id="torch.ops.aten.randn") +def core_aten_randn( + size: Sequence[int], + *, + dtype: dtypeLike | None = None, + layout: torch.layout | None = None, + device: DeviceLike | None = None, + pin_memory: bool | None = None, +) -> TensorProxy: + if layout is None: + layout = torch.strided + if pin_memory is None: + pin_memory = False + utils.check( + not requires_grad, lambda: "requires_grad=True is not yet supported within thunder.jit", NotImplementedError + ) + utils.check(layout == torch.strided, lambda: "Only torch.strided layout is supported", NotImplementedError) + utils.check(not pin_memory, lambda: "pin_memory=True is not supported within thunder.jit", NotImplementedError) + # NOTE: Currently, we don't model randomness + utils.check(generator is None, lambda: "generator is not None which is currently unsupported", NotImplementedError) + utils.check(out is None, lambda: "out is not None which is currently unsupported", NotImplementedError) + + device = to_device(maybe_get_default_device(device)) + dtype = to_dtype(maybe_get_default_dtype(dtype)) + shape = tuple(utils.extract_shape_from_varargs(shape)) + return prims.randn(shape, device=device, dtype=dtype) + + @torchsymbol(torch.randn_like) def randn_like( a, @@ -824,8 +867,7 @@ def zeros_like(a: TensorLike, /, *, device: DeviceLike | None = None, dtype: dty return full_like(a, 0, device=device, dtype=dtype) -@torchsymbol(torch.empty) -def empty( +def _empty_impl( *size: int, device: None | DeviceLike = None, dtype: None | dtypeLike = None, @@ -855,6 +897,43 @@ def empty( return clang.empty(size, device=device, dtype=dtype) +@torchsymbol(torch.empty) +def empty( + *size: int, + device: None | DeviceLike = None, + dtype: None | dtypeLike = None, + out: None | TensorLike = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + pin_memory: bool = False, + memory_format: torch.memory_format = torch.contiguous_format, +) -> TensorLike: + return _empty_impl( + *size, + device=device, + dtype=dtype, + out=out, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + +@torchsymbol(torch.ops.aten.empty.memory_format, id="torch.ops.aten.empty.memory_format") +def core_aten_empry_memory_format( + *size: int, + dtype: None | dtypeLike = None, + layout: torch.layout = torch.strided, + device: None | DeviceLike = None, + pin_memory: bool = False, + memory_format: torch.memory_format = torch.contiguous_format, +) -> TensorLike: + return _empty_impl( + *size, device=device, dtype=dtype, out=out, layout=layout, pin_memory=pin_memory, memory_format=memory_format + ) + + # # Shape operations # @@ -866,6 +945,11 @@ def cat(tensors: Sequence[TensorLike], dim: int = 0) -> TensorLike: return clang.cat(tensors, dim) +@torchsymbol(torch.ops.aten.cat, torch.ops.aten.cat.default, id="torch.ops.aten.cat") +def core_aten_cat(tensors: Sequence[TensorLike], dim: int = 0) -> TensorLike: + return clang.cat(tensors, dim) + + @torchsymbol(torch.chunk, is_method=True) def chunk(a: TensorLike, chunks: int, dim: int = 0) -> Sequence[TensorLike]: utils.check(a.ndim > 0, lambda: f"chunk: a ({a.ndim=}) must be at least 1-dimensional") @@ -936,11 +1020,26 @@ def diagonal(a: TensorLike, /, offset: int = 0, dim1: int = 0, dim2: int = 1) -> return clang.diagonal(a, offset, dim1, dim2) +@torchsymbol(torch.ops.aten.diagonal, torch.ops.aten.diagonal.default, id="torch.ops.aten.diagonal") +def core_aten_diagonal(a: TensorLike, /, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TensorLike: + return clang.diagonal(a, offset, dim1, dim2) + + @torchsymbol(torch.Tensor.expand, is_method=True) def expand(a: TensorLike, /, *shape: int) -> TensorLike: return clang.expand(a, *shape) +@torchsymbol(torch.ops.aten.expand, torch.ops.aten.expand.default, id="torch.ops.aten.expand") +def core_aten_expand(a: TensorLike, size: Sequence[int], *, implicit: bool = False): + utils.check( + not implicit, + lambda: f"`torch.ops.aten.expand` with {implicit=} is not supported", + exception_type=NotImplementedError, + ) + return clang.expand(a, *size) + + @torchsymbol(torch.Tensor.expand_as, is_method=True) def expand_as(a: TensorLike, b: TensorLike, /) -> TensorLike: return expand(a, b.size()) @@ -951,8 +1050,7 @@ def flatten(a: TensorLike, /, start_dim: int = 0, end_dim: int = -1) -> TensorLi return clang.flatten(a, start_dim, end_dim) -@torchsymbol(torch.flip, is_method=True) -def flip(a: TensorLike, /, *dims: int) -> TensorLike: +def _flip_impl(a: TensorLike, /, *dims: int) -> TensorLike: dims = utils.extract_shape_from_varargs(dims) # PyTorch supports 0-dim inputs with len(dims) <= 1 @@ -970,6 +1068,16 @@ def flip(a: TensorLike, /, *dims: int) -> TensorLike: return clang.flip(a, dims) +@torchsymbol(torch.flip, is_method=True) +def flip(a: TensorLike, /, *dims: int) -> TensorLike: + return _flip_impl(a, *dims) + + +@torchsymbol(torch.ops.aten.flip, torch.ops.aten.flip.default, id="torch.ops.aten.flip") +def core_aten_flip(a: TensorLike, dims: Sequence[int]) -> TensorLike: + return _flip_impl(a, *dims) + + # fake out of place variant @torchsymbol(id="setitem") def setitem(inp, idx, val): @@ -1052,8 +1160,13 @@ def permute(a: TensorLike, /, *dims: int) -> TensorLike: return clang.transpose(a, dims) -@torchsymbol(torch.Tensor.repeat, is_method=True) -def repeat(a: TensorLike, /, *repeats: int) -> TensorLike: +@torchsymbol(torch.ops.aten.permute, torch.ops.aten.permute.default, id="torch.ops.aten.permute") +def core_aten_permute(a: TensorLike, dims: Sequence[int]) -> TensorLike: + dims = utils.extract_shape_from_varargs(dims) + return clang.transpose(a, dims) + + +def _repeat_impl(a: TensorLike, /, *repeats: int) -> TensorLike: repeats = utils.extract_shape_from_varargs(repeats) utils.check_valid_shape(repeats) utils.check(a.ndim <= len(repeats), f"Expected {a.ndim=} <= {len(repeats)=}") @@ -1073,6 +1186,16 @@ def repeat(a: TensorLike, /, *repeats: int) -> TensorLike: return reshape(a, out_shape) +@torchsymbol(torch.Tensor.repeat, is_method=True) +def repeat(a: TensorLike, /, *repeats: int) -> TensorLike: + return _repeat_impl(a, *repeats) + + +@torchsymbol(torch.ops.aten.repeat, torch.ops.aten.repeat.default, id="torch.ops.aten.repeat") +def core_aten_repeat(a: TensorProxy, repeats: Sequence[int]) -> TensorProxy: + return _repeat_impl(a, *repeats) + + @torchsymbol(torch.reshape, is_method=True) def reshape(a: TensorLike, /, *shape: int) -> TensorLike: shape = utils.extract_shape_from_varargs(shape) @@ -1091,8 +1214,7 @@ def unflatten(a: TensorLike, /, dim: int, sizes=Sequence[int]) -> TensorLike: return a.view(a.shape[:dim] + tuple(sizes) + a.shape[dim + 1 :]) -@torchsymbol(torch.select, is_method=True) -def select(a: TensorLike, /, dim: int, index: int): +def _select_impl(a: TensorLike, /, dim: int, index: int): # dim check utils.check( a.ndim != 0, @@ -1115,6 +1237,16 @@ def select(a: TensorLike, /, dim: int, index: int): return squeeze(a_sliced, dim) +@torchsymbol(torch.select, is_method=True) +def select(a: TensorLike, /, dim: int, index: int): + return _select_impl(a, dim, index) + + +@torchsymbol(torch.ops.aten.select.int, id="torch.ops.aten.select.int") +def core_aten_select(a: TensorProxy, dim: int, index: int): + return _select_impl(a, dim, index) + + # TODO consider revising this to just call _split_indices # Splits a tensor along a split dimension dim into n tensors # If input is divisible by n then every tensor will have the same length along the split dimension @@ -1230,8 +1362,7 @@ def stack(tensors: Sequence[TensorLike], /, dim: int = 0) -> TensorLike: # See https://pytorch.org/docs/master/generated/torch.squeeze.html -@torchsymbol(torch.squeeze, is_method=True) -def squeeze(a: TensorLike, /, dim: None | int | Sequence[int] = None) -> TensorLike: +def _squeeze_impl(a: TensorLike, /, dim: None | int | Sequence[int] = None) -> TensorLike: # Converts dim to a tuple of numbers dims = dim if dim is None: @@ -1253,6 +1384,21 @@ def squeeze(a: TensorLike, /, dim: None | int | Sequence[int] = None) -> TensorL return clang.squeeze(a, dims) +@torchsymbol(torch.squeeze, is_method=True) +def squeeze(a: TensorLike, /, dim: None | int | Sequence[int] = None) -> TensorLike: + return _squeeze_impl(a, dim) + + +@torchsymbol(torch.ops.aten.squeeze.dim, id="torch.ops.aten.squeeze.dim") +def core_aten_squeeze_dim(a: TensorProxy, dim: int) -> TensorLike: + return _squeeze_impl(a, dim) + + +@torchsymbol(torch.ops.aten.squeeze.dims, id="torch.ops.aten.squeeze.dims") +def core_aten_squeeze_dims(a: TensorProxy, dim: Sequence[int]) -> TensorLike: + return _squeeze_impl(a, dim) + + @torchsymbol(torch.t, is_method=True) def t(a: TensorLike, /) -> TensorLike: utils.check( @@ -1263,6 +1409,18 @@ def t(a: TensorLike, /) -> TensorLike: return prims.transpose(a, (1, 0)) if a.ndim == 2 else a +@torchsymbol(torch.ops.aten.t.default, id="torch.ops.aten.t.default") +def core_aten_t(a: TensorProxy) -> TensorProxy: + utils.check( + a.ndim <= 2, + lambda: f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D", + RuntimeError, + ) + if a.ndim != 2: + return a + return transpose(a, 0, 1) + + @run_once def warn_ndim_not_2(): warnings.warn( @@ -1303,8 +1461,7 @@ def tensor_split(a: TensorLike, /, indices_or_sections, dim=0): return _split_indices(a, indices_or_sections, dim) -@torchsymbol(torch.transpose, is_method=True) -def transpose(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike: +def _transpose_impl(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike: dim0, dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) permutation = list(range(0, a.ndim)) @@ -1313,6 +1470,16 @@ def transpose(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike: return clang.transpose(a, permutation) +@torchsymbol(torch.transpose, is_method=True) +def transpose(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike: + return _transpose_impl(a, dim0, dim1) + + +@torchsymbol(torch.ops.aten.transpose.int, id="torch.ops.aten.transpose.int") +def core_aten_transpose(a: TensorProxy, dim0: int, dim1: int) -> TensorProxy: + return _transpose_impl(a, dim0, dim1) + + @torchsymbol(torch.unbind, is_method=True) def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]: utils.check( @@ -1332,6 +1499,11 @@ def unsqueeze(a: TensorLike, /, dim: int) -> TensorLike: return clang.unsqueeze(a, dim) +@torchsymbol(torch.ops.aten.unsqueeze, torch.ops.aten.unsqueeze.default, id="torch.ops.aten.unsqueeze") +def core_aten_unsqueeze(a: TensorProxy, dim: int) -> TensorProxy: + return clang.unsqueeze(a, dim) + + # TODO Review view functionalization # TODO Add type annotations @torchsymbol(torch.Tensor.view, is_method=True) @@ -1340,6 +1512,11 @@ def view(a: TensorLike, /, *shape) -> TensorLike: return reshape(a, shape) +@torchsymbol(torch.ops.aten.view, torch.ops.aten.view.default, id="torch.ops.aten.view") +def core_aten_view(a: TensorProxy, size: Sequence[int]) -> TensorProxy: + return reshape(a, size) + + @torchsymbol(torch.Tensor.view_as, is_method=True) def view_as(a: TensorLike, b: TensorLike, /) -> TensorLike: return view(a, b.size()) @@ -1386,6 +1563,11 @@ def asin(a): return clang.asin(a) +@torchsymbol(torch.ops.aten.asin, torch.ops.aten.asin.default, id="torch.ops.aten.asin") +def core_aten_asin(a): + return clang.asin(a) + + @torchsymbol(torch.asin_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def asin_(a): return prims.copy_(asin(a), a) @@ -1396,6 +1578,11 @@ def asinh(a): return clang.asinh(a) +@torchsymbol(torch.ops.aten.asinh, torch.ops.aten.asinh.default, id="torch.ops.aten.asinh") +def core_aten_asinh(a): + return clang.asinh(a) + + @torchsymbol(torch.asinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def asinh_(a): return prims.copy_(asinh(a), a) @@ -1406,6 +1593,11 @@ def atan(a): return clang.atan(a) +@torchsymbol(torch.ops.aten.atan, torch.ops.aten.atan.default, id="torch.ops.aten.atan") +def core_aten_atan(a): + return clang.atan(a) + + @torchsymbol(torch.atan_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def atan_(a): return prims.copy_(atan(a), a) @@ -1416,6 +1608,11 @@ def atanh(a): return clang.atanh(a) +@torchsymbol(torch.ops.aten.atanh, torch.ops.aten.atanh.default, id="torch.ops.aten.atanh") +def core_aten_atanh(a): + return clang.atanh(a) + + @torchsymbol(torch.atanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def atanh_(a): return prims.copy_(atanh(a), a) @@ -1426,6 +1623,11 @@ def bitwise_not(a): return clang.bitwise_not(a) +@torchsymbol(torch.ops.aten.bitwise_not, torch.ops.aten.bitwise_not.default, id="torch.ops.aten.bitwise_not") +def core_aten_bitwise_not(a): + return clang.bitwise_not(a) + + @torchsymbol(torch.Tensor.bitwise_not_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def bitwise_not_(a): return prims.copy_(bitwise_not(a), a) @@ -1446,6 +1648,11 @@ def cos(a): return clang.cos(a) +@torchsymbol(torch.ops.aten.cos, torch.ops.aten.cos.default, id="torch.ops.aten.cos") +def core_aten_cos(a): + return clang.cos(a) + + @torchsymbol(torch.cos_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def cos_(a): return prims.copy_(cos(a), a) @@ -1456,6 +1663,11 @@ def cosh(a): return clang.cosh(a) +@torchsymbol(torch.ops.aten.cosh, torch.ops.aten.cosh.default, id="torch.ops.aten.cosh") +def core_aten_cosh(a): + return clang.cosh(a) + + @torchsymbol(torch.cosh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def cosh_(a): return prims.copy_(cosh(a), a) @@ -1476,6 +1688,11 @@ def erf(a): return clang.erf(a) +@torchsymbol(torch.ops.aten.erf, torch.ops.aten.erf.default, id="torch.ops.aten.erf") +def core_aten_erf(a): + return clang.erf(a) + + @torchsymbol(torch.erf_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def erf_(a): return prims.copy_(erf(a), a) @@ -1506,6 +1723,11 @@ def exp(a): return clang.exp(a) +@torchsymbol(torch.ops.aten.exp, torch.ops.aten.exp.default, id="torch.ops.aten.exp") +def core_aten_exp(a): + return clang.exp(a) + + @torchsymbol(torch.exp_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def exp_(a): return prims.copy_(exp(a), a) @@ -1560,6 +1782,11 @@ def expm1(a): return clang.expm1(a) +@torchsymbol(torch.ops.aten.expm1, torch.ops.aten.expm1.default, id="torch.ops.aten.expm1") +def core_aten_expm1(a): + return clang.expm1(a) + + @torchsymbol(torch.expm1_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def expm1_(a): return prims.copy_(expm1(a), a) @@ -1570,6 +1797,11 @@ def floor(a): return clang.floor(a) +@torchsymbol(torch.ops.aten.floor, torch.ops.aten.floor.default, id="torch.ops.aten.floor") +def core_aten_floor(a): + return clang.floor(a) + + @torchsymbol(torch.floor_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def floor_(a): return prims.copy_(floor(a), a) @@ -1595,6 +1827,11 @@ def log(a): return clang.log(a) +@torchsymbol(torch.ops.aten.log, torch.ops.aten.log.default, id="torch.ops.aten.log") +def core_aten_log(a): + return clang.log(a) + + @torchsymbol(torch.log_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def log_(a): return prims.copy_(log(a), a) @@ -1605,6 +1842,11 @@ def log10(a): return clang.log10(a) +@torchsymbol(torch.ops.aten.log10, torch.ops.aten.log10.default, id="torch.ops.aten.log10") +def core_aten_log10(a): + return clang.log10(a) + + @torchsymbol(torch.log10_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def log10_(a): return prims.copy_(log10(a), a) @@ -1615,6 +1857,11 @@ def log1p(a): return clang.log1p(a) +@torchsymbol(torch.ops.aten.log1p, torch.ops.aten.log1p.default, id="torch.ops.aten.log1p") +def core_aten_log1p(a): + return clang.log1p(a) + + @torchsymbol(torch.log1p_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def log1p_(a): return prims.copy_(log1p(a), a) @@ -1625,6 +1872,11 @@ def log2(a): return clang.log2(a) +@torchsymbol(torch.ops.aten.log2, torch.ops.aten.log2.default, id="torch.ops.aten.log2") +def core_aten_log2(a): + return clang.log2(a) + + @torchsymbol(torch.log2_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def log2_(a): return prims.copy_(log2(a), a) @@ -1641,6 +1893,11 @@ def neg(a): return clang.neg(a) +@torchsymbol(torch.ops.aten.neg, torch.ops.aten.neg.default, id="torch.ops.aten.neg") +def core_aten_neg(a): + return clang.neg(a) + + @torchsymbol(torch.neg_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def neg_(a): return prims.copy_(neg(a), a) @@ -1651,6 +1908,11 @@ def reciprocal(a): return clang.reciprocal(a) +@torchsymbol(torch.ops.aten.reciprocal, torch.ops.aten.reciprocal.default, id="torch.ops.aten.reciprocal") +def core_aten_reciprocal(a): + return clang.reciprocal(a) + + @torchsymbol(torch.reciprocal_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def reciprocal_(a): return prims.copy_(reciprocal(a), a) @@ -1661,6 +1923,11 @@ def round(a): return clang.round(a) +@torchsymbol(torch.ops.aten.round, torch.ops.aten.round.default, id="torch.ops.aten.round") +def core_aten_round(a): + return clang.round(a) + + @torchsymbol(torch.round_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def round_(a): return prims.copy_(round(a), a) @@ -1671,6 +1938,11 @@ def rsqrt(a): return clang.rsqrt(a) +@torchsymbol(torch.ops.aten.rsqrt, torch.ops.aten.rsqrt.default, id="torch.ops.aten.rsqrt") +def core_aten_rsqrt(a): + return clang.rsqrt(a) + + @torchsymbol(torch.rsqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def rsqrt_(a): return prims.copy_(rsqrt(a), a) @@ -1683,6 +1955,11 @@ def sign(a): return clang.sign(a) +@torchsymbol(torch.ops.aten.sign, torch.ops.aten.sign.default, id="torch.ops.aten.sign") +def core_aten_sign(a): + return clang.sign(a) + + @torchsymbol(torch.Tensor.sign_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sign_(a): return prims.copy_(sign(a), a) @@ -1698,6 +1975,11 @@ def sin(a): return clang.sin(a) +@torchsymbol(torch.ops.aten.sin, torch.ops.aten.sin.default, id="torch.ops.aten.sin") +def core_aten_sin(a): + return clang.sin(a) + + @torchsymbol(torch.sin_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sin_(a): return prims.copy_(sin(a), a) @@ -1708,6 +1990,11 @@ def sinh(a): return clang.sinh(a) +@torchsymbol(torch.ops.aten.sinh, torch.ops.aten.sinh.default, id="torch.ops.aten.sinh") +def core_aten_sinh(a): + return clang.sinh(a) + + @torchsymbol(torch.sinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sinh_(a): return prims.copy_(sinh(a), a) @@ -1718,6 +2005,11 @@ def sqrt(a): return clang.sqrt(a) +@torchsymbol(torch.ops.aten.sqrt, torch.ops.aten.sqrt.default, id="torch.ops.aten.sqrt") +def core_aten_sqrt(a): + return clang.sqrt(a) + + @torchsymbol(torch.sqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sqrt_(a): return prims.copy_(sqrt(a), a) @@ -1728,6 +2020,11 @@ def tan(a): return clang.tan(a) +@torchsymbol(torch.ops.aten.tan, torch.ops.aten.tan.default, id="torch.ops.aten.tan") +def core_aten_tan(a): + return clang.tan(a) + + @torchsymbol(torch.tan_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def tan_(a): return prims.copy_(tan(a), a) @@ -1738,6 +2035,11 @@ def tanh(a): return clang.tanh(a) +@torchsymbol(torch.ops.aten.tanh, torch.ops.aten.tanh.default, id="torch.ops.aten.tanh") +def core_aten_tanh(a): + return clang.tanh(a) + + @torchsymbol(torch.tanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def tanh_(a): return prims.copy_(tanh(a), a) @@ -1748,6 +2050,11 @@ def trunc(a): return clang.trunc(a) +@torchsymbol(torch.ops.aten.trunc, torch.ops.aten.trunc.default, id="torch.ops.aten.trunc") +def core_aten_trunc(a): + return clang.trunc(a) + + @torchsymbol(torch.trunc_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def trunc_(a): return prims.copy_(trunc(a), a) @@ -1788,8 +2095,7 @@ def elu(a: TensorProxy, /, alpha: float = 1.0, inplace: bool = False) -> TensorL _inplace_to_out_of_place[elu] = elu, 2 -@torchsymbol(torch.nn.functional.gelu, is_method=False) -def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike: +def _gelu_impl(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike: if approximate == "none": # gelu(a) = a * Phi(a), where Phi is the cdf for the Normal Gaussian. # We use the error function to compute Phi. @@ -1802,6 +2108,16 @@ def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike: raise ValueError(f"gelu does not support the approximate={approximate} argument") +@torchsymbol(torch.nn.functional.gelu, is_method=False) +def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike: + return _gelu_impl(a, approximate=approximate) + + +@torchsymbol(torch.ops.aten.gelu, torch.ops.aten.gelu.default, id="torch.ops.aten.gelu") +def core_aten_gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike: + return _gelu_impl(a, approximate=approximate) + + @torchsymbol(torch.nn.functional.leaky_relu, is_method=False) def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool = False) -> TensorLike: out = where(a > 0, a, a * negative_slope) @@ -1813,6 +2129,12 @@ def leaky_relu(a: TensorProxy, /, negative_slope: float = 0.01, inplace: bool = _inplace_to_out_of_place[leaky_relu] = leaky_relu, 2 +@torchsymbol(torch.ops.aten.leaky_relu, torch.ops.aten.leaky_relu.default, id="torch.ops.aten.leaky_relu") +def core_aten_leaky_relu(a: TensorProxy, negative_slope: float = 0.01) -> TensorLike: + out = where(a > 0, a, a * negative_slope) + return out + + @torchsymbol(torch.nn.functional.logsigmoid, is_method=False) def logsigmoid(a: TensorProxy, /) -> TensorLike: return where(a > 0, -log1p(exp(-a)), a - log1p(exp(a))) @@ -1829,7 +2151,6 @@ def log_sigmoid_backward(g: TensorProxy, a: TensorProxy, buffer: TensorProxy) -> # TODO Should this use clamp? -- Would that propagate NaNs properly? @torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True) def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike: - out = where(a > 0, a, 0) if inplace: return prims.copy_(out, a) @@ -1839,6 +2160,11 @@ def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike: _inplace_to_out_of_place[relu] = relu, 1 +@torchsymbol(torch.ops.aten.relu, torch.ops.aten.relu.default, id="torch.ops.aten.relu") +def core_aten_relu(a: TensorLike) -> TensorLike: + return where(a > 0, a, 0) + + @torchsymbol(torch.relu_, torch.nn.functional.relu_, id="torch.relu_", is_method=True) def relu_( a: TensorLike, @@ -1923,11 +2249,10 @@ def tanhshrink(a: TensorLike, /) -> TensorLike: _inplace_to_out_of_place[tanhshrink] = tanhshrink, -1 + # # Elementwise binary operations # - - @torchsymbol(torch.add, is_method=True) def add( a: NumberLike | TensorLike, b: NumberLike | TensorLike, /, *, alpha: Number | TensorLike = 1 @@ -1938,6 +2263,14 @@ def add( return clang.add(a, b) +@torchsymbol(torch.ops.aten.add.Scalar, torch.ops.aten.add.Tensor, id="torch.ops.aten.add") +def core_aten_add(a: TensorLike, b: NumberLike | TensorLike, *, alpha: Number | TensorLike = 1) -> TensorLike: + if isinstance(alpha, TensorProxy) or alpha != 1: + b = b * alpha + + return clang.add(a, b) + + @torchsymbol(torch.Tensor.add_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def add_( a: TensorLike, @@ -1954,6 +2287,11 @@ def atan2(a, b, /): return clang.atan2(a, b) +@torchsymbol(torch.ops.aten.atan2, torch.ops.aten.atan2.default, id="torch.ops.aten.atan2") +def core_aten_atan2(a, b): + return clang.atan2(a, b) + + @torchsymbol(torch.Tensor.atan2_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def atan2_(a, b, /): return prims.copy_(atan2(a, b), a) @@ -1964,6 +2302,11 @@ def bitwise_and(a, b, /): return clang.bitwise_and(a, b) +@torchsymbol(torch.ops.aten.bitwise_and.Scalar, torch.ops.aten.bitwise_and.Tensor, id="torch.ops.aten.bitwise_and") +def core_aten_bitwise_and(a, b, /): + return clang.bitwise_and(a, b) + + @torchsymbol(torch.Tensor.bitwise_and_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def bitwise_and_(a, b, /): return prims.copy_(bitwise_and(a, b), a) @@ -1974,6 +2317,11 @@ def bitwise_or(a, b, /): return clang.bitwise_or(a, b) +@torchsymbol(torch.ops.aten.bitwise_or.Scalar, torch.ops.aten.bitwise_or.Tensor, id="torch.ops.aten.bitwise_or") +def core_aten_bitwise_or(a, b, /): + return clang.bitwise_or(a, b) + + @torchsymbol(torch.Tensor.bitwise_or_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def bitwise_or_(a, b, /): return prims.copy_(bitwise_or(a, b), a) @@ -1984,6 +2332,11 @@ def bitwise_xor(a, b, /): return clang.bitwise_xor(a, b) +@torchsymbol(torch.ops.aten.bitwise_xor.Scalar, torch.ops.aten.bitwise_xor.Tensor, id="torch.ops.aten.bitwise_xor") +def core_aten_bitwise_xor(a, b, /): + return clang.bitwise_xor(a, b) + + @torchsymbol(torch.Tensor.bitwise_xor_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def bitwise_xor_(a, b, /): return prims.copy_(bitwise_xor(a, b), a) @@ -2004,9 +2357,7 @@ def copy_(a, b, /): return prims.copy_(b, a) -# TODO Implement div -@torchsymbol(torch.div, is_method=True) -def div( +def _div_impl( a: Number | TensorLike, b: Number | TensorLike, /, @@ -2026,6 +2377,28 @@ def div( raise ValueError(f"div does not support the rounding_mode={rounding_mode} argument") +# TODO Implement div +@torchsymbol(torch.div, is_method=True) +def div( + a: Number | TensorLike, + b: Number | TensorLike, + *, + rounding_mode: None | str = None, + out: None | TensorLike = None, +) -> Number | TensorLike: + return _div_impl(a, b, rounding_mode=rounding_mode, out=out) + + +@torchsymbol(torch.ops.aten.div.Scalar, torch.ops.aten.div.Tensor, id="torch.ops.aten.div") +def core_aten_div(a: TensorLike, b: Number | TensorLike) -> TensorLike: + return _div_impl(a, b) + + +@torchsymbol(torch.ops.aten.div.Scalar_mode, torch.ops.aten.div.Tensor_mode, id="torch.ops.aten.div_mode") +def core_aten_div_mode(a: TensorLike, b: Number | TensorLike, *, rounding_mode: None | str = None) -> TensorLike: + return _div_impl(a, b, rounding_mode=rounding_mode) + + @torchsymbol(torch.Tensor.div_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def div_( a: TensorLike, @@ -2042,6 +2415,11 @@ def eq(a, b, /): return clang.eq(a, b) +@torchsymbol(torch.ops.aten.eq.Scalar, torch.ops.aten.eq.Tensor, id="torch.ops.aten.eq") +def core_aten_eq(a, b): + return clang.eq(a, b) + + @torchsymbol(torch.Tensor.eq_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def eq_(a, b, /): return prims.copy_(eq(a, b), a) @@ -2062,6 +2440,11 @@ def fmod(a, b, /): return clang.fmod(a, b) +@torchsymbol(torch.ops.aten.fmod.Scalar, torch.ops.aten.fmod.Tensor, id="torch.ops.aten.fmod") +def core_aten_fmod(a, b): + return clang.fmod(a, b) + + @torchsymbol(torch.Tensor.fmod_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def fmod_(a, b, /): return prims.copy_(fmod(a, b), a) @@ -2072,6 +2455,11 @@ def ge(a, b, /): return clang.ge(a, b) +@torchsymbol(torch.ops.aten.ge.Scalar, torch.ops.aten.ge.Tensor, id="torch.ops.aten.ge") +def core_aten_ge(a, b): + return clang.ge(a, b) + + @torchsymbol(torch.Tensor.ge_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def ge_(a, b, /): return prims.copy_(ge(a, b), a) @@ -2082,6 +2470,11 @@ def gt(a, b, /): return clang.gt(a, b) +@torchsymbol(torch.ops.aten.gt.Scalar, torch.ops.aten.gt.Tensor, id="torch.ops.aten.gt") +def core_aten_gt(a, b): + return clang.gt(a, b) + + @torchsymbol(torch.Tensor.gt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def gt_(a, b, /): return prims.copy_(gt(a, b), a) @@ -2092,6 +2485,11 @@ def logical_and(a, b, /): return clang.logical_and(a, b) +@torchsymbol(torch.ops.aten.logical_and, torch.ops.aten.logical_and.default, id="torch.ops.aten.logical_and") +def core_aten_logical_and(a, b): + return clang.logical_and(a, b) + + @torchsymbol(torch.Tensor.logical_and_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def logical_and_(a, b, /): return prims.copy_(logical_and(a, b), a) @@ -2102,6 +2500,11 @@ def logical_not(a: TensorLike, /) -> TensorLike: return clang.logical_not(a) +@torchsymbol(torch.ops.aten.logical_not, torch.ops.aten.logical_not.default, id="torch.ops.aten.logical_not") +def core_aten_logical_not(a: TensorLike, /) -> TensorLike: + return clang.logical_not(a) + + @torchsymbol(torch.Tensor.logical_not_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def logical_not_(a: TensorLike, /) -> TensorLike: return prims.copy_(logical_not(a), a) @@ -2112,6 +2515,11 @@ def le(a, b, /): return clang.le(a, b) +@torchsymbol(torch.ops.aten.le.Scalar, torch.ops.aten.le.Tensor, id="torch.ops.aten.le") +def core_aten_le(a, b): + return clang.le(a, b) + + @torchsymbol(torch.Tensor.le_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def le_(a, b, /): return prims.copy_(le(a, b), a) @@ -2122,6 +2530,11 @@ def lt(a, b, /): return clang.lt(a, b) +@torchsymbol(torch.ops.aten.lt.Scalar, torch.ops.aten.lt.Tensor, id="torch.ops.aten.lt") +def core_aten_lt(a, b): + return clang.lt(a, b) + + @torchsymbol(torch.Tensor.lt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def lt_(a, b, /): return prims.copy_(lt(a, b), a) @@ -2132,11 +2545,21 @@ def maximum(a: TensorProxy, b: TensorProxy) -> TensorProxy: return clang.maximum(a, b) +@torchsymbol(torch.ops.aten.maximum, torch.ops.aten.maximum.default, id="torch.ops.aten.maximum") +def core_aten_maximum(a: TensorProxy, b: TensorProxy) -> TensorProxy: + return clang.maximum(a, b) + + @torchsymbol(torch.minimum, is_method=True) def minimum(a: TensorProxy, b: TensorProxy) -> TensorProxy: return clang.minimum(a, b) +@torchsymbol(torch.ops.aten.minimum, torch.ops.aten.minimum.default, id="torch.ops.aten.minimum") +def core_aten_minimum(a: TensorProxy, b: TensorProxy) -> TensorProxy: + return clang.minimum(a, b) + + # NOTE This is just an alias for proxies to find operation defined for the modulus # operator # TODO Review this alias @@ -2153,6 +2576,11 @@ def mul(a, b, /): return clang.mul(a, b) +@torchsymbol(torch.ops.aten.mul.Scalar, torch.ops.aten.mul.Tensor, id="torch.ops.aten.mul") +def core_aten_mul(a, b): + return clang.mul(a, b) + + @torchsymbol(torch.Tensor.mul_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def mul_(a, b, /): return prims.copy_(mul(a, b), a) @@ -2163,6 +2591,11 @@ def ne(a, b, /): return clang.ne(a, b) +@torchsymbol(torch.ops.aten.ne.Scalar, torch.ops.aten.ne.Tensor, id="torch.ops.aten.ne") +def core_aten_ne(a, b): + return clang.ne(a, b) + + @torchsymbol(torch.Tensor.ne_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def ne_(a, b, /): return prims.copy_(ne(a, b), a) @@ -2206,6 +2639,16 @@ def pow(a, b, /): return clang.pow(a, b) +@torchsymbol( + torch.ops.aten.pow.Scalar, + torch.ops.aten.pow.Tensor_Scalar, + torch.ops.aten.pow.Tensor_Tensor, + id="torch.ops.aten.pow", +) +def core_aten_pow(a, b): + return clang.pow(a, b) + + @torchsymbol(torch.Tensor.pow_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def pow_(a, b, /): return prims.copy_(pow(a, b), a) @@ -2216,6 +2659,11 @@ def remainder(a, b, /): return clang.remainder(a, b) +@torchsymbol(torch.ops.aten.remainder.Scalar, torch.ops.aten.remainder.Tensor, id="torch.ops.aten.remainder") +def core_aten_remainder(a, b): + return clang.remainder(a, b) + + @torchsymbol(torch.Tensor.remainder_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def remainder_(a, b, /): return prims.copy_(remainder(a, b), a) @@ -2229,6 +2677,14 @@ def sub(a, b, /, *, alpha: NumberLike | TensorLike = 1): return clang.sub(a, b) +@torchsymbol(torch.ops.aten.sub.Scalar, torch.ops.aten.sub.Tensor, id="torch.ops.aten.sub") +def core_aten_sub(a, b, /, *, alpha: NumberLike | TensorLike = 1): + if isinstance(alpha, TensorProxy) or alpha != 1: + b = b * alpha + + return clang.sub(a, b) + + @torchsymbol(torch.Tensor.sub_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def sub_(a, b, /, *, alpha: NumberLike | TensorLike = 1): return prims.copy_(sub(a, b, alpha=alpha), a) @@ -2308,8 +2764,7 @@ def lerp_(start: TensorLike, end: TensorLike, weight: Number | TensorLike) -> Te # -@torchsymbol(torch.clamp, is_method=True) -def clamp( +def _clamp_impl( a: TensorLike, /, min: None | Number | TensorLike = None, @@ -2343,6 +2798,25 @@ def clamp( return a +@torchsymbol(torch.clamp, is_method=True) +def clamp( + a: TensorLike, + /, + min: None | Number | TensorLike = None, + max: None | Number | TensorLike = None, +) -> TensorLike: + return _clamp_impl(a, min=min, max=max) + + +@torchsymbol(torch.ops.aten.clamp, torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor, id="torch.ops.aten.clamp") +def core_aten_clamp( + a: TensorLike, + min: None | Number | TensorLike = None, + max: None | Number | TensorLike = None, +) -> TensorLike: + return _clamp_impl(a, min=min, max=max) + + @torchsymbol(torch.clamp_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def clamp_( a: TensorLike, /, min: None | Number | TensorLike = None, max: None | Number | TensorLike = None @@ -2425,6 +2899,20 @@ def where( return clang.where(pred, a, b) +@torchsymbol(torch.ops.aten.where.self, id="torch.ops.aten.where") +def core_aten_where( + condition: TensorLike, + self: TensorLike, + other: TensorLike, +) -> TensorProxy: + utils.check( + isinstance(a, (Number, NumberProxy, TensorProxy)) and isinstance(b, (Number, NumberProxy, TensorProxy)), + lambda: f"torch.where() does not support only specifying a condition", + exception_type=NotImplementedError, + ) + return clang.where(pred, a, b) + + @torchsymbol(torch.nan_to_num, is_method=True) def nan_to_num( a: TensorLike, @@ -2655,15 +3143,18 @@ def amin(a, /, dim=None, keepdim: bool = False): # NOTE: Using name `torch_max` to avoid conflict with Python's `max` @overload -def torch_max(a: TensorLike, /) -> TensorLike: ... +def torch_max(a: TensorLike, /) -> TensorLike: + ... @overload -def torch_max(a: TensorLike, /, dim: NumberLike, keepdim: bool = False) -> tuple[TensorLike, TensorLike]: ... +def torch_max(a: TensorLike, /, dim: NumberLike, keepdim: bool = False) -> tuple[TensorLike, TensorLike]: + ... @overload -def torch_max(a: TensorLike, b: TensorLike, /) -> TensorLike: ... +def torch_max(a: TensorLike, b: TensorLike, /) -> TensorLike: + ... @torchsymbol(torch.max, is_method=True, method_name="max", id="torch.max") @@ -2710,6 +3201,15 @@ def clone(a: TensorProxy, *, memory_format=torch.preserve_format) -> TensorProxy return prims.clone(a) +@torchsymbol(torch.ops.aten.clone, torch.ops.aten.clone.default, id="torch.ops.aten.clone") +def core_aten_clone(a: TensorProxy, *, memory_format: None | torch.memory_format = None) -> TensorProxy: + if memory_format is None: + memory_format = torch.preserve_format + if memory_format is not torch.preserve_format: + raise NotImplementedError("only preserve_format is currently supported") + return prims.clone(a) + + # Because we do not use @torchsymbol, we need to manually register the # implementation. register_function(torch.clone, clone) @@ -2730,8 +3230,7 @@ def glu(a: TensorProxy, /, dim: int = -1) -> TensorProxy: return out -@torchsymbol(torch.mean, is_method=True) -def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None) -> TensorProxy: +def _mean_impl(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None) -> TensorProxy: dtype = dtype if dtype is not None else a.dtype utils.check( not utils.is_integer_dtype(dtype) and not utils.is_boolean_dtype(dtype), @@ -2740,25 +3239,103 @@ def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None) -> T result = _reduction( a, - prims.sum, + prims.sum, + dims=dim, + keepdims=keepdim, + dtype=dtype, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, + ) + + dims = _reduction_dims(a.shape, dim) # type: ignore[arg-type] + nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1) + result = result / nelem + result_dtype = a.dtype if dtype is None else dtype + result = to(result, result_dtype) + return result + + +@torchsymbol(torch.mean, is_method=True) +def mean(a: TensorProxy, /, dim=None, keepdim: bool = False, *, dtype=None) -> TensorProxy: + return _mean_impl(a, dim=dim, keepdim=keepdim, dtype=dtype) + + +@torchsymbol(torch.ops.aten.mean, torch.ops.aten.mean.default, id="torch.ops.aten.mean") +def core_aten_mean(a: TensorProxy, *, dtype=None) -> TensorProxy: + return _mean_impl(a, dtype=dtype) + + +@torchsymbol(torch.ops.aten.mean.dim, id="torch.ops.aten.mean.dim") +def core_aten_mean(a: TensorProxy, dim: int, keepdim: bool = False, *, dtype=None) -> TensorProxy: + return _mean_impl(a, dim=dim, keepdim=keepdim, dtype=dtype) + + +@torchsymbol(torch.prod, is_method=True) +def prod( + a: TensorProxy, /, dim: None | Sequence[int] = None, keepdim: bool = False, *, dtype: None | dtypeLike = None +) -> TensorProxy: + # Promotes all exact dtypes to int64 + if dtype is None: + if utils.is_exact_dtype(a.dtype): + dtype = dtypes.int64 + else: + dtype = a.dtype + + result = _reduction( + a, + prims.prod, + dims=dim, + keepdims=keepdim, + dtype=dtype, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + return result + + +@torchsymbol(torch.ops.aten.prod, torch.ops.aten.prod.default, id="torch.ops.aten.prod") +def core_aten_prod(a: TensorProxy, *, dtype: None | dtypeLike = None) -> TensorProxy: + # Promotes all exact dtypes to int64 + if dtype is None: + if utils.is_exact_dtype(a.dtype): + dtype = dtypes.int64 + else: + dtype = a.dtype + + result = _reduction( + a, + prims.prod, + dtype=dtype, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + return result + + +@torchsymbol(torch.ops.aten.prod.dim_int, id="torch.ops.aten.prod.dim_int") +def core_aten_prod(a: TensorProxy, dim: int, keepdim: bool = False, *, dtype: None | dtypeLike = None) -> TensorProxy: + # Promotes all exact dtypes to int64 + if dtype is None: + if utils.is_exact_dtype(a.dtype): + dtype = dtypes.int64 + else: + dtype = a.dtype + + result = _reduction( + a, + prims.prod, dims=dim, keepdims=keepdim, dtype=dtype, - output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, ) - dims = _reduction_dims(a.shape, dim) # type: ignore[arg-type] - nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1) - result = result / nelem - result_dtype = a.dtype if dtype is None else dtype - result = to(result, result_dtype) return result -@torchsymbol(torch.prod, is_method=True) -def prod( - a: TensorProxy, /, dim: None | Sequence[int] = None, keepdim: bool = False, *, dtype: None | dtypeLike = None -) -> TensorProxy: +@torchsymbol(torch.sum, is_method=True) +def sum( + a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False, *, dtype: None | dtypeLike = None +) -> TensorLike: # Promotes all exact dtypes to int64 if dtype is None: if utils.is_exact_dtype(a.dtype): @@ -2768,7 +3345,7 @@ def prod( result = _reduction( a, - prims.prod, + prims.sum, dims=dim, keepdims=keepdim, dtype=dtype, @@ -2778,10 +3355,8 @@ def prod( return result -@torchsymbol(torch.sum, is_method=True) -def sum( - a: TensorLike, /, dim: None | int | Sequence[int] = None, keepdim: bool = False, *, dtype: None | dtypeLike = None -) -> TensorLike: +@torchsymbol(torch.ops.aten.sum.dim_IntList, id="torch.ops.aten.sum.dim_IntList") +def core_aten_sum(a: TensorProxy, dim: int, keepdim: bool = False, *, dtype: dtypeLike | None = None) -> TensorProxy: # Promotes all exact dtypes to int64 if dtype is None: if utils.is_exact_dtype(a.dtype): @@ -2824,9 +3399,7 @@ def _sum_grad( register_grad(sum, _sum_grad) -# NOTE This decomposition can not be efficiently fused, so make it primitive -@torchsymbol(torch.cumsum, is_method=True, is_prim=True) -def cumsum(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> TensorLike: +def _cumsum_impl(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> TensorLike: # check the input dimension utils.canonicalize_dim(a.ndim, dim) if dtype is None: @@ -2835,6 +3408,17 @@ def cumsum(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> Tensor return TensorProxy(like=a, dtype=to_dtype(dtype)) +# NOTE This decomposition can not be efficiently fused, so make it primitive +@torchsymbol(torch.cumsum, is_method=True, is_prim=True) +def cumsum(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> TensorLike: + return _cumsum_impl(a, dim, dtype=dtype) + + +@torchsymbol(torch.ops.aten.cumsum, torch.ops.aten.cumsum.default, is_prim=True, id="torch.ops.aten.cumsum") +def core_aten_cumsum(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> TensorLike: + return _cumsum_impl(a, dim, dtype=dtype) + + @torchsymbol(torch.Tensor.cumsum_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def cumsum_(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> TensorLike: return prims.copy_(cumsum(a, dim, dtype=dtype), a) @@ -2861,6 +3445,45 @@ def var( return result +@torchsymbol(torch.ops.aten.var.correction, id="torch.ops.aten.var.correction") +def core_aten_var_correction( + a: TensorProxy, + dim: int | None = None, + *, + correction: NumberLike | None = None, + keepdim: bool = False, +) -> TensorProxy: + result = _reduction( + a, + partial(prims.var, correction=correction if correction is not None else -1), + dims=dim, + keepdims=keepdim, + dtype=None, + has_identity=True, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, + ) + return result + + +@torchsymbol(torch.ops.aten.var.dim, id="torch.ops.aten.var.dim") +def core_aten_var_dim( + a: TensorProxy, + dim: int | None = None, + unbiased: bool = True, + keepdim: bool = False, +) -> TensorProxy: + result = _reduction( + a, + partial(prims.var, correction=1 if unbiased else 0), + dims=dim, + keepdims=keepdim, + dtype=None, + has_identity=True, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, + ) + return result + + @torchsymbol(torch.var_mean, tags=(prims.OpTags.REDUCTION_OP,)) def var_mean( a: TensorProxy, @@ -2899,6 +3522,13 @@ def topk( return clang.topk(a, k, dim, largest, sorted, out=out) +@torchsymbol(torch.ops.aten.topk, torch.ops.aten.topk.default, id="torch.ops.aten.topk") +def core_aten_topk( + a: TensorLike, /, k: int, dim: int = -1, largest: bool = True, sorted: bool = True +) -> (TensorLike, TensorLike): + return clang.topk(a, k, dim, largest, sorted) + + @torchsymbol(torch.sort, is_method=True) def sort( a: TensorLike, /, dim: None | int = None, descending: bool = False, stable: bool = False, *, out=None @@ -2906,6 +3536,11 @@ def sort( return clang.sort(a, dim, descending, stable, out=out) +@torchsymbol(torch.ops.aten.sort, torch.ops.aten.sort.default, id="torch.ops.aten.sort") +def core_aten_sort(a: TensorLike, /, dim: int = -1, descending: bool = False) -> (TensorLike, TensorLike): + return clang.sort(a, dim, descending) + + # # Scatter and gather-related operations # @@ -2937,15 +3572,25 @@ def index_select(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike: return clang.take(a, index, dim) +@torchsymbol(torch.ops.aten.index_select, torch.ops.aten.index_select.default, id="torch.ops.aten.index_select") +def core_aten_index_select(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike: + return clang.take(a, index, dim) + + @torchsymbol(torch.gather, is_method=True) def gather(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike: return clang.gather(a, indices=index, dim=dim) +@torchsymbol(torch.ops.aten.gather, torch.ops.aten.gather.default, id="torch.ops.aten.gather") +def core_aten_gather(a: TensorLike, /, dim: int, index: TensorLike, *, sparse_grad: bool = False) -> TensorLike: + utils.check(not sparse_grad, lambda: f"{sparse_grad=} is not supported", exception_type=NotImplementedError) + return clang.gather(a, indices=index, dim=dim) + + # NOTE: PyTorch uses `src` for torch.Tensor arguments and `value` for scalars # when referencing the source of the values -@torchsymbol(torch.scatter, is_method=True) -def scatter( +def _scatter_impl( a: TensorLike, /, dim: int, @@ -2970,6 +3615,40 @@ def scatter( return clang.scatter(a, index, value, dim) +@torchsymbol(torch.scatter, is_method=True) +def scatter( + a: TensorLike, + /, + dim: int, + index: TensorLike, + src: TensorLike | None = None, + *, + value: None | Number = None, + reduce: None | str = None, +) -> TensorLike: + return _scatter_impl(a, dim, index, src, value=value, reduce=reduce) + + +@torchsymbol(torch.ops.aten.scatter.src, id="torch.ops.aten.scatter.src") +def core_aten_scatter_tensor( + a: TensorProxy, + dim: int, + index: TensorProxy, + src: TensorProxy, +) -> TensorLike: + return _scatter_impl(a, dim, index, src) + + +@torchsymbol(torch.ops.aten.scatter.value, id="torch.ops.aten.scatter.value") +def core_aten_scatter_value( + a: TensorProxy, + dim: int, + index: TensorProxy, + value: Number, +) -> TensorLIke: + return _scatter_impl(a, dim, index, value=value) + + @torchsymbol(torch.Tensor.scatter_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def scatter_( a: TensorLike, @@ -3002,6 +3681,11 @@ def scatter_add(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) return clang.scatter_add(a, indices=index, value=src, dim=dim) +@torchsymbol(torch.ops.aten.scatter_add, torch.ops.aten.scatter_add.default, id="torch.ops.aten.scatter_add") +def core_aten_scatter_add(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike: + return clang.scatter_add(a, indices=index, value=src, dim=dim) + + @torchsymbol(torch.Tensor.scatter_add_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def scatter_add_(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike: return prims.copy_(scatter_add(a, dim, index, src), a) @@ -3019,6 +3703,13 @@ def index_put( return clang.index_put(a, indices, values, accumulate) +@torchsymbol(torch.ops.aten.index_put, torch.ops.aten.index_put.default, id="torch.ops.aten.index_put") +def core_aten_index_put( + a: TensorLike, /, indices: Sequence[TensorLike], values: TensorLike, accumulate: bool = False +) -> TensorLike: + return clang.index_put(a, indices, values, accumulate) + + @torchsymbol(torch.index_put_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) def index_put_( a: TensorLike, @@ -3725,6 +4416,19 @@ def layer_norm( return _native_layer_norm(a, normalized_shape, weight, bias, eps)[0] +@torchsymbol( + torch.ops.aten.native_layer_norm, torch.ops.aten.native_layer_norm.default, id="torch.ops.aten.native_layer_norm" +) +def core_aten_layer_norm( + a: TensorLike, + normalized_shape: Sequence[int], + weight: TensorLike | None, + bias: TensorLike | None, + eps: NumberLike, +) -> tuple[TensorProxy, TensorProxy, TensorProxy]: + return _native_layer_norm(a, normalized_shape, weight, bias, eps) + + def rms_norm( a: TensorLike, /, @@ -3899,8 +4603,12 @@ def bmm(a: TensorLike, b: TensorLike, /) -> TensorLike: return matmul(a, b) -@torchsymbol(torch.convolution, is_method=False) -def convolution( +@torchsymbol(torch.ops.aten.bmm, torch.ops.aten.bmm.default, id="torch.ops.aten.bmm") +def core_aten_bmm(a: TensorLike, b: TensorLike, /) -> TensorLike: + return matmul(a, b) + + +def _convolution_impl( a: TensorLike, weight: TensorLike, bias: None | TensorLike, @@ -3930,6 +4638,36 @@ def convolution( ) +@torchsymbol(torch.convolution, is_method=False) +def convolution( + a: TensorLike, + weight: TensorLike, + bias: None | TensorLike, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + transposed: bool, + output_padding: Sequence[int], + groups: int, +) -> TensorLike: + return _convolution_impl(a, weight, bias, stride, padding, dilation, transposed, output_padding, groups) + + +@torchsymbol(torch.ops.aten.convolution, torch.ops.aten.convolution.default, id="torch.ops.aten.convolution") +def core_aten_convolution( + a: TensorLike, + weight: TensorLike, + bias: None | TensorLike, + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + transposed: bool, + output_padding: Sequence[int], + groups: int, +) -> TensorLike: + return _convolution_impl(a, weight, bias, stride, padding, dilation, transposed, output_padding, groups) + + # Helper functions that are useful for "window"-based ops # like convolution, pooling and similar. { @@ -4242,6 +4980,18 @@ def avg_pool1d( return _avg_pool_helper(1, a, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) +@torchsymbol(torch.ops.aten.avg_pool1d, torch.ops.aten.avg_pool1d.default, id="torch.ops.aten.avg_pool1d") +def core_aten_avg_pool1d( + a: TensorProxy, + kernel_size: int | Sequence[int], + stride: Sequence[int] = [], + padding: int = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, +) -> TensorProxy: + return _avg_pool_helper(1, a, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + + @torchsymbol(torch.nn.functional.avg_pool2d, id="torch.nn.functional.avg_pool2d", is_method=False) def avg_pool2d( a: TensorProxy, @@ -4256,6 +5006,19 @@ def avg_pool2d( return _avg_pool_helper(2, a, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) +@torchsymbol(torch.ops.aten.avg_pool2d, torch.ops.aten.avg_pool2d.default, id="torch.ops.aten.avg_pool2d") +def core_aten_avg_pool2d( + a: TensorProxy, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: NumberLike | None = None, +) -> TensorProxy: + return _avg_pool_helper(2, a, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + + @torchsymbol(torch.nn.functional.avg_pool3d, id="torch.nn.functional.avg_pool3d", is_method=False) def avg_pool3d( a: TensorProxy, @@ -4270,6 +5033,20 @@ def avg_pool3d( return _avg_pool_helper(3, a, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) +@torchsymbol(torch.ops.aten.avg_pool3d, torch.ops.aten.avg_pool3d.default, id="torch.ops.aten.avg_pool3d") +def core_aten_avg_pool3d( + a: TensorProxy, + /, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: NumberLike | None = None, +) -> TensorProxy: + return _avg_pool_helper(3, a, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + + @torchsymbol( torch.nn.functional.adaptive_avg_pool2d, id="torch.nn.functional.adaptive_avg_pool2d", is_method=False, is_prim=True ) @@ -4360,6 +5137,22 @@ def max_pool2d( return _max_pool_helper(2, a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) +@torchsymbol( + torch.ops.aten.max_pool2d_with_indices, + torch.ops.aten.max_pool2d_with_indices.default, + id="torch.ops.aten.max_pool2d_with_indices", +) +def core_aten_max_pool2d_with_indices( + a: TensorProxy, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, +) -> tuple[TensorProxy, TensorProxy]: + return _max_pool_helper(2, a, kernel_size, stride, padding, dilation, True, ceil_mode) + + @torchsymbol(torch.max_pool3d, torch.nn.functional.max_pool3d, id="torch.nn.functional.max_pool3d", is_method=False) def max_pool3d( a: TensorProxy, @@ -4374,6 +5167,22 @@ def max_pool3d( return _max_pool_helper(3, a, kernel_size, stride, padding, dilation, return_indices, ceil_mode) +@torchsymbol( + torch.ops.aten.max_pool3d_with_indices, + torch.ops.aten.max_pool3d_with_indices.default, + id="torch.ops.aten.max_pool3d_with_indices", +) +def core_aten_max_pool3d_with_indices( + a: TensorProxy, + kernel_size: int | Sequence[int], + stride: int | Sequence[int] | None = None, + padding: int | Sequence[int] = 0, + dilation: int | Sequence[int] = 1, + ceil_mode: bool = False, +) -> tuple[TensorProxy, TensorProxy]: + return _max_pool_helper(3, a, kernel_size, stride, padding, dilation, True, ceil_mode) + + @torchsymbol(torch.conv1d, torch.nn.functional.conv1d, id="torch.nn.functional.conv1d", is_method=False) def conv1d( a: TensorProxy, @@ -4668,8 +5477,30 @@ def dropout(a: TensorProxy, /, p: NumberLike = 0.5, training: bool = True, inpla _inplace_to_out_of_place[dropout] = dropout, 3 -@torchsymbol(torch.nn.functional.embedding, id="torch.nn.functional.embedding") -def embedding( +@torchsymbol(torch.ops.aten.native_dropout, torch.ops.aten.native_dropout.default, id="torch.ops.aten.native_dropout") +def core_aten_dropout(a: TensorProxy, p: NumberLike, train: bool) -> tuple[TensorProxy, TensorProxy]: + if not train: + return a + + utils.check( + p <= 1 and p >= 0, + lambda: f"Dropout probability has to be between 0 and 1, but got, {p}", + ) + + dropout_mask = _dropout_helper(a, 1 - p) + if p == 1: + return zeros_like(a), dropout_mask + + if p == 0: + return a, dropout_mask + + scale = 1 / (1 - p) + + out = a * dropout_mask * scale + return out, dropout_mask + + +def _embedding_impl( a: TensorLike, /, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False ) -> TensorLike: # TODO: add embedding_renorm_ so we can remove embedding prim @@ -4697,7 +5528,35 @@ def embedding( return reshape(flatten_output, output_shape) -@torchsymbol(torch.ops.aten.embedding_backward) +@torchsymbol(torch.nn.functional.embedding, id="torch.nn.functional.embedding") +def embedding( + a: TensorLike, /, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False +) -> TensorLike: + return _embedding_impl( + a, + weight, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + + +@torchsymbol(torch.ops.aten.embedding, torch.ops.aten.embedding.default, id="torch.ops.aten.embedding") +def core_aten_embedding( + weight: TensorLike, + indices: TensorLike, + padding_idx: int = -1, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> TensorLike: + return _embedding_impl( + indices, weight, padding_idx=padding_idx, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse + ) + + +@torchsymbol(torch.ops.aten.embedding_backward, torch.ops.aten.embedding_dense_backward) def embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse): result = prims.embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse) return result @@ -4721,15 +5580,17 @@ def one_hot(a: TensorLike, /, num_classes: int) -> TensorLike: return scatter_add(canvas, dim=-1, index=index, src=src) -@torchsymbol(torch.group_norm, torch.nn.functional.group_norm, id="torch.nn.functional.group_norm", is_method=False) -def group_norm( +# @torchsymbol(torch.group_norm, torch.nn.functional.group_norm, id="torch.nn.functional.group_norm", is_method=False) +# def group_norm( +def _group_norm_impl( a: TensorProxy, /, num_groups: int, weight: None | TensorProxy = None, bias: None | TensorProxy = None, eps: float = 1e-5, -) -> TensorProxy: + return_stats: bool = False, +) -> tuple[TensorProxy, TensorProxy, TensorProxy]: utils.check(a.ndim >= 2, lambda: f"group_norm: {a.ndim=} should be at least 2") batch_size, num_channels, *inner_dims = a.shape @@ -4757,7 +5618,7 @@ def group_norm( # Perform Normalization (yes, subtract mean, divide by sd) over all the dims # but the batch and the group dim. - res, *_ = _normalize(a_groupped, norm_dims=range(2, a_groupped.ndim), eps=eps) + res, mean, rstd = _normalize(a_groupped, norm_dims=range(2, a_groupped.ndim), eps=eps) # Restore the channel dimension res = view(res, a.shape) @@ -4773,7 +5634,38 @@ def group_norm( res = res + bias res = to(res, a.dtype) - return res + if return_stats: + return res, mean, rstd + else: + return res + + +@torchsymbol(torch.group_norm, torch.nn.functional.group_norm, id="torch.nn.functional.group_norm", is_method=False) +def group_norm( + a: TensorProxy, + /, + num_groups: int, + weight: None | TensorProxy = None, + bias: None | TensorProxy = None, + eps: float = 1e-5, +) -> TensorProxy: + return _group_norm_impl(a, num_groups, weight, bias, eps, return_stats=False) + + +@torchsymbol( + torch.ops.aten.native_group_norm, torch.ops.aten.native_group_norm.default, id="torch.ops.aten.native_group_norm" +) +def core_aten_group_norm( + a: TensorProxy, + weight: TensorProxy | None, + bias: TensorProxy | None, + N: int, + C: int, + HxW: int, + group: int, + eps: float, +) -> tuple[TensorProxy, TensorProxy, TensorProxy]: + return _group_norm_impl(a, num_groups, weight, bias, eps, return_stats=True) def _interpolate_scale_factor_helper( @@ -5232,6 +6124,11 @@ def sigmoid(a: TensorLike, /) -> TensorLike: return clang.sigmoid(a) +@torchsymbol(torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default, id="torch.ops.aten.sigmoid") +def core_aten_sigmoid(a: TensorLike) -> TensorLike: + return clang.sigmoid(a) + + # CompositeImplicitAutograd - don't register decomp @torchsymbol(torch.softmax, torch.nn.functional.softmax, is_method=True, id="torch.softmax") def _softmax( From db3a6c998598fb7a2a5464d8c5c0d3527035fc2c Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 23 Dec 2024 22:29:15 +0900 Subject: [PATCH 2/3] trace transform of tensor wrapper subclass to support `__torch_dispatch__`. Since it extends the behavior that is implemented in C++ level, we'd need to apply the transform to split forward and backward traces separately. Signed-off-by: Masaki Kozuki --- docs/source/reference/transforms/index.rst | 1 + thunder/__init__.py | 9 +- thunder/core/prims.py | 168 +++- thunder/core/proxies.py | 38 +- thunder/executors/torch_autograd.py | 7 + thunder/executors/torchex.py | 48 +- thunder/tests/test_tensor_subclass.py | 52 ++ thunder/transforms/__init__.py | 2 + thunder/transforms/tensor_wrapper_subclass.py | 739 ++++++++++++++++++ 9 files changed, 1056 insertions(+), 8 deletions(-) create mode 100644 thunder/transforms/tensor_wrapper_subclass.py diff --git a/docs/source/reference/transforms/index.rst b/docs/source/reference/transforms/index.rst index 8711275e14..be3db1c73f 100644 --- a/docs/source/reference/transforms/index.rst +++ b/docs/source/reference/transforms/index.rst @@ -8,3 +8,4 @@ thunder.transforms MaterializationTransform ConstantFolding + unroll_tensor_subclasses diff --git a/thunder/__init__.py b/thunder/__init__.py index 7a04a43fab..5c915cd714 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -73,6 +73,7 @@ from thunder.core.interpreter import print_interpreter_log, print_to_log from thunder.core.jit_ext import thunder_general_jit from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction +from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses # NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this import torch as pytorch @@ -369,7 +370,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]: data_ptr_to_tensor_group_index = {} tensor_group_index_to_tensor_indices = defaultdict(list) for idx, t in enumerate(flat_args): - if pytorch.is_tensor(t) and t.layout == pytorch.strided: + if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided: data_ptr = t.untyped_storage().data_ptr() if data_ptr not in data_ptr_to_tensor_group_index: data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index) @@ -616,6 +617,7 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = dce(computation_trc) computation_traces.append(computation_trc) + _unroll_tensor_subclasses_applied = False backward_trc = None if not cd.disable_torch_autograd_support: tensor_cls = (pytorch.Tensor, TensorProxy) @@ -626,10 +628,15 @@ def get_computation_and_inputs(*args, **kwargs): # transform_for_execution and various sorting of symbols, # applying transform_for_execution after this would be # breaking the order of operations + _unroll_tensor_subclasses_applied = True computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces # by split_forward_backward + if not _unroll_tensor_subclasses_applied: + computation_trc = unroll_tensor_subclasses(computation_trc) + computation_traces.append(computation_trc) + if backward_trc is None: from thunder.executors.passes import transform_for_execution as transform_for_execution_pass from thunder.executors.passes import _transform_for_operator_executor_execution diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 90b062eeb4..9b718317f7 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -280,6 +280,8 @@ class PrimIDs(Enum): SINK = auto() # Tensor Subclasses methods TENSOR_SUBCLASS_CTOR = auto() + FLATTEN_TENSOR_SUBCLASS = auto() + UNFLATTEN_TENSOR_SUBCLASS = auto() class OpTags(Enum): @@ -4098,7 +4100,7 @@ def check_types(coll): return tuple(types_set) -def filter_types(types: tuple[Any, ...]) -> tuple[Any, ...]: +def filter_types_for_tensor_wrapper_subclass(types: tuple[Any, ...]) -> tuple[Any, ...]: return tuple( filter( lambda t: ( @@ -4170,7 +4172,7 @@ def printer_of_tensor_subclass_ctor( filtered_types = (cls,) if non_tensors: types = get_nested_types([t.obj if isinstance(t, codeutils.ContextObject) else t for t in non_tensors]) - filtered_types += filter_types(types) + filtered_types += filter_types_for_tensor_wrapper_subclass(types) new_imports = {t.__name__: t for t in filtered_types} bsym._import_ctx.update(new_imports) return s @@ -4183,7 +4185,7 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: filtered_types: tuple[Any, ...] = (cls,) if non_tensors: types = get_nested_types(non_tensors) - filtered_types += filter_types(types) + filtered_types += filter_types_for_tensor_wrapper_subclass(types) new_imports = {t.__name__: t for t in filtered_types} bsym._import_ctx.update(new_imports) @@ -4195,3 +4197,163 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: python_printer=printer_of_tensor_subclass_ctor, _bind_postprocess=bind_postprocess_of_tensor_subclass_ctor, ) + + +def printer_of_tensor_subclass_flatten( + bsym: BoundSymbol, + out_printables: Any, + arg_printables: Sequence[Printable], + kwarg_printables: dict[str, Printable], +) -> str | Iterable[str]: + from itertools import chain + + arg_str = ( + "" + if (arg_printables is None or len(arg_printables) == 0) + else ", ".join(codeutils.prettyprint(x) for x in arg_printables) + ) + + result_str: str + if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0): + result_str = "" + else: + result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = " + + # Creates a comment describing the output + comment_str = "" + if isinstance(bsym.output, Proxy): + comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}" + + s = f"{result_str}{arg_str}.__tensor_flatten__(){comment_str}" + + if bsym.header: + header_lines = ( + bsym.header + if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str) + else bsym.header.splitlines() + ) + header_lines = (f"# {line}" for line in header_lines) + return chain(header_lines, [s]) + + return s + + +# NOTE(crcrpar): The behavior is different from PyTorch `subclass_tensor.__tensor_flatten__()` +# that returns a list of tensor attr names and a dict of const metadata. In Thunder traces, +# const values could be obviated and actual tensor proxies would be more useful +# than tensor attr names. +def flatten_tensor_subclass_meta(t: SubclassTensorProxy) -> tuple[TensorProxy, ...]: + tensor_attr_names, metadata = t.__tensor_flatten__() + tensors = tuple(getattr(t, name) for name in tensor_attr_names) + return tensors + + +flatten_tensor_subclass = make_prim( + PrimIDs.FLATTEN_TENSOR_SUBCLASS, + "flatten_tensor_subclass", + meta=flatten_tensor_subclass_meta, + python_printer=printer_of_tensor_subclass_flatten, +) + + +def printer_of_unflatten_tensor_subclass( + bsym: BoundSymbol, + out_printables: Any, + arg_printables: Sequence[Printable], + kwarg_printables: dict[str, Printable], +) -> str | Iterable[str]: + from itertools import chain + + wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0] + if isinstance(wrapped_cls, torch._C._TensorMeta): + cls = wrapped_cls + else: + cls: torch._C._TensorMeta = wrapped_cls.obj + + arg_str = ( + "" + if (arg_printables is None or len(arg_printables) == 0) + else ", ".join(codeutils.prettyprint(x) for x in arg_printables[1:]) + ) + kwarg_str: str + + if len(kwarg_printables) == 0: + kwarg_str = "" + else: + kwarg_str = ", ".join(f"{k}={codeutils.prettyprint(v)}" for k, v in kwarg_printables.items()) + + result_str: str + if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0): + result_str = "" + else: + result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = " + + # Creates a comment describing the output + comment_str = "" + if isinstance(bsym.output, Proxy): + comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}" + + s = f"{result_str}{cls.__name__}.__tensor_unflatten__({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}" + + if bsym.header: + header_lines = ( + bsym.header + if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str) + else bsym.header.splitlines() + ) + header_lines = (f"# {line}" for line in header_lines) + return chain(header_lines, [s]) + + return s + + +def bind_postprocess_of_unflatten_tensor_subclass(bsym: BoundSymbol) -> None: + cls = bsym.args[0] + inner_tensors = bsym.args[1] + metadata = bsym.args[2] + + filtered_types: tuple[Any, ...] = (cls,) + if metadata: + types = get_nested_types(list(metadata.values())) + filtered_types += filter_types_for_tensor_wrapper_subclass(types) + new_imports = {t.__name__: t for t in filtered_types} + bsym._import_ctx.update(new_imports) + + +def unflatten_tensor_subclass_meta( + tensor_subclass_type, + inner_tensors: dict[str, TensorProxy], + metadata: dict[str, Any], +) -> SubclassTensorProxy: + first_tensor: TensorProxy = list(inner_tensors.values())[0] + a = SubclassTensorProxy( + shape=first_tensor.shape, + device=first_tensor.device, + dtype=first_tensor.dtype, + requires_grad=first_tensor.requires_grad, + tensors=list(inner_tensors.values()), + non_tensors=list(metadata.values()), + subclass_type=tensor_subclass_type, + ) + for name, value in inner_tensors.items(): + setattr(a, name, value) + for name, value in metadata.items(): + setattr(a, name, value) + return a + + +def unflatten_tensor_subclass_python_impl( + tensor_subclass_type, + inner_tensors: dict[str, TensorProxy], + metadata: dict[str, Any], +) -> torch.Tensor: + return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1) + + +unflatten_tensor_subclass = make_prim( + PrimIDs.UNFLATTEN_TENSOR_SUBCLASS, + "unflatten_tensor_subclass", + meta=unflatten_tensor_subclass_meta, + python_printer=printer_of_unflatten_tensor_subclass, + _bind_postprocess=bind_postprocess_of_unflatten_tensor_subclass, +) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index c6c8865675..439c9f95b6 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -2111,6 +2111,7 @@ def __setattr__(self, name, value): # TODO: move this function to jit_ext.py def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy: + from torch._subclasses.fake_tensor import FakeTensor from thunder.core.interpreter import ProvenanceRecord, PseudoInst, wrap_const if hasattr(t, "_thunder_device"): @@ -2145,8 +2146,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = else: # NOTE Without tuple(t.shape) then the shape would be a torch.Size object shape = tuple(t.shape) - return TensorProxy( - name, + ctor_kwargs = dict( + name=name, shape=tuple(shape), device=device, dtype=dtype, @@ -2156,6 +2157,39 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = history=history, thunder_fsdp_padding_size=_thunder_fsdp_padding_size, ) + # n.b.(crcrpar): :class:`thunder.dynamo.ThunderCompiler.__call__` takes torch.fx GraphModule + # where `FakeTensor` seems to be used, leading to failures observed in e.g. + # https://github.com/Lightning-AI/lightning-thunder/actions/runs/11689709564/job/32553053319#step:10:5747 + # https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=219328&view=logs&jobId=5b0799f7-725e-5b16-9b83-c0a5a25d03f0&j=5b0799f7-725e-5b16-9b83-c0a5a25d03f0 + if ( + isinstance(t, torch.Tensor) + and type(t) not in (torch.Tensor, torch.nn.Parameter, FakeTensor) + and hasattr(t, "__tensor_flatten__") + and hasattr(t, "__tensor_unflatten__") + ): + baseutils.check( + hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"), + lambda: f"{t=} seems to be a tensor subclass but not traceable", + ) + tensor_attr_names, metadata = t.__tensor_flatten__() + tensors = [tensorproxy(getattr(t, name), name=None, history=history) for name in tensor_attr_names] + ctor_kwargs.update( + { + "tensors": tensors, + "non_tensors": list(metadata.values()), + "subclass_type": type(t), + } + ) + p = SubclassTensorProxy(**ctor_kwargs) + p._tensor_attr_names = tensor_attr_names + p._non_tensor_attr_names = list(metadata.keys()) + for name, tensor in zip(tensor_attr_names, tensors): + setattr(p, name, tensor) + for name, value in metadata.items(): + setattr(p, name, value) + return p + else: + return TensorProxy(**ctor_kwargs) def futuretensorproxy( diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index ce9497125b..abf7a524bf 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -132,6 +132,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.distributed.transforms import FSDPCommBucketing from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops from thunder.executors.passes import del_last_used, transform_for_execution + from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses utils.check(compile_data is not None, lambda: "`compile_data` is required") # NOTE: This function is rather slow, so it's intended to be used @@ -158,6 +159,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat fw_traces = [fw_trace] bw_traces = [bw_trace] + fw_trace = unroll_tensor_subclasses(fw_trace) + fw_traces.append(fw_trace) + from thunder.distributed import FSDPType # only enable rematerialize_params_in_backward when using FSDP ZeRO3 @@ -262,6 +266,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat if getattr(compile_data.fn, "use_fsdp", False): bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) + bw_trace = unroll_tensor_subclasses(bw_trace) + bw_traces.append(bw_trace) + # Now we can run the optimization passes on the backward trace # TODO Restore request for no rematerialization bw_extrace = transform_for_execution( diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 826d9d8a01..e17c6f5841 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2235,13 +2235,13 @@ def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensor def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: - from thunder.core.prims import get_nested_types, filter_types + from thunder.core.prims import get_nested_types, filter_types_for_tensor_wrapper_subclass cls, _name, _shape, _device, _dtype, _requires_grad, _tensors, non_tensors = bsym.args filtered_types = (cls,) if non_tensors: types = get_nested_types(non_tensors) - filtered_types += filter_types(types) + filtered_types += filter_types_for_tensor_wrapper_subclass(types) new_imports = {t.__name__: t for t in filtered_types} bsym._import_ctx.update(new_imports) @@ -2254,3 +2254,47 @@ def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: python_printer=prims.printer_of_tensor_subclass_ctor, ) _register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable) + + +def flatten_tensor_subclass_impl(t): + tensor_attr_names, metadata = t.__tensor_flatten__() + tensors = tuple(getattr(t, name) for name in tensor_attr_names) + return tensors + + +flatten_tensor_subclass = ex.register_operator( + "flatten_tensor_subclass", + meta=prims.flatten_tensor_subclass.meta, + fn=flatten_tensor_subclass_impl, +) +_register_implementation( + prims.flatten_tensor_subclass, + flatten_tensor_subclass, + checker=_always_executable, +) + + +def unflatten_tensor_subclass_impl( + tensor_subclass_type: torch._C._TensorMeta, + inner_tensors: dict[str, TensorLike], + metadata: dict, +): + for key in metadata: + v = metadata[key] + if isinstance(v, dtypes.dtype): + metadata[key] = to_torch_dtype(v) + elif isinstance(v, devices.Device): + metadata[key] = to_torch_device(v) + return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1) + + +unflatten_tensor_subclass = ex.register_operator( + "unflatten_tensor_subclass", + meta=prims.unflatten_tensor_subclass.meta, + fn=unflatten_tensor_subclass_impl, +) +_register_implementation( + prims.unflatten_tensor_subclass, + unflatten_tensor_subclass, + checker=_always_executable, +) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index 7e3bad2460..da5cf1b8ac 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -180,3 +180,55 @@ def g(x: torch.Tensor) -> ScaleTensorSubclass: expected = g(x) actual = jitted(x) torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + +@instantiate( + dtypes=(thunder.core.dtypes.float32,), + decorators=(pytest.mark.parametrize("requires_grad", (False, True), ids=("fwd_only", "with_bwd")),), +) +def test_func_of_subclass_simple_math(executor, device, _, requires_grad): + + def f(x: ScaleTensorSubclass, y: ScaleTensorSubclass) -> torch.Tensor: + out = x + y + return out + + jitted = executor.make_callable(f) + + dtype = torch.float32 + shape = (2, 2) + x = ScaleTensorSubclass( + make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), + make_tensor((), device=device, dtype=dtype), + ) + y = ScaleTensorSubclass( + make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), + make_tensor((), device=device, dtype=dtype), + ) + + expected = f(x, y) + actual = jitted(x, y) + assert type(expected) is type(actual) + torch.testing.assert_close(expected, actual) + if requires_grad: + actual.mean().backward() + + def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + y = EncapsulateXandScale.apply(data, scale) + out = x + y + return out + + jitted = executor.make_callable(g) + + x = ScaleTensorSubclass( + make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad), + make_tensor((), device=device, dtype=dtype), + ) + data = make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad) + scale = make_tensor((), device=device, dtype=dtype) + + expected = g(x, data, scale) + actual = jitted(x, data, scale) + assert type(expected) is type(actual) + torch.testing.assert_close(expected, actual) + if requires_grad: + actual.mean().backward() diff --git a/thunder/transforms/__init__.py b/thunder/transforms/__init__.py index 2ae556c2a0..ffaf8b3652 100644 --- a/thunder/transforms/__init__.py +++ b/thunder/transforms/__init__.py @@ -1,10 +1,12 @@ from .constant_folding import ConstantFolding from .materialization import MaterializationTransform from .qlora import LORATransform +from .tensor_wrapper_subclass import unroll_tensor_subclasses __all__ = [ "ConstantFolding", "LORATransform", "MaterializationTransform", + "unroll_tensor_subclasses", ] diff --git a/thunder/transforms/tensor_wrapper_subclass.py b/thunder/transforms/tensor_wrapper_subclass.py new file mode 100644 index 0000000000..1379a4bd4c --- /dev/null +++ b/thunder/transforms/tensor_wrapper_subclass.py @@ -0,0 +1,739 @@ +from __future__ import annotations +from dataclasses import dataclass +from dataclasses import field +from numbers import Number +from typing import TYPE_CHECKING, NamedTuple +import time +import warnings + +import torch +from torch.fx import Node +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.experimental.proxy_tensor import make_fx +from torch._dispatch.python import enable_python_dispatcher +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.functional_tensor import FunctionalTensorMode +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from thunder.core.baseutils import run_once +from thunder.core.codeutils import SigInfo +from thunder.core import devices +from thunder.core import dtypes +from thunder.core import prims +from thunder.core import utils +from thunder.core.proxies import ProxyInterface +from thunder.core.proxies import SubclassTensorProxy +from thunder.core.proxies import TensorProxy +from thunder.core.proxies import Variable +from thunder.core.proxies import variableify +from thunder.core.pytree import tree_flatten +from thunder.core.pytree import tree_map +from thunder.core.pytree import tree_unflatten +from thunder.core.trace import TraceCtx +from thunder.core.trace import TraceProvenance +from thunder.core.trace import from_trace +from thunder.core.trace import tracectx + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Any + from optree import PyTreeSpec + from torch.fx import GraphModule + from torch._ops import OpOverload + from thunder.core.symbol import Symbol, BoundSymbol + + +__all__ = [ + "unroll_tensor_subclasses", +] + + +PLACEHOLDER: str = "placeholder" +CALL_FUNCTION: str = "call_function" +OUTPUT: str = "output" + + +@run_once +def warn_tensor_subclass_support() -> None: + warnings.warn("Tensor Subclasses with `__torch_dispatch__` defined support is experimental") + + +class OutputWrapperForFxTracing(NamedTuple): + inner_tensors: dict[str, torch.Tensor] | torch.Tensor + metadata: dict[str, Any] | None + + +def _materialize_tensor_proxy(t: TensorProxy, fake_tensor_mode: FakeTensorMode | None) -> torch.Tensor: + shape = t.shape + device = devices.to_torch_device(t.device) + dtype = dtypes.to_torch_dtype(t.dtype) + requires_grad = t.requires_grad + + with torch.device("meta"): + t = torch.empty(shape, dtype=dtype, requires_grad=requires_grad) + if fake_tensor_mode is None: + return t + fakified_empty_tensor = fake_tensor_mode.fake_tensor_converter.from_meta_and_device( + fake_mode=fake_tensor_mode, t=t, device=device + ) + return fakified_empty_tensor + + +def _make_fake_subclass_tensor_from_subclass_tensor_proxy( + tensor_proxy: SubclassTensorProxy, + fake_tensor_mode: FakeTensorMode, +) -> torch.Tensor: + utils.check( + (subclass_type := getattr(tensor_proxy, SubclassTensorProxy.SUBCLASS_TYPE_ATTR, None)) is not None, + lambda: f"{tensor_proxy} does not have `{SubclassTensorProxy.SUBCLASS_TYPE_ATTR}`", + ) + utils.check( + tensor_proxy._tensors, + lambda: f"{tensor_proxy} has an empty `{tensor_proxy._tensors=}`", + ) + tensor_attr_names = tensor_proxy._tensor_attr_names + non_tensor_attr_names = tensor_proxy._non_tensor_attr_names + inner_tensors = dict( + zip( + tensor_attr_names, + [_materialize_tensor_proxy(t, fake_tensor_mode=fake_tensor_mode) for t in tensor_proxy._tensors], + ) + ) + new_non_tensors = [] + for a in tensor_proxy._non_tensors: + if isinstance(a, dtypes.dtype): + new_non_tensors.append(dtypes.to_torch_dtype(a)) + elif isinstance(a, devices.Device): + new_non_tensors.append(devices.to_torch_device(a)) + else: + new_non_tensors.append(a) + metadata = dict(zip(non_tensor_attr_names, new_non_tensors)) + subclass_tensor = subclass_type.__tensor_unflatten__( + inner_tensors, + metadata, + outer_size=-1, + outer_stride=-1, + ) + fakified = fake_tensor_mode.from_tensor(subclass_tensor, static_shapes=True) + return fakified + + +def materialize_tensor_proxy( + t: TensorProxy | SubclassTensorProxy, + fake_tensor_mode: FakeTensorMode, +) -> torch.Tensor: + if isinstance(t, SubclassTensorProxy): + return _make_fake_subclass_tensor_from_subclass_tensor_proxy(t, fake_tensor_mode) + return _materialize_tensor_proxy(t, fake_tensor_mode) + + +def maybe_materialize_tensor( + t: ProxyInterface, + fake_tensor_mode: FakeTensorMode, +) -> ProxyInterface | torch.Tensor: + if isinstance(t, (TensorProxy, SubclassTensorProxy)): + return materialize_tensor_proxy(t, fake_tensor_mode) + if isinstance(t, (Number, str)): + return t + return t.value + + +def proxy_fake_tensor(t: torch.Tensor | FakeTensor) -> ProxyInterface: + if isinstance(t, FakeTensor) or (isinstance(t, torch.Tensor) and not issubclass(type(t), torch.Tensor)): + return TensorProxy( + None, + shape=list(t.shape), + dtype=dtypes.to_dtype(t.dtype), + device=devices.to_device(t.device), + requires_grad=t.requires_grad, + ) + if torch.utils._python_dispatch.is_traceable_wrapper_subclass(t): + tensor_attr_names, metadata = t.__tensor_flatten__() + tensor_proxies = [proxy_fake_tensor(getattr(t, name)) for name in tensor_attr_names] + non_tensor_attr_names = list(metadata.keys()) + non_tensors = list(metadata.values()) + p = SubclassTensorProxy( + None, + shape=list(t.shape), + dtype=dtypes.to_dtype(t.dtype), + device=devices.to_device(t.device), + requires_grad=t.requires_grad, + tensors=tensor_proxies, + non_tensors=non_tensors, + subclass_type=type(t), + ) + p._tensor_attr_names = tensor_attr_names + p._non_tensor_attr_names = non_tensor_attr_names + for name, value in zip(tensor_attr_names + non_tensor_attr_names, tensor_proxies + non_tensors): + setattr(p, name, value) + return p + return t + + +def trace_from_bsym_or_bsyms(bsym_or_bsyms: BoundSymbol | Sequence[BoundSymbol]) -> TraceCtx: + from thunder.core.compile_data import get_compile_data + + cd = get_compile_data() + ad_hoc_executor = None + if cd is not None: + from thunder.extend import AdHocExecutor + + executors_list = list(filter(lambda t: isinstance(t, AdHocExecutor), cd.executors_list)) + if executors_list: + ad_hoc_executor = executors_list[0] + + bsyms = list(utils.sequencify(bsym_or_bsyms)) + trace_args = bsyms[0].flat_proxy_args + trace_name = bsyms[0].sym.name + + if ad_hoc_executor is not None and ad_hoc_executor._implmap: + tmp_bsyms = [] + for bsym in bsyms: + if ad_hoc_executor.can_execute(bsym) and bsym.subsymbols: + tmp_bsyms.extend(bsym.subsymbols) + else: + tmp_bsyms.append(bsym) + bsyms = tmp_bsyms + unpack_bsyms = [ + prims.unpack_trivial.bind(a, name=a.name, output=a) + for a in filter(lambda a: isinstance(a, ProxyInterface), trace_args) + ] + + trace = TraceCtx() + trace.bound_symbols.extend(unpack_bsyms + bsyms) + trace.args = trace_args + with tracectx(trace): + prims.python_return(bsyms[-1].output) + with tracectx(trace): + # note(crcrpar): Give prefix `tmp` to avoid infinite recursion due to the same name + trace._siginfo = SigInfo.from_name_and_args(f"tmp_{trace_name}", trace.args) + return trace + + +def make_trace_executable(trace_to_convert: TraceCtx, *args_for_eval, **kwargs_for_eval): + from functools import wraps + from thunder import trace + from thunder.core.transforms import eval_trace + from thunder.executors.torch_compile import to_torch_translator + + @wraps(trace_to_convert.python_callable()) + def torch_interpreted_func(*args, **kwargs): + return eval_trace(trace_to_convert, *args, **kwargs, symbol_mapper=to_torch_translator) + + torch_trace = trace(inline_trace=False)(torch_interpreted_func, *args_for_eval, **kwargs_for_eval) + return torch_trace + + +@dataclass +class DesugarTensorSubclass: + computation_trace: TraceCtx + swap_map: dict[Variable, ProxyInterface] = field(init=False, default_factory=dict) + fake_tensor_mode: FakeTensorMode = field(init=False, default_factory=FakeTensorMode) + flat_trace_args: Sequence[ProxyInterface] = field(init=False, default=None) + subclass_proxy_to_flatten: set[Variable] = field(init=False, default_factory=set) + bsym_to_new_outputs: dict[BoundSymbol, list[TensorProxy]] = field(init=False, default_factory=dict) + + def __post_init__(self) -> None: + # Check if this trace is backward trace + is_backward_trace: bool = False + if len(self.computation_trace.bound_symbols) > 6: + maybe_unpack_C0_bsym = self.computation_trace.bound_symbols[4] + maybe_unpack_C1_bsym = self.computation_trace.bound_symbols[5] + is_backward_trace = ( + maybe_unpack_C0_bsym.args + and maybe_unpack_C1_bsym.args + and ( + maybe_unpack_C0_bsym.sym.id, + maybe_unpack_C1_bsym.sym.id, + getattr(maybe_unpack_C0_bsym.args[0], "name", ""), + getattr(maybe_unpack_C1_bsym.args[0], "name", ""), + ) + == ( + prims.PrimIDs.UNPACK_SEQUENCE, + prims.PrimIDs.UNPACK_SEQUENCE, + "C0", + "C1", + ) + ) + if is_backward_trace: + self.flat_trace_args, _ = tree_flatten((maybe_unpack_C0_bsym.output, maybe_unpack_C1_bsym.output)) + if not is_backward_trace: + self.flat_trace_args, _ = tree_flatten((self.computation_trace.args, self.computation_trace.kwargs)) + for arg in self.flat_trace_args: + if isinstance(arg, SubclassTensorProxy): + self.subclass_proxy_to_flatten.add(variableify(arg)) + + def _get_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: + return p._tensor_attr_names + + def _get_non_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: + return p._non_tensor_attr_names + + def translate_fx_graph_into_bsym( + self, + bsym: BoundSymbol, + fx_graph: GraphModule, + ) -> BoundSymbol | tuple[BoundSymbol, ...]: + import thunder.torch as ltorch + from thunder.torch import _torch_to_thunder_function_map + + unwrapped_bsym_args: dict[int, ProxyInterface] = {} + list_of_flattening_bsyms: list[BoundSymbol] = [] + for a in bsym.flat_args: + if isinstance(a, SubclassTensorProxy): + if variableify(a) in self.subclass_proxy_to_flatten: + self.computation_trace.push_scope([]) + with tracectx(self.computation_trace): + prims.flatten_tensor_subclass(a) + flattening_bsym = self.computation_trace.pop_scope()[0] + list_of_flattening_bsyms.append(flattening_bsym) + tensor_attr_names = self._get_tensor_attr_names(a) + tensors = a._tensors + + non_tensor_attr_names = self._get_non_tensor_attr_names(a) + non_tensors = a._non_tensors + metadata = dict(zip(non_tensor_attr_names, non_tensors)) + for name, t in zip(tensor_attr_names, tensors): + utils.check( + isinstance(t, TensorProxy), + lambda: f"{a=}, {tensor_attr_names = }, {tensors=}", + ) + unwrapped_bsym_args[len(unwrapped_bsym_args)] = t + # TODO(crcrpar): Think about how to verify the correctness of this flattening + flat_metadata, _ = tree_flatten(metadata) + for v in flat_metadata: + unwrapped_bsym_args[len(unwrapped_bsym_args)] = v + else: + if not isinstance(a, ProxyInterface): + from thunder.core.proxies import proxy + + with tracectx(self.computation_trace): + a = proxy(a) + unwrapped_bsym_args[len(unwrapped_bsym_args)] = a + + node: Node + list_of_placeholder_node: list[Node] = [] + list_of_function_call_node: list[Node] = [] + node_of_output: Node + for node in fx_graph.graph.nodes: + if node.op == PLACEHOLDER: + list_of_placeholder_node.append(node) + if node.op == CALL_FUNCTION: + list_of_function_call_node.append(node) + if node.op == OUTPUT: + node_of_output = node + args = [n.target for n in list_of_placeholder_node] + arg_name_to_index = {a: i for i, a in enumerate(args)} + ltorch_ops_for_node_of_ops = [] + for node in list_of_function_call_node: + op: OpOverload = node.target + if op not in _torch_to_thunder_function_map: + msg = ( + f"`thunder.torch` does not have corresponding op for {op}. " + "Think about adding it to thunder/torch/default_torch_ops.py" + f"\nThe op is found while flattening the following BoundSymbol:\n{bsym}" + f"\ntorch.fx graph:\n{fx_graph.print_readable(print_output=False)}" + ) + raise RuntimeError(msg) + ltorch_ops_for_node_of_ops.append(_torch_to_thunder_function_map[op]) + + bsyms: list[BoundSymbol] = [] + if list_of_flattening_bsyms: + bsyms.extend(list_of_flattening_bsyms) + fxnode_output_name_to_tensor_proxy: dict[str, OpOverload] = {} + for node, ltorch_op in zip(list_of_function_call_node, ltorch_ops_for_node_of_ops): + args: list[Node] = node.args + + arg_proxies: list[ProxyInterface] = [] + for a in args: + if isinstance(a, Node): + if isinstance(a.target, str): + arg_proxies.append(unwrapped_bsym_args[arg_name_to_index[a.target]]) + else: + arg_proxies.append(fxnode_output_name_to_tensor_proxy[str(a)]) + else: + if isinstance(a, immutable_dict): + arg_proxies.append(dict(a)) + elif isinstance(a, immutable_list): + arg_proxies.append(list(a)) + else: + arg_proxies.append(a) + + self.computation_trace.push_scope([]) + + try: + with tracectx(self.computation_trace): + out = ltorch_op(*arg_proxies) + except Exception as e: + msg = ( + f"Failing to map `torch.{node}` to `thunder.torch` op of " + f"{ltorch_op} with args of {arg_proxies}\n" + f"BoundSymbol in question is\n```python\n{bsym}\n```\n" + f"Corresponding torch.fx Graph is\n```python\n{fx_graph.print_readable(print_output=False)}\n```\n" + f"Original error is {e}" + ) + raise type(e)(msg) + else: + fxnode_output_name_to_tensor_proxy[str(node)] = out + bsyms.extend(self.computation_trace.pop_scope()) + if len(bsyms) == 0: + return [bsym] + + orig_output = bsym.flat_outs[0] + if is_subclass_ctor_bsym := bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR: + utils.check_type(orig_output, SubclassTensorProxy) + if isinstance(orig_output, SubclassTensorProxy): + # note(crcrpar): args[0] would be list of tensors, and args[1] could be list of non-tensors. + args: list[Node] = node_of_output.args[0] + new_tensor_proxies = [] + new_non_tensor_values = [] + for a in args: + value = a + if isinstance(a, Node): + if isinstance(a.target, str): + value = unwrapped_bsym_args[arg_name_to_index[a.target]] + else: + value = fxnode_output_name_to_tensor_proxy[str(a)] + if isinstance(value, TensorProxy): + new_tensor_proxies.append(value) + elif isinstance(value, (immutable_dict, immutable_list)): + if isinstance(value, immutable_dict): + new_non_tensor_values.append(dict(value)) + else: + new_non_tensor_values.append(list(v)) + else: + new_non_tensor_values.append(value) + utils.check( + len(orig_output._tensors) == len(new_tensor_proxies), + lambda: ( + f"The number of new tensor proxies for {orig_output=} does not match: " + f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}" + ), + ) + with tracectx(self.computation_trace): + new_subclass = orig_output.replace() + new_subclass._tensors = new_tensor_proxies + for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies): + setattr(new_subclass, name, value) + bsyms.append( + prims.unflatten_tensor_subclass.bind( + new_subclass._subclass_type, + dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)), + dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)), + output=new_subclass, + ) + ) + + self.swap_map[variableify(orig_output)] = new_subclass + self.subclass_proxy_to_flatten.add(variableify(new_subclass)) + + else: + non_none_args = [n for n in node_of_output.args[0] if n is not None] + utils.check(len(non_none_args) == 1, lambda: f"{node_of_output.args = }") + new_out_node = non_none_args[0] + self.swap_map[variableify(orig_output)] = fxnode_output_name_to_tensor_proxy[str(new_out_node)] + + args = ", ".join([t.name if isinstance(t, ProxyInterface) else f"{t}" for t in bsym.flat_args]) + header = f"{bsym.sym.id}({args})" + for i, sbsym in enumerate(bsyms, 1): + sbsym.header = f"[{i}/{len(bsyms)}] unrolled `__torch_dispatch__` of `{header}`" + return bsyms + + def convert_trace_to_fx_graph_and_get_fake_result( + self, + trace: TraceCtx, + ) -> tuple[GraphModule, tuple[OutputWrapperForFxTracing, ...], tuple[torch.Tensor, ...], PyTreeSpec]: + def create_ctor(unflatten_method, tensor_names): + def ctor(tensors, metadata): + inner_tensors = dict(zip(tensor_names, tensors)) + return unflatten_method(inner_tensors, metadata, -1, -1) + + return ctor + + args = tree_map( + lambda t: maybe_materialize_tensor( + t, + self.fake_tensor_mode, + ), + trace.args, + ) + desugared_args = [] + arg_idx_to_sugar: dict[int, tuple[int, Any]] = {} + for a in args: + if is_traceable_wrapper_subclass(a): + start_idx = len(desugared_args) + attrs, metadta = a.__tensor_flatten__() + desugared_args.extend([getattr(a, name) for name in attrs]) + desugared_args.append(metadta) + end_idx = len(desugared_args) + arg_idx_to_sugar[start_idx] = end_idx, create_ctor(type(a).__tensor_unflatten__, attrs) + else: + desugared_args.append(a) + + out_specs: list[Any] = [] + orig_output: list[torch.Tensor] = [] + + def transform_out(out: torch.Tensor) -> OutputWrapperForFxTracing: + orig_output.append(out) + if is_traceable_wrapper_subclass(out): + from enum import Enum + + attrs, metadata = out.__tensor_flatten__() + tensors = [getattr(out, name) for name in attrs] + for key in metadata: + v = metadata[key] + if issubclass(type(v), Enum) and not isinstance(v, (torch.dtype, torch.device)): + metadata[key] = str(metadata[key]) + output = OutputWrapperForFxTracing(dict(zip(attrs, tensors)), metadata) + else: + output = OutputWrapperForFxTracing(out, None) + return output + + desugared_proxy_args = [] + for a in trace.args: + if isinstance(a, SubclassTensorProxy): + names, metadata = a.__tensor_flatten__() + desugared_proxy_args.extend([getattr(a, name) for name in names]) + desugared_proxy_args.append(metadata) + else: + desugared_proxy_args.append(a) + + extrace = make_trace_executable(trace, *trace.args, **trace.kwargs) + utils.check( + (len(extrace.bound_symbols) == len(trace.bound_symbols)) + or ( + len(extrace.bound_symbols) == len(trace.bound_symbols) - 1 + and any(bsym.sym.id == prims.PrimIDs.SHALLOW_COPY for bsym in trace.bound_symbols) + ), + lambda: ( + f"Input trace is\n{trace}\nExecution trace is\n{extrace}\n" + f"Input has {len(trace.bound_symbols)} syms but execution trace has {len(extrace.bound_symbols)}" + ), + ) + f = extrace.python_callable(include_decorators=False) + + def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing, ...]: + args = [] + cur_idx = 0 + while cur_idx < len(desugared_args): + if cur_idx in arg_idx_to_sugar: + end_idx, construct_subclass = arg_idx_to_sugar[cur_idx] + args_of_subclass = desugared_args[cur_idx:end_idx] + tensors = args_of_subclass[:-1] + metadata = args_of_subclass[-1] + subclass = construct_subclass(tensors, metadata) + args.append(subclass) + + cur_idx = end_idx + else: + args.append(desugared_args[cur_idx]) + cur_idx += 1 + + out = f(*args) + # Specialcasing the output of initial computation trace + if isinstance(out, dict) and len(out) == 2 and ("output", "flat_args") == tuple(out.keys()): + sequencified_out = out + else: + sequencified_out = utils.sequencify(out) + flat_out, out_spec = tree_flatten(sequencified_out) + out_specs.append(out_spec) + flat_cosmeticized_out = tree_map(transform_out, flat_out) + return tree_unflatten(flat_cosmeticized_out, out_spec) + + with ( + enable_python_dispatcher(), + FunctionalTensorMode( + pre_dispatch=False, + export=False, + _allow_token_discovery=True, + ), + ): + fx: GraphModule = make_fx(f_with_wrap_and_unwrap)(*desugared_args) + + return fx, fx(*desugared_args), tuple(orig_output), out_specs[0] + + def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: + updated_bsym: BoundSymbol = bsym.from_bsym_swap_proxies(self.swap_map) + if bsym.sym.id == prims.PrimIDs.RETURN: + new_swap_map = {} + for k, v in self.swap_map.items(): + if isinstance(v, SubclassTensorProxy): + continue + new_swap_map[k] = v + if not self.subclass_proxy_to_flatten or True: + return [updated_bsym] + + is_bsym_of_subclass_ctor = bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR + returns_subclass = any(isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_outs) + no_subclass_args = all(not isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_args) + is_unpack = bsym.sym.id in {prims.PrimIDs.UNPACK_TRIVIAL, prims.PrimIDs.UNPACK_SEQUENCE} + is_subclass_ctor = is_bsym_of_subclass_ctor or (no_subclass_args and returns_subclass and not is_unpack) + if not is_subclass_ctor and no_subclass_args: + return [updated_bsym] + + utils.check( + len(updated_bsym.flat_outs) < 2, + lambda: f"bsym has {len(updated_bsym.flat_outs)} outputs", + exception_type=NotImplementedError, + ) + + trace = trace_from_bsym_or_bsyms(updated_bsym) + fx, sequencified_cosmeticized_out, orig_output, _ = self.convert_trace_to_fx_graph_and_get_fake_result(trace) + utils.check( + len(sequencified_cosmeticized_out) == len(orig_output), + lambda: f"{len(sequencified_cosmeticized_out)=}, {len(orig_output)=}", + ) + if is_subclass_ctor: + utils.check(len(sequencified_cosmeticized_out) == 1 and len(orig_output) == 1, lambda: "") + fake_tensor_subclass = orig_output[0] + subclass_proxy = updated_bsym.flat_outs[0] + tensor_attr_names, metadata = fake_tensor_subclass.__tensor_flatten__() + subclass_proxy._tensor_attr_names = tensor_attr_names + subclass_proxy._non_tensor_attr_names = list(metadata.keys()) + self.subclass_proxy_to_flatten.add(variableify(subclass_proxy)) + for name, value in zip( + tensor_attr_names + subclass_proxy._non_tensor_attr_names, + subclass_proxy._tensors + subclass_proxy._non_tensor_attr_names, + ): + setattr(subclass_proxy, name, value) + return [updated_bsym] + + out = [] + for i, (cosmeticized_out, orig_out) in enumerate(zip(sequencified_cosmeticized_out, orig_output)): + if isinstance(cosmeticized_out.inner_tensors, dict): + utils.check( + is_traceable_wrapper_subclass(orig_out), lambda: f"{cosmeticized_out=} don't match {orig_out=}" + ) + out.append(orig_out) + else: + out.append(orig_out) + + with tracectx(self.computation_trace): + out_proxy = tree_map(proxy_fake_tensor, out) + + utils.check( + len(updated_bsym.flat_outs) == len(out_proxy), + lambda: f"{len(bsym.flat_outs)=}, {len(out_proxy)=}, {out_proxy=}, {bsym.flat_outs=}", + ) + sequence_out = [variableify(a) for a in updated_bsym.flat_outs] + self.swap_map.update(dict(zip(sequence_out, utils.sequencify(out_proxy)))) + + bsym_with_modified_output = updated_bsym.from_bsym_swap_proxies(self.swap_map) + self.bsym_to_new_outputs[bsym_with_modified_output] = bsym_with_modified_output + return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx) + + +def tensor_subclass_dce(trace: TraceCtx) -> TraceCtx: + """Remove ``tensor.__tensor_flatten__``s as possible. + + This function tries to remove flattening of tensor subclass + by replacing their outputs with tensor args of ``tensor``\'s constructor, + either '`TensorSubclass(...)` or `TensorSubclass.__tensor_unflatten__(...)`. + + This function does not remove ``TensorSubclass(...)`` nor ``TensorSubclass.__tensor_unflatten__(...)`` + as they could be a saved tensor for backward. + """ + start_time_ns = time.perf_counter_ns() + swap_map: dict[Variable, TensorProxy] = {} + producer_map = utils.producers(trace) + bsym_to_exclude: set[BoundSymbol] = set() + + subclass_flatten_bsym: BoundSymbol + for subclass_flatten_bsym in filter( + lambda bsym: bsym.sym.id == prims.PrimIDs.FLATTEN_TENSOR_SUBCLASS, + trace.bound_symbols, + ): + subclass_tensor_proxy: SubclassTensorProxy = subclass_flatten_bsym.flat_args[0] + flatten_tensors: tuple[TensorProxy, ...] = subclass_flatten_bsym.output + ctor_bsym: BoundSymbol = producer_map[subclass_tensor_proxy] + match ctor_bsym.sym.id: + case prims.PrimIDs.TENSOR_SUBCLASS_CTOR: + ctor_tensors: list[TensorProxy] = ctor_bsym.args[6] + case prims.PrimIDs.UNFLATTEN_TENSOR_SUBCLASS: + ctor_tensors: list[TensorProxy] = list(ctor_bsym.args[1].values()) + case _: + continue + utils.check( + len(flatten_tensors) == len(ctor_tensors), + lambda: f"{flatten_tensors} and {ctor_tensors} have different number of tensors", + ) + + for k, v in zip(flatten_tensors, ctor_tensors): + if k.name == v.name: + continue + swap_map[variableify(k)] = v + bsym_to_exclude.add(subclass_flatten_bsym) + + if not swap_map: + return trace + + new_bsyms: list[BoundSymbol] = [] + bsym: BoundSymbol + for bsym in trace.bound_symbols: + if bsym in bsym_to_exclude: + continue + new_bsyms.append(bsym.from_bsym_swap_proxies(swap_map, skip_output=True)) + + new_trace = from_trace(trace) + new_trace.bound_symbols = new_bsyms + end_time_ns = time.perf_counter_ns() + elapsed_time_ns = end_time_ns - start_time_ns + elapsed_time_millis = elapsed_time_ns // 1000000 + new_trace.set_provenance( + TraceProvenance(f"DCE of Tensor Subclass Flattening/Unflattening (took {elapsed_time_millis} milliseconds)") + ) + + return new_trace + + +def unroll_tensor_subclasses(trace: TraceCtx) -> TraceCtx: + """Unroll tensor subclasses in ``computation_trace``. + + Two things are happening inside of this function: + * Reevaluate every single bsym of ``computation_trace.bound_symbols``. + * Flatten tensor subclasses + + Each :class:`thunder.core.symbol.BoundSymbol` is reevaluated with torch.fx tracing and + ``FakeTensorMode``. This is necessary because Thunder's initial trace cannot correctly infer the output + type of an op with tensor subclasses. By translating each bsym into a callable and tracing it with + ``torch.fx`` and ``FakeTensorMode``, we can tell the output type and the exact behavior of the bsym + which is extended by subclass's ``__torch_dispatch__`` (note that the sequence of observed operations + are free from tensor subclasses, everything is flattened). + The output type information is then reflected to the output :class:`thunder.core.proxies.Proxy`. + + With this function applied, the :class:`thunder.core.trace.TraceCtx` is free from tensor subclasses. + Exceptions are prologue (meaning the first few lines of the trace, before any math) and epilogue (meaning + the last few lines of the trace, right before return statement). + + Args: + trace: + + Returns: + TraceCtx: transformed trace that is free from tensor subclasses, every ``__torch_dispatch__`` + behavior is spelled out. + """ + start_time_ns = time.perf_counter_ns() + + desugar_tensor_subclass = DesugarTensorSubclass(computation_trace=trace) + updated_bsyms: list[BoundSymbol] = [] + bsym: BoundSymbol + for bsym in trace.bound_symbols: + maybe_desugared_bsyms = desugar_tensor_subclass(bsym) + updated_bsyms.extend(maybe_desugared_bsyms) + + if not desugar_tensor_subclass.subclass_proxy_to_flatten: + return trace + + end_time_ns = time.perf_counter_ns() + elapsed_time_ns = end_time_ns - start_time_ns + elapsed_time_millis = elapsed_time_ns // 1000000 + + computation_trace_with_subclass_tensor_unrolled = from_trace(trace) + computation_trace_with_subclass_tensor_unrolled.bound_symbols.extend(updated_bsyms) + computation_trace_with_subclass_tensor_unrolled.set_provenance( + TraceProvenance(f"tensor subclasses unrolled (took {elapsed_time_millis} milliseconds)") + ) + dced_computation_trace = tensor_subclass_dce(computation_trace_with_subclass_tensor_unrolled) + warn_tensor_subclass_support() + return dced_computation_trace From 1053d3c33c9311f31ae50e623dd98a50aa80a0c5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 12:24:21 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/torch/__init__.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 34245270b8..6854903774 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3143,18 +3143,15 @@ def amin(a, /, dim=None, keepdim: bool = False): # NOTE: Using name `torch_max` to avoid conflict with Python's `max` @overload -def torch_max(a: TensorLike, /) -> TensorLike: - ... +def torch_max(a: TensorLike, /) -> TensorLike: ... @overload -def torch_max(a: TensorLike, /, dim: NumberLike, keepdim: bool = False) -> tuple[TensorLike, TensorLike]: - ... +def torch_max(a: TensorLike, /, dim: NumberLike, keepdim: bool = False) -> tuple[TensorLike, TensorLike]: ... @overload -def torch_max(a: TensorLike, b: TensorLike, /) -> TensorLike: - ... +def torch_max(a: TensorLike, b: TensorLike, /) -> TensorLike: ... @torchsymbol(torch.max, is_method=True, method_name="max", id="torch.max")