diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index b6563bbe13..aea2baf52e 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -2101,8 +2101,9 @@ def __new__(cls, iterable=()): return wrap_const(cls.value()) def __init__(self, iterable=()): - SequenceWrapperMethods.__init__(self, iterable) - return wrap_const(None) + # We need to propagate the return value because it could be JIT_SIGNALS + res = SequenceWrapperMethods.__init__(self, iterable) + return res def __setitem__(self, key, value, /): self.track_items() diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index 3f0dfb0a94..ab1b45289c 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -3001,6 +3001,21 @@ def foo(): assert foo() == jit(foo)() +def test_exception_in_list_init(jit): + def foo(l): + for i in l: + yield i + + def bar(): + return list(foo(2)) + + with pytest.raises(TypeError): + bar() + + with pytest.raises(TypeError): + jit(bar)() + + # # Network tests #