From b99603d03a4e96ee3ce332a6404d9219cde2bf44 Mon Sep 17 00:00:00 2001 From: Ali Alshaarawy Date: Mon, 18 Nov 2024 11:19:05 +0000 Subject: [PATCH 1/4] tree_flatten supports subclasses of tuple (named tuples) --- thunder/core/pytree.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 6a3d3d3a76..25d5161871 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -61,6 +61,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE): torch.autograd.function.FunctionCtx, } and not isinstance(args, (ProxyInterface)) + and not isinstance(args, tuple) and not dataclasses.is_dataclass(args) and not type(args).__module__.startswith("torch.return_types") ): From ba33789b1312532868c792f983ffd66bcf09af61 Mon Sep 17 00:00:00 2001 From: Ali Alshaarawy Date: Mon, 18 Nov 2024 11:41:06 +0000 Subject: [PATCH 2/4] use existing namedtuple check util --- thunder/core/baseutils.py | 10 ++++++++++ thunder/core/interpreter.py | 12 +----------- thunder/core/pytree.py | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/thunder/core/baseutils.py b/thunder/core/baseutils.py index 42c63aafd6..005a8f3fd4 100644 --- a/thunder/core/baseutils.py +++ b/thunder/core/baseutils.py @@ -209,6 +209,16 @@ def sequencify(x: Any) -> Sequence: def get_module(name: str) -> Any: return sys.modules[name] +def is_likely_from_collections_namedtuple(tuple_type): + from collections import namedtuple + + # Check if tuple_type code object is coming from namedtuple + return ( + hasattr(tuple_type, "__repr__") + and hasattr(tuple_type.__repr__, "__code__") + and tuple_type.__repr__.__code__ in namedtuple.__code__.co_consts + ) + # # Functions related to printing and debugging diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index bc40ce12b7..992d1e27ec 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -42,7 +42,7 @@ TracebackType, ) -from thunder.core.baseutils import Singleton, init_colors, extract_callable_name +from thunder.core.baseutils import Singleton, init_colors, extract_callable_name, is_likely_from_collections_namedtuple from thunder.core.codeutils import Positions @@ -2848,16 +2848,6 @@ def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /): else: item_wrappers.append(wv) - def is_likely_from_collections_namedtuple(tuple_type): - from collections import namedtuple - - # Check if tuple_type code object is coming from namedtuple - return ( - hasattr(tuple_type, "__repr__") - and hasattr(tuple_type.__repr__, "__code__") - and tuple_type.__repr__.__code__ in namedtuple.__code__.co_consts - ) - # Construction of namedtuples may raise try: ures = tuple(w.value for w in item_wrappers) diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 25d5161871..8c92a38555 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -6,7 +6,7 @@ import torch import thunder.core.dtypes as dtypes import thunder.core.devices as devices -from thunder.core.baseutils import ProxyInterface +from thunder.core.baseutils import ProxyInterface, is_likely_from_collections_namedtuple from types import FunctionType OPTREE_NAMESPACE = "thunder" @@ -61,7 +61,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE): torch.autograd.function.FunctionCtx, } and not isinstance(args, (ProxyInterface)) - and not isinstance(args, tuple) + and not is_likely_from_collections_namedtuple(args) and not dataclasses.is_dataclass(args) and not type(args).__module__.startswith("torch.return_types") ): From bd825baf368de4eb3d5f845060a21cd0b0cac5aa Mon Sep 17 00:00:00 2001 From: Ali Alshaarawy Date: Mon, 18 Nov 2024 12:09:08 +0000 Subject: [PATCH 3/4] add test --- thunder/tests/test_core.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 29c491be08..7698219e78 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -662,6 +662,15 @@ def test_to_printable_not_collection(): assert inp is out +def test_to_printable_collection(): + from collections import namedtuple + MyTuple = namedtuple('MyTuple', ['x', 'y']) + + inps = (MyTuple("abc", "def"),) + for inp in inps: + out = codeutils.to_printable(None, inp) + assert inp == out + # # Type promotion tests # From 5ea6748a5de6d43add7f955c95dc9b6bdd797f6e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 12:38:14 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/core/baseutils.py | 1 + thunder/tests/test_core.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/thunder/core/baseutils.py b/thunder/core/baseutils.py index 005a8f3fd4..08cd4fd159 100644 --- a/thunder/core/baseutils.py +++ b/thunder/core/baseutils.py @@ -209,6 +209,7 @@ def sequencify(x: Any) -> Sequence: def get_module(name: str) -> Any: return sys.modules[name] + def is_likely_from_collections_namedtuple(tuple_type): from collections import namedtuple diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 7698219e78..9acde1615b 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -664,13 +664,15 @@ def test_to_printable_not_collection(): def test_to_printable_collection(): from collections import namedtuple - MyTuple = namedtuple('MyTuple', ['x', 'y']) + + MyTuple = namedtuple("MyTuple", ["x", "y"]) inps = (MyTuple("abc", "def"),) for inp in inps: out = codeutils.to_printable(None, inp) assert inp == out + # # Type promotion tests #