Skip to content

Commit

Permalink
fix leak with uncaught exceptions (#1193)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Sep 24, 2024
1 parent 4695654 commit a92ed64
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 11 deletions.
68 changes: 57 additions & 11 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,12 @@ def is_jitting_with_raise():

# Guard against opaque functions which interrupt jitting.
if (ctx := get_interpretercompilectx_if_available()) is not None:
raise InterpreterError(f"Lookaside was not triggered, but there is an active compile context: {ctx}")
# nested try to delete ctx from locals
try:
raise InterpreterError(f"Lookaside was not triggered, but there is an active compile context: {ctx}")
except InterpreterError:
del ctx
raise

return False

Expand Down Expand Up @@ -1495,8 +1500,9 @@ def set_builtins(globals, builtins_dict):
except Exception as e:
# We need to cheat a bit to get a Python frame here...
python_frame = frame.get_or_make_python_frame()
tb = TracebackType(e.__traceback__, python_frame, python_frame.f_lasti, python_frame.f_lineno)
raise e.with_traceback(tb)
e.__traceback__ = TracebackType(e.__traceback__, python_frame, python_frame.f_lasti, python_frame.f_lineno)
del e
raise # re-raises e

if mode == "eval":
return res
Expand Down Expand Up @@ -6254,14 +6260,24 @@ def thunder_interpreter_generator():
res, status = _run_frame(frame, compilectx, runtimectx, send_value=send_value)
except Exception as e:
msg = f"Encountered exception {type(e).__name__}: {e}"
raise InterpreterError(msg) from e
# nested try ... raise to delete e from locals
try:
raise InterpreterError(msg) from e
except InterpreterError:
del e
raise
if status is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
e = runtimectx.curexc
assert isinstance(e, BaseException)
runtimectx.curexc = None
if isinstance(e, StopIteration):
return unwrap(e.value)
raise e
# nested try except to delete e from locals
try:
raise e
except BaseException:
del e
raise
if status == INTERPRETER_SIGNALS.RETURN_VALUE:
return # TODO: should this return res?
assert status == INTERPRETER_SIGNALS.YIELD_VALUE
Expand All @@ -6284,14 +6300,24 @@ async def thunder_interpreter_async_generator():
res, status = _run_frame(frame, compilectx, runtimectx, send_value=send_value)
except Exception as e:
msg = f"Encountered exception {type(e).__name__}: {e}"
raise InterpreterError(msg) from e
# nested try ... raise to delete e from locals
try:
raise InterpreterError(msg) from e
except InterpreterError:
del e
raise
if status is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
e = runtimectx.curexc
assert isinstance(e, BaseException)
runtimectx.curexc = None
if isinstance(e, StopIteration):
return
raise e
# nested try except to delete e from locals
try:
raise e
except BaseException:
del e
raise
if status == INTERPRETER_SIGNALS.RETURN_VALUE:
return # TODO: should this return res?
assert status == INTERPRETER_SIGNALS.YIELD_VALUE
Expand All @@ -6314,14 +6340,24 @@ async def thunder_interpreter_coroutine():
res, status = _run_frame(frame, compilectx, runtimectx, send_value=send_value)
except Exception as e:
msg = f"Encountered exception {type(e).__name__}: {e}"
raise InterpreterError(msg) from e
# nested try ... raise to delete e from locals
try:
raise InterpreterError(msg) from e
except InterpreterError:
del e
raise
if status is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
e = runtimectx.curexc
assert isinstance(e, BaseException)
runtimectx.curexc = None
if isinstance(e, StopIteration):
return unwrap(e.value)
raise e
# nested try except to delete e from locals
try:
raise e
except BaseException:
del e
raise
if status == INTERPRETER_SIGNALS.RETURN_VALUE:
return unwrap(res)
assert status == INTERPRETER_SIGNALS.YIELD_VALUE
Expand Down Expand Up @@ -7134,7 +7170,12 @@ def fn_2(args, kwargs):
msg = (
f"Encountered exception {type(e).__name__}: {e} while tracing {fn}:{os.linesep}" f"{traceback_str}"
)
raise InterpreterError(msg) from e
# nested try ... raise to delete e from locals
try:
raise InterpreterError(msg) from e
except InterpreterError:
del e
raise

# NOTE: Wrapped functions are valid to assign new attributes to.
fn_._last_interpreter_log = runtimectx.interp_log # type: ignore
Expand All @@ -7143,7 +7184,12 @@ def fn_2(args, kwargs):
e = runtimectx.curexc
assert isinstance(e, BaseException), e
runtimectx.curexc = None
raise e
# The below is "raise e" but deleting e from the scope
try:
raise e
except Exception:
del e
raise

return interpretation_result

Expand Down
23 changes: 23 additions & 0 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,29 @@ def foo():
assert weak_x() is None


def test_uncaught_exception_no_leak():

class Identity(torch.nn.Module):
def forward(self, x):
raise RuntimeError("FOOBAR")
return x

def main():
with torch.device("cpu"):
model = thunder.jit(Identity())
x = torch.randn(16, 16)

try:
model(x)
except:
pass
return weakref.ref(x)

weak_x = main()

assert weak_x() is None


def test_walrus_operator(jit):
def foo(a, b):
c = (a := b)
Expand Down

0 comments on commit a92ed64

Please sign in to comment.