Skip to content

Commit

Permalink
always use dict.__getitem__ in dict/OrderedDict iteration (#676)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jun 28, 2024
1 parent 27a573d commit 72e033a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
4 changes: 2 additions & 2 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,7 +2345,7 @@ def __iter__(self):
return self

def __next__(self):
return self._mapping[next(self._key_iter)]
return dict.__getitem__(self._mapping, next(self._key_iter))


class MappingValuesWrapper:
Expand All @@ -2369,7 +2369,7 @@ def __iter__(self):

def __next__(self):
k = next(self._key_iter)
return k, self._mapping[k]
return k, dict.__getitem__(self._mapping, k)


class MappingItemsWrapper:
Expand Down
16 changes: 16 additions & 0 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3214,3 +3214,19 @@ def test_litgpt(jit):
result = jfn(*args, **kwargs)

assert_close(result, fn(*args, **kwargs))


def test_transformer_model_output():
pytest.importorskip("transformers")
from transformers.utils.generic import ModelOutput

def fn(x):
mo = ModelOutput(foo=x)
return mo["foo"]

x = torch.randn(3)
expected = fn(x)

actual = thunder.jit(fn)(x)

assert expected is actual

0 comments on commit 72e033a

Please sign in to comment.