Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding shape prim #1113

Merged
merged 12 commits into from
Sep 11, 2024
12 changes: 12 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class PrimIDs(Enum):
CONSTRUCT_TUPLE = auto()
PACK_BUFFER = auto()
PACK_SETITEM = auto()
SHAPE = auto()
# TODO: UNPACK_SET
# Utility prims
COMMENT = auto()
Expand Down Expand Up @@ -1238,6 +1239,17 @@ def pack_setitem_impl(o: Any, key: Any, v: Any) -> None:
)


def shape_meta(t: TensorProxy) -> Sequence[int | NumberProxy]:
return t._shape


shape = make_prim(
PrimIDs.SHAPE,
"shape",
meta=shape_meta,
)


# NOTE UNPACK_GETITEM is intended only to be bound to directly, and not called
def unpack_getitem_meta(o: Any, key: Any) -> Any:
raise NotImplementedError
Expand Down
3 changes: 3 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,9 @@ def _maximum_grad(a: TensorProxy, b: TensorProxy, /):
# This operation creates no grad associations
register_grad(pids.ARGMAX, prims.argmax)

# This operation creates no grad associations
register_grad(pids.SHAPE, prims.shape)

#
# Phantom grad transform helpers
#
Expand Down
26 changes: 26 additions & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,6 +2240,32 @@ def matmul(
register_supported(PrimIDs.MATMUL, matmul, _matmul_check)


def _shape_check(
a: TensorProxy,
) -> bool:
# TODO: currently we cannot support this yet. fusion_pass needs to be
# updated to ensure that the fused region consumes all NumberProxy within
# and not leak it out as a fusion output, since nvfuser cannot yet produce
# scalar outputs.
return False


def shape(
a: TensorProxy,
*,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:
nva = getnv(a, fd, lc_to_nv_map)
ret = []
for i in range(a.ndim):
ret.append(fd.ops.size(nva, i))
return ret


register_supported(PrimIDs.SHAPE, shape, _shape_check)


# Registering SDPA operators for nvFuser
# SDPA requires an execution and grad transform since the forward and backward passes are called through different implementations.
# For both execution and grad transform, a new operator is registered with nvfuserex (ex.register_operator) and then added to the translation map (register_supported).
Expand Down
1 change: 0 additions & 1 deletion thunder/executors/pythonex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch

import thunder.core.prims as prims
from thunder.core.prims import PrimIDs
from thunder.core.proxies import NumberProxy, NumberLike, TensorProxy, CollectionProxy
from thunder.core.symbol import Symbol, BoundSymbol
from thunder.core import baseutils
Expand Down
9 changes: 8 additions & 1 deletion thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import thunder.core.devices as devices
from thunder.core.devices import to_torch_device, to_device
import thunder.core.prims as prims
from thunder.core.prims import PrimIDs
from thunder.core.trace import TraceCtx, set_tracectx, reset_tracectx, from_trace
from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, variableify, pytype
from thunder.core.pytree import tree_flatten, tree_unflatten
Expand Down Expand Up @@ -2128,3 +2127,11 @@ def _copy__impl(copy_from, copy_to):

copy_ = ex.register_operator("copy_", meta=prims.copy_, tags=(prims.OpTags.DONT_DCE,), fn=_copy__impl)
_register_implementation(prims.copy_, copy_, checker=_always_executable)


def _shape_impl(t):
return t.shape


shape = ex.register_operator("shape", meta=prims.shape_meta, fn=_shape_impl)
_register_implementation(prims.shape, shape, checker=_always_executable)
7 changes: 4 additions & 3 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,11 @@ def is_floating_point(a: TensorLike, /) -> bool:


# Handles the size method
@torchsymbol(torch.Tensor.size)
t-vi marked this conversation as resolved.
Show resolved Hide resolved
def size(a: TensorLike, /, dim: None | int = None) -> int | Sequence[int]:
if dim is not None:
return a.shape[dim]
return a.shape
return prims.shape(a)[dim]
return prims.shape(a)


register_method("size", size)
Expand Down Expand Up @@ -1286,7 +1287,7 @@ def transpose(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike:
@torchsymbol(torch.unbind, is_method=True)
def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]:
utils.check(
len(a.size()) > 0,
a.ndim > 0,
t-vi marked this conversation as resolved.
Show resolved Hide resolved
lambda: f"Dimension specified as={dim} but tensor has no dimensions.",
)
return tuple(s.squeeze(dim) for s in tensor_split(a, a.shape[dim], dim))
Expand Down
Loading