Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor wrapper subclass] Add Proxy, prim, and lookaside #1583

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
StringProxy,
TensorProxy,
FutureTensorProxy,
SubclassTensorProxy,
make_proxy_name,
Variable,
variableify,
Expand Down Expand Up @@ -930,6 +931,42 @@ def thunder_function(*args, **kwargs):
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)


@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass)
def _make_wrapper_subclass(
cls: torch._C._TensorMeta,
size: Sequence[int],
strides: Sequence[int] | None = None,
storage_offset: int | None = None,
memory_format: torch.memory_format | None = None,
dtype: torch.dtype | None = None,
layout: torch.layout | None = torch.strided,
device: torch.device | None = None,
pin_memory: bool = False,
requires_grad: bool = False,
dispatch_sizes_strides_policy: str | None = None,
dispatch_device: bool = False,
dispatch_layout: bool = False,
_extra_dispatch_keys: torch.DispatchKeySet | None = None,
storage_size: int | None = None,
):
ucls = unwrap(cls)
usize = unwrap(size)
udtype = unwrap(dtype)
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)

subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))


# Adds proxy methods
# NOTE These methods map to themselves, which prevents the interpreter from looking into them
# This is OK because these methods are written in a tracing-safe manner, and trying to
Expand Down
151 changes: 149 additions & 2 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from __future__ import annotations
from enum import auto, Enum
from numbers import Number
from functools import reduce, wraps
import operator
import builtins
import math
from types import NoneType
from typing import Union, Type, Any, List, Dict, Tuple, Optional
from typing import Union, Type, Any, List, Dict, Tuple, Optional, TYPE_CHECKING
from collections.abc import Callable
from collections.abc import Callable, Hashable, Sequence

import torch

from thunder.core.langctxs import LanguageContext, register_langctx, Languages, langctx

if TYPE_CHECKING:
from collections.abc import Iterable
from thunder.core.codeutils import ContextObject

#
# Creates and registers the torch language context
#
Expand Down Expand Up @@ -77,6 +82,7 @@ def register_method(method_name: str, method: Callable, /) -> None:
TupleProxy,
AnyProxy,
IntegerProxy,
SubclassTensorProxy,
)
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
Expand Down Expand Up @@ -272,6 +278,8 @@ class PrimIDs(Enum):
COPY_ = auto()
#
SINK = auto()
# Tensor Subclasses methods
TENSOR_SUBCLASS_CTOR = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -3543,7 +3551,10 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro
view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,))


def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy:
def shallow_copy_meta(a: TensorProxy | SubclassTensorProxy, /) -> TensorProxy:
if isinstance(a, SubclassTensorProxy):
# SubclassTensorProxy(like=...) would not copy some attrs such as `_tensors` while replace does.
return a.replace()
return TensorProxy(like=a)


Expand Down Expand Up @@ -4048,3 +4059,139 @@ def sink_meta(*args, **kwargs):

# TODO do we want another tag to remove this after prologue is constructed?
sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,))


def tensor_subclass_ctor_meta(
cls, name, shape, device, dtype, requires_grad, tensors, non_tensors
) -> SubclassTensorProxy:
s = SubclassTensorProxy(
name,
subclass_type=cls,
shape=shape,
device=device,
dtype=dtype,
requires_grad=requires_grad,
tensors=tensors,
non_tensors=non_tensors,
history=[t.history for t in tensors],
)
return s


def get_nested_types(collection):
collection = utils.sequencify(collection)
types_set = {type(t) for t in collection}

def check_types(coll):
for item in coll:
types_set.add(type(item))
# Check if the item is a nested collection
if baseutils.is_collection(item):
# If it's a dictionary, check its values
if isinstance(item, dict):
check_types(item.values())
# Recursively check nested collections
else:
check_types(item)

check_types(collection)
return tuple(types_set)


def filter_types(types: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(
filter(
lambda t: (
t.__module__ != "builtins"
and t != Number
# note(crcrpar): maybe `thunder.core`?
and not t.__module__.startswith("thunder.")
and not t.__module__.startswith("torch.")
),
types,
)
)


def printer_of_tensor_subclass_ctor(
bsym: BoundSymbol,
out_printables: Any,
arg_printables: Sequence[Printable],
kwarg_printables: dict[str, Printable],
) -> str | Iterable[str]:
from itertools import chain

baseutils.check(not kwarg_printables, lambda: f"No kwargs are supported but {kwarg_printables = }")

# NOTE(crcrpar): It's not a context but at the moment Tensor subclass is treated as `ContextObject`.
wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0]
if isinstance(wrapped_cls, torch._C._TensorMeta):
cls = wrapped_cls
else:
cls: torch._C._TensorMeta = wrapped_cls.obj
tensors, non_tensors = arg_printables[-2:]
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(dtypes.to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(devices.to_torch_device(a))
else:
new_non_tensors.append(a)

arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors])
kwarg_str = ""

result_str: str
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
result_str = ""
else:
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "

# Creates a comment describing the output
comment_str: str
if isinstance(bsym.output, Proxy):
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"
else:
comment_str = ""

cls_with_module = f"{cls.__name__}"
s = f"{result_str}{cls_with_module}({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}"

if bsym.header:
header_lines = (
bsym.header
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
else bsym.header.splitlines()
)
header_lines = (f"# {line}" for line in header_lines)
return chain(header_lines, [s])

filtered_types = (cls,)
if non_tensors:
types = get_nested_types([t.obj if isinstance(t, codeutils.ContextObject) else t for t in non_tensors])
filtered_types += filter_types(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)
return s


def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
cls = bsym.args[0]
non_tensors = bsym.args[-1]

filtered_types: tuple[Any, ...] = (cls,)
if non_tensors:
types = get_nested_types(non_tensors)
filtered_types += filter_types(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)


tensor_subclass_ctor = make_prim(
PrimIDs.TENSOR_SUBCLASS_CTOR,
"tensor_subclass_ctor",
meta=tensor_subclass_ctor_meta,
python_printer=printer_of_tensor_subclass_ctor,
_bind_postprocess=bind_postprocess_of_tensor_subclass_ctor,
)
Loading
Loading