diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index cb0491be0b..e80cdff692 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -1342,7 +1342,12 @@ def is_jitting_with_raise(): # Guard against opaque functions which interrupt jitting. if (ctx := get_interpretercompilectx_if_available()) is not None: - raise InterpreterError(f"Lookaside was not triggered, but there is an active compile context: {ctx}") + # nested try to delete ctx from locals + try: + raise InterpreterError(f"Lookaside was not triggered, but there is an active compile context: {ctx}") + except InterpreterError: + del ctx + raise return False @@ -1495,8 +1500,9 @@ def set_builtins(globals, builtins_dict): except Exception as e: # We need to cheat a bit to get a Python frame here... python_frame = frame.get_or_make_python_frame() - tb = TracebackType(e.__traceback__, python_frame, python_frame.f_lasti, python_frame.f_lineno) - raise e.with_traceback(tb) + e.__traceback__ = TracebackType(e.__traceback__, python_frame, python_frame.f_lasti, python_frame.f_lineno) + del e + raise # re-raises e if mode == "eval": return res @@ -6254,14 +6260,24 @@ def thunder_interpreter_generator(): res, status = _run_frame(frame, compilectx, runtimectx, send_value=send_value) except Exception as e: msg = f"Encountered exception {type(e).__name__}: {e}" - raise InterpreterError(msg) from e + # nested try ... raise to delete e from locals + try: + raise InterpreterError(msg) from e + except InterpreterError: + del e + raise if status is INTERPRETER_SIGNALS.EXCEPTION_RAISED: e = runtimectx.curexc assert isinstance(e, BaseException) runtimectx.curexc = None if isinstance(e, StopIteration): return unwrap(e.value) - raise e + # nested try except to delete e from locals + try: + raise e + except BaseException: + del e + raise if status == INTERPRETER_SIGNALS.RETURN_VALUE: return # TODO: should this return res? assert status == INTERPRETER_SIGNALS.YIELD_VALUE @@ -6284,14 +6300,24 @@ async def thunder_interpreter_async_generator(): res, status = _run_frame(frame, compilectx, runtimectx, send_value=send_value) except Exception as e: msg = f"Encountered exception {type(e).__name__}: {e}" - raise InterpreterError(msg) from e + # nested try ... raise to delete e from locals + try: + raise InterpreterError(msg) from e + except InterpreterError: + del e + raise if status is INTERPRETER_SIGNALS.EXCEPTION_RAISED: e = runtimectx.curexc assert isinstance(e, BaseException) runtimectx.curexc = None if isinstance(e, StopIteration): return - raise e + # nested try except to delete e from locals + try: + raise e + except BaseException: + del e + raise if status == INTERPRETER_SIGNALS.RETURN_VALUE: return # TODO: should this return res? assert status == INTERPRETER_SIGNALS.YIELD_VALUE @@ -6314,14 +6340,24 @@ async def thunder_interpreter_coroutine(): res, status = _run_frame(frame, compilectx, runtimectx, send_value=send_value) except Exception as e: msg = f"Encountered exception {type(e).__name__}: {e}" - raise InterpreterError(msg) from e + # nested try ... raise to delete e from locals + try: + raise InterpreterError(msg) from e + except InterpreterError: + del e + raise if status is INTERPRETER_SIGNALS.EXCEPTION_RAISED: e = runtimectx.curexc assert isinstance(e, BaseException) runtimectx.curexc = None if isinstance(e, StopIteration): return unwrap(e.value) - raise e + # nested try except to delete e from locals + try: + raise e + except BaseException: + del e + raise if status == INTERPRETER_SIGNALS.RETURN_VALUE: return unwrap(res) assert status == INTERPRETER_SIGNALS.YIELD_VALUE @@ -7134,7 +7170,12 @@ def fn_2(args, kwargs): msg = ( f"Encountered exception {type(e).__name__}: {e} while tracing {fn}:{os.linesep}" f"{traceback_str}" ) - raise InterpreterError(msg) from e + # nested try ... raise to delete e from locals + try: + raise InterpreterError(msg) from e + except InterpreterError: + del e + raise # NOTE: Wrapped functions are valid to assign new attributes to. fn_._last_interpreter_log = runtimectx.interp_log # type: ignore @@ -7143,7 +7184,12 @@ def fn_2(args, kwargs): e = runtimectx.curexc assert isinstance(e, BaseException), e runtimectx.curexc = None - raise e + # The below is "raise e" but deleting e from the scope + try: + raise e + except Exception: + del e + raise return interpretation_result diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index d0db7d8709..fbb3a32ee1 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -827,6 +827,29 @@ def foo(): assert weak_x() is None +def test_uncaught_exception_no_leak(): + + class Identity(torch.nn.Module): + def forward(self, x): + raise RuntimeError("FOOBAR") + return x + + def main(): + with torch.device("cpu"): + model = thunder.jit(Identity()) + x = torch.randn(16, 16) + + try: + model(x) + except: + pass + return weakref.ref(x) + + weak_x = main() + + assert weak_x() is None + + def test_walrus_operator(jit): def foo(a, b): c = (a := b)