Skip to content

Commit

Permalink
tree_flatten supports named tuples (#1446)
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-alshaar7 authored Nov 18, 2024
1 parent 11a32a4 commit a5c523d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 12 deletions.
11 changes: 11 additions & 0 deletions thunder/core/baseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ def get_module(name: str) -> Any:
return sys.modules[name]


def is_likely_from_collections_namedtuple(tuple_type):
from collections import namedtuple

# Check if tuple_type code object is coming from namedtuple
return (
hasattr(tuple_type, "__repr__")
and hasattr(tuple_type.__repr__, "__code__")
and tuple_type.__repr__.__code__ in namedtuple.__code__.co_consts
)


#
# Functions related to printing and debugging
#
Expand Down
12 changes: 1 addition & 11 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
TracebackType,
)

from thunder.core.baseutils import Singleton, init_colors, extract_callable_name
from thunder.core.baseutils import Singleton, init_colors, extract_callable_name, is_likely_from_collections_namedtuple
from thunder.core.codeutils import Positions


Expand Down Expand Up @@ -2848,16 +2848,6 @@ def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /):
else:
item_wrappers.append(wv)

def is_likely_from_collections_namedtuple(tuple_type):
from collections import namedtuple

# Check if tuple_type code object is coming from namedtuple
return (
hasattr(tuple_type, "__repr__")
and hasattr(tuple_type.__repr__, "__code__")
and tuple_type.__repr__.__code__ in namedtuple.__code__.co_consts
)

# Construction of namedtuples may raise
try:
ures = tuple(w.value for w in item_wrappers)
Expand Down
3 changes: 2 additions & 1 deletion thunder/core/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import thunder.core.dtypes as dtypes
import thunder.core.devices as devices
from thunder.core.baseutils import ProxyInterface
from thunder.core.baseutils import ProxyInterface, is_likely_from_collections_namedtuple
from types import FunctionType

OPTREE_NAMESPACE = "thunder"
Expand Down Expand Up @@ -61,6 +61,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE):
torch.autograd.function.FunctionCtx,
}
and not isinstance(args, (ProxyInterface))
and not is_likely_from_collections_namedtuple(args)
and not dataclasses.is_dataclass(args)
and not type(args).__module__.startswith("torch.return_types")
):
Expand Down
11 changes: 11 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,17 @@ def test_to_printable_not_collection():
assert inp is out


def test_to_printable_collection():
from collections import namedtuple

MyTuple = namedtuple("MyTuple", ["x", "y"])

inps = (MyTuple("abc", "def"),)
for inp in inps:
out = codeutils.to_printable(None, inp)
assert inp == out


#
# Type promotion tests
#
Expand Down

0 comments on commit a5c523d

Please sign in to comment.