From a0a61518a70d80ab5ce4d1b6e44bcae159d81bb0 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 7 Nov 2024 22:04:11 +0900 Subject: [PATCH 1/4] remove print (#1408) --- thunder/core/transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 116e4094ed..d59fad498b 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2530,7 +2530,6 @@ def vjp_impl_const(symbol, *args, **kwargs): # It could be a torch.dropout with 0.0 probability, so we skip it if symbol.sym.id == "torch.nn.functional.dropout": return None - print(f"VJP for {symbol} is not implemented") raise NotImplementedError(f"VJP for {symbol.sym.id} is not implemented") def _vjp_impl(*args, **kwargs): From db84e15eaace0aaa6b55385d5389352d5a35a38f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 11 Nov 2024 17:10:33 +0900 Subject: [PATCH 2/4] [docstring] `compile` -> `jit` in example (#1409) --- thunder/core/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index d59fad498b..2bf91372fe 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2969,12 +2969,12 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa Example: >>> import torch - >>> from thunder import compile, last_traces + >>> from thunder import jit, last_traces >>> from thunder.core.transforms import forward_and_backward_from_trace >>> def f(x): ... return torch.sin(x) >>> x = torch.tensor(3.0) - >>> cf = compile(f) + >>> cf = jit(f) >>> out = cf(x) >>> trace = last_traces(cf)[0] >>> forward_and_backward_from_trace(trace) From a13a09991e482c70e069d42efbb6a013430914c8 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 11 Nov 2024 18:01:33 +0900 Subject: [PATCH 3/4] Check `_interpret_call` output in the lookaside of `torch.autograd.Function`. (#1411) --- thunder/core/jit_ext.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 7b1cf3eb87..0c6b21e55e 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -672,6 +672,8 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar trace_of_fwd, fwd_output_provenance = _convert_pytorchfunc_to_thundertrace( custom_forward, True, wrapped_ctx, *args, **kwargs ) + if trace_of_fwd is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return trace_of_fwd # Forward. unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args) From f7b2e1590fd0d98985131c54653b2a95ba4729ed Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 12 Nov 2024 04:47:38 +0900 Subject: [PATCH 4/4] [jit_ext] Inserts unpack bsyms to traces generated inside `torch.autograd.Function` (#1414) --- thunder/core/jit_ext.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 0c6b21e55e..c7ec43ad45 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -682,6 +682,11 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar unwrapped_custom_forward_args, ) trace_of_fwd.args = unwrapped_custom_forward_args + unpack_bsyms = [ + prims.unpack_trivial.bind(a, name=a.name, output=a) + for a in filter(lambda a: isinstance(a, Proxy), trace_of_fwd.args) + ] + trace_of_fwd.bound_symbols = unpack_bsyms + trace_of_fwd.bound_symbols @wraps(trace_of_fwd.python_callable()) def core_of_forward(*args, **kwargs): @@ -723,6 +728,11 @@ def core_of_forward(*args, **kwargs): ctx_proxy.saved_tensors + grads, ) trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads) + bwd_unpack_bsyms = [ + prims.unpack_trivial.bind(a, name=a.name, output=a) + for a in filter(lambda a: isinstance(a, Proxy), trace_of_backward.args) + ] + trace_of_backward.bound_symbols = bwd_unpack_bsyms + trace_of_backward.bound_symbols bwd_trace_impl = TraceCtx() bwd_trace_impl.bound_symbols.extend(trace_of_backward.bound_symbols)