Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 25, 2024
1 parent 1752d10 commit 18dedec
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 4 additions & 3 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,7 +2600,8 @@ def _collections_namedtuple_lookaside(
*,
rename: bool = False,
defaults: None | Iterable[Any] = None,
module: None | str = None):
module: None | str = None,
):
# Type checks {
assert wrapped_isinstance(typename, str)
assert wrapped_isinstance(field_names, Iterable)
Expand Down Expand Up @@ -2628,7 +2629,7 @@ def _collections_namedtuple_lookaside(

# Run opaque namedtuple {
@interpreter_needs_wrap
def create_namedtuple(typename: str, field_names: str, **kwargs):
def create_namedtuple(typename: str, field_names: str, **kwargs):
namedtuple_type = collections.namedtuple(typename, field_names, **kwargs)
return namedtuple_type

Expand Down Expand Up @@ -2727,7 +2728,7 @@ def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /):
try:
ures = tuple(w.value for w in item_wrappers)
# Named tuples expect varargs, not iterables at new/init
if hasattr(new_tuple_type, 'is_namedtuple'):
if hasattr(new_tuple_type, "is_namedtuple"):
ures = new_tuple_type(*ures)
build_inst = PseudoInst.BUILD_NAMEDTUPLE
else:
Expand Down
7 changes: 4 additions & 3 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def test_namedtuple_lookaside(jit):
from collections import namedtuple

typename = "MyNamedTuple"
field_names = ('a', 'b', 'c')
field_names = ("a", "b", "c")

# Test returnign just the type {
def f():
Expand All @@ -1050,6 +1050,7 @@ def f():

# Check module name
import inspect

assert jtype.__module__ == inspect.currentframe().f_globals["__name__"]
# }

Expand All @@ -1064,7 +1065,7 @@ def f(a, b, c):
return obj[0]

jf = jit(f)

assert f(a, b, c) is a
assert jf(a, b, c) is a

Expand All @@ -1074,7 +1075,7 @@ def f(a, b, c):
return obj.a

jf = jit(f)

assert f(a, b, c) is a
assert jf(a, b, c) is a
# }
Expand Down

0 comments on commit 18dedec

Please sign in to comment.