Skip to content

Commit

Permalink
remove subclass_type_to_attr_names
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 7, 2024
1 parent 03e1219 commit bed021e
Showing 1 changed file with 7 additions and 36 deletions.
43 changes: 7 additions & 36 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,21 +77,17 @@ def _materialize_tensor_proxy(t: TensorProxy, fake_tensor_mode: FakeTensorMode |
def _make_fake_subclass_tensor_from_subclass_tensor_proxy(
tensor_proxy: SubclassTensorProxy,
fake_tensor_mode: FakeTensorMode,
subclass_to_attr_names: dict[_TensorMeta, tuple[list[str], list[str]]],
) -> torch.Tensor:
utils.check(
(subclass_type := getattr(tensor_proxy, SubclassTensorProxy.SUBCLASS_TYPE_ATTR, None)) is not None,
lambda: f"{tensor_proxy} does not have `{SubclassTensorProxy.SUBCLASS_TYPE_ATTR}`",
)
utils.check(
subclass_type in subclass_to_attr_names,
lambda: f"{tensor_proxy}'s `{subclass_type=}` has never been observed",
)
utils.check(
tensor_proxy._tensors,
lambda: f"{tensor_proxy} has an empty `{tensor_proxy._tensors=}`",
)
tensor_attr_names, non_tensor_attr_names = subclass_to_attr_names[subclass_type]
tensor_attr_names = tensor_proxy._tensor_attr_names
non_tensor_attr_names = tensor_proxy._non_tensor_attr_names
inner_tensors = dict(
zip(
tensor_attr_names,
Expand All @@ -112,20 +108,18 @@ def _make_fake_subclass_tensor_from_subclass_tensor_proxy(
def materialize_tensor_proxy(
t: TensorProxy | SubclassTensorProxy,
fake_tensor_mode: FakeTensorMode,
subclass_to_attr_names: dict[_TensorMeta, tuple[list[str], list[str]]],
) -> torch.Tensor:
if isinstance(t, SubclassTensorProxy):
return _make_fake_subclass_tensor_from_subclass_tensor_proxy(t, fake_tensor_mode, subclass_to_attr_names)
return _make_fake_subclass_tensor_from_subclass_tensor_proxy(t, fake_tensor_mode)
return _materialize_tensor_proxy(t, fake_tensor_mode)


def maybe_materialize_tensor(
t: ProxyInterface,
fake_tensor_mode: FakeTensorMode,
subclass_to_attr_names: dict[_TensorMeta, tuple[list[str], list[str]]],
) -> ProxyInterface | torch.Tensor:
if isinstance(t, (TensorProxy, SubclassTensorProxy)):
return materialize_tensor_proxy(t, fake_tensor_mode, subclass_to_attr_names)
return materialize_tensor_proxy(t, fake_tensor_mode)
if isinstance(t, (Number, str)):
return t
return t.value
Expand Down Expand Up @@ -200,17 +194,14 @@ class DesugarTensorSubclass:
flat_trace_args: Sequence[ProxyInterface] = field(init=False, default=None)
flat_trace_args_spec: Any = field(init=False, default=None)
requires_desugarring: bool = field(init=False, default=False)
subclass_type_to_attr_names: dict[_TensorMeta, tuple[list[str], list[str]]] = field(
init=False, default_factory=dict
)
subclass_proxy_to_flatten: set[Variable] = field(init=False, default_factory=set)

def __post_init__(self) -> None:
self.flat_trace_args, self.flat_trace_args_spec = tree_flatten(
(self.computation_trace.args, self.computation_trace.kwargs)
)
for arg in self.flat_trace_args:
self.maybe_update_subclass_type_dict(arg)
# self.maybe_update_subclass_type_dict(arg)
if isinstance(arg, SubclassTensorProxy):
self.subclass_proxy_to_flatten.add(variableify(arg))

Expand All @@ -230,30 +221,11 @@ def __post_init__(self) -> None:
self.computation_trace,
)

def maybe_update_subclass_type_dict(self, proxy_arg: ProxyInterface) -> None:
if not isinstance(proxy_arg, SubclassTensorProxy):
return
subclass_type = getattr(proxy_arg, SubclassTensorProxy.SUBCLASS_TYPE_ATTR)
if subclass_type in self.subclass_type_to_attr_names and not hasattr(subclass_type, "_tensor_attr_names"):
tensor_attr_names, non_tensor_attr_names = self.subclass_type_to_attr_names[subclass_type]
for name, value in zip(tensor_attr_names, subclass_type._tensors):
setattr(proxy_arg, name, value)
for name, value in zip(non_tensor_attr_names, subclass_type._non_tensors):
setattr(proxy_arg, name, value)
elif subclass_type not in self.subclass_type_to_attr_names:
tensor_attr_names = proxy_arg._tensor_attr_names
non_tensor_attr_names = proxy_arg._non_tensor_attr_names
self.subclass_type_to_attr_names[subclass_type] = tensor_attr_names, non_tensor_attr_names
else:
utils.check(False, lambda: f"{proxy_arg} hasn't gotten attribute names -- {subclass_type}")

def _get_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]:
subclass_type = p._subclass_type
return self.subclass_type_to_attr_names[subclass_type][0]
return p._tensor_attr_names

def _get_non_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]:
subclass_type = p._subclass_type
return self.subclass_type_to_attr_names[subclass_type][1]
return p._non_tensor_attr_names

def translate_fx_graph_into_bsym(
self,
Expand Down Expand Up @@ -376,7 +348,6 @@ def ctor(tensors, metadata):
lambda t: maybe_materialize_tensor(
t,
self.fake_tensor_mode,
self.subclass_type_to_attr_names,
),
trace.args,
)
Expand Down

0 comments on commit bed021e

Please sign in to comment.