Skip to content

Commit

Permalink
Proxy and prims
Browse files Browse the repository at this point in the history
to support programs that only call ctor of tensor wrapper subclasses

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 23, 2024
1 parent 9d79b8d commit 17b5570
Show file tree
Hide file tree
Showing 5 changed files with 564 additions and 4 deletions.
37 changes: 37 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
StringProxy,
TensorProxy,
FutureTensorProxy,
SubclassTensorProxy,
make_proxy_name,
Variable,
variableify,
Expand Down Expand Up @@ -930,6 +931,42 @@ def thunder_function(*args, **kwargs):
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)


@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass)
def _make_wrapper_subclass(
cls: torch._C._TensorMeta,
size: Sequence[int],
strides: Sequence[int] | None = None,
storage_offset: int | None = None,
memory_format: torch.memory_format | None = None,
dtype: torch.dtype | None = None,
layout: torch.layout | None = torch.strided,
device: torch.device | None = None,
pin_memory: bool = False,
requires_grad: bool = False,
dispatch_sizes_strides_policy: str | None = None,
dispatch_device: bool = False,
dispatch_layout: bool = False,
_extra_dispatch_keys: torch.DispatchKeySet | None = None,
storage_size: int | None = None,
):
ucls = unwrap(cls)
usize = unwrap(size)
udtype = unwrap(dtype)
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)

subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))


# Adds proxy methods
# NOTE These methods map to themselves, which prevents the interpreter from looking into them
# This is OK because these methods are written in a tracing-safe manner, and trying to
Expand Down
151 changes: 149 additions & 2 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from __future__ import annotations
from enum import auto, Enum
from numbers import Number
from functools import reduce, wraps
import operator
import builtins
import math
from types import NoneType
from typing import Union, Type, Any, List, Dict, Tuple, Optional
from typing import Union, Type, Any, List, Dict, Tuple, Optional, TYPE_CHECKING
from collections.abc import Callable
from collections.abc import Callable, Hashable, Sequence

import torch

from thunder.core.langctxs import LanguageContext, register_langctx, Languages, langctx

if TYPE_CHECKING:
from collections.abc import Iterable
from thunder.core.codeutils import ContextObject

#
# Creates and registers the torch language context
#
Expand Down Expand Up @@ -77,6 +82,7 @@ def register_method(method_name: str, method: Callable, /) -> None:
TupleProxy,
AnyProxy,
IntegerProxy,
SubclassTensorProxy,
)
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
Expand Down Expand Up @@ -272,6 +278,8 @@ class PrimIDs(Enum):
COPY_ = auto()
#
SINK = auto()
# Tensor Subclasses methods
TENSOR_SUBCLASS_CTOR = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -3543,7 +3551,10 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro
view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,))


def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy:
def shallow_copy_meta(a: TensorProxy | SubclassTensorProxy, /) -> TensorProxy:
if isinstance(a, SubclassTensorProxy):
# SubclassTensorProxy(like=...) would not copy some attrs such as `_tensors` while replace does.
return a.replace()
return TensorProxy(like=a)


Expand Down Expand Up @@ -4048,3 +4059,139 @@ def sink_meta(*args, **kwargs):

# TODO do we want another tag to remove this after prologue is constructed?
sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,))


def tensor_subclass_ctor_meta(
cls, name, shape, device, dtype, requires_grad, tensors, non_tensors
) -> SubclassTensorProxy:
s = SubclassTensorProxy(
name,
subclass_type=cls,
shape=shape,
device=device,
dtype=dtype,
requires_grad=requires_grad,
tensors=tensors,
non_tensors=non_tensors,
history=[t.history for t in tensors],
)
return s


def get_nested_types(collection):
collection = utils.sequencify(collection)
types_set = {type(t) for t in collection}

def check_types(coll):
for item in coll:
types_set.add(type(item))
# Check if the item is a nested collection
if baseutils.is_collection(item):
# If it's a dictionary, check its values
if isinstance(item, dict):
check_types(item.values())
# Recursively check nested collections
else:
check_types(item)

check_types(collection)
return tuple(types_set)


def filter_types(types: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(
filter(
lambda t: (
t.__module__ != "builtins"
and t != Number
# note(crcrpar): maybe `thunder.core`?
and not t.__module__.startswith("thunder.")
and not t.__module__.startswith("torch.")
),
types,
)
)


def printer_of_tensor_subclass_ctor(
bsym: BoundSymbol,
out_printables: Any,
arg_printables: Sequence[Printable],
kwarg_printables: dict[str, Printable],
) -> str | Iterable[str]:
from itertools import chain

baseutils.check(not kwarg_printables, lambda: f"No kwargs are supported but {kwarg_printables = }")

# NOTE(crcrpar): It's not a context but at the moment Tensor subclass is treated as `ContextObject`.
wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0]
if isinstance(wrapped_cls, torch._C._TensorMeta):
cls = wrapped_cls
else:
cls: torch._C._TensorMeta = wrapped_cls.obj
tensors, non_tensors = arg_printables[-2:]
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(dtypes.to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(devices.to_torch_device(a))
else:
new_non_tensors.append(a)

arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors])
kwarg_str = ""

result_str: str
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
result_str = ""
else:
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "

# Creates a comment describing the output
comment_str: str
if isinstance(bsym.output, Proxy):
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"
else:
comment_str = ""

cls_with_module = f"{cls.__name__}"
s = f"{result_str}{cls_with_module}({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}"

if bsym.header:
header_lines = (
bsym.header
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
else bsym.header.splitlines()
)
header_lines = (f"# {line}" for line in header_lines)
return chain(header_lines, [s])

filtered_types = (cls,)
if non_tensors:
types = get_nested_types([t.obj if isinstance(t, codeutils.ContextObject) else t for t in non_tensors])
filtered_types += filter_types(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)
return s


def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
cls = bsym.args[0]
non_tensors = bsym.args[-1]

filtered_types: tuple[Any, ...] = (cls,)
if non_tensors:
types = get_nested_types(non_tensors)
filtered_types += filter_types(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)


tensor_subclass_ctor = make_prim(
PrimIDs.TENSOR_SUBCLASS_CTOR,
"tensor_subclass_ctor",
meta=tensor_subclass_ctor_meta,
python_printer=printer_of_tensor_subclass_ctor,
_bind_postprocess=bind_postprocess_of_tensor_subclass_ctor,
)
Loading

0 comments on commit 17b5570

Please sign in to comment.