From 61e2f8480b39a6ca64067d846c6110ecfa09bd25 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Mon, 26 Aug 2024 14:44:35 +0200 Subject: [PATCH] TE - fix propagate metadata for fp8_autocast in `from_trace` (#1021) --- thunder/core/trace.py | 2 ++ .../tests/test_transformer_engine_executor.py | 32 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/thunder/core/trace.py b/thunder/core/trace.py index da1fd27b23..be3a287091 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -482,6 +482,8 @@ def from_trace(trace: TraceCtx) -> TraceCtx: t.name_ctr = trace.name_ctr t.obj_name_ctr = trace.obj_name_ctr t.names = trace.names + # This is a detail for enabling transformer_engine's autocast manager. + t._include_te_fp8_autocast = trace._include_te_fp8_autocast t._siginfo = trace._siginfo return t diff --git a/thunder/tests/test_transformer_engine_executor.py b/thunder/tests/test_transformer_engine_executor.py index efcf213209..002b41956c 100644 --- a/thunder/tests/test_transformer_engine_executor.py +++ b/thunder/tests/test_transformer_engine_executor.py @@ -209,3 +209,35 @@ def foo(x, w): # https://github.com/NVIDIA/TransformerEngine/issues/990 out.backward(torch.randn_like(out), retain_graph=True) out.backward(torch.randn_like(out)) + + +@requiresCUDA +def test_te_trace_metadata_propagation(): + # This test is to verify that we correctly propagate metadata `_include_te_fp8_autocast` on + # trace using `from_trace`. `_include_te_fp8_autocast` is used to enable wrapping forward trace with `fp8_autocast`. + def foo(x, w): + return torch.nn.functional.linear(x, w) + + device = "cuda" + x = torch.randn(16, 16, device=device, requires_grad=True) + w = torch.randn(16, 16, device=device, requires_grad=True) + + class MyNoopTransform(thunder.core.transforms.Transform): + def transform_trace_post_optimization(self, computation_trace, **kwargs): + new_trace = thunder.core.trace.from_trace(computation_trace) + new_trace.bound_symbols = computation_trace.bound_symbols + return new_trace + + cfunc = thunder.jit( + foo, + executors=[transformer_engine_ex], + transforms=[ + MyNoopTransform(), + ], + ) + out = cfunc(x, w) + + fwd_traces = thunder.last_traces(cfunc) + + # Verify that we have `te_linear` in the trace. + assert any(bsym.sym.name.startswith("te_linear") for bsym in fwd_traces[-1].bound_symbols)