diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 12a5bac4b2..f8fd555a43 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -2600,7 +2600,8 @@ def _collections_namedtuple_lookaside( *, rename: bool = False, defaults: None | Iterable[Any] = None, - module: None | str = None): + module: None | str = None, +): # Type checks { assert wrapped_isinstance(typename, str) assert wrapped_isinstance(field_names, Iterable) @@ -2628,7 +2629,7 @@ def _collections_namedtuple_lookaside( # Run opaque namedtuple { @interpreter_needs_wrap - def create_namedtuple(typename: str, field_names: str, **kwargs): + def create_namedtuple(typename: str, field_names: str, **kwargs): namedtuple_type = collections.namedtuple(typename, field_names, **kwargs) return namedtuple_type @@ -2727,7 +2728,7 @@ def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /): try: ures = tuple(w.value for w in item_wrappers) # Named tuples expect varargs, not iterables at new/init - if hasattr(new_tuple_type, 'is_namedtuple'): + if hasattr(new_tuple_type, "is_namedtuple"): ures = new_tuple_type(*ures) build_inst = PseudoInst.BUILD_NAMEDTUPLE else: diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index 76613dd981..fafe82f770 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -1035,7 +1035,7 @@ def test_namedtuple_lookaside(jit): from collections import namedtuple typename = "MyNamedTuple" - field_names = ('a', 'b', 'c') + field_names = ("a", "b", "c") # Test returnign just the type { def f(): @@ -1050,6 +1050,7 @@ def f(): # Check module name import inspect + assert jtype.__module__ == inspect.currentframe().f_globals["__name__"] # } @@ -1064,7 +1065,7 @@ def f(a, b, c): return obj[0] jf = jit(f) - + assert f(a, b, c) is a assert jf(a, b, c) is a @@ -1074,7 +1075,7 @@ def f(a, b, c): return obj.a jf = jit(f) - + assert f(a, b, c) is a assert jf(a, b, c) is a # }