Skip to content

Commit

Permalink
TE - fix propagate metadata for fp8_autocast in from_trace (#1021)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitij12345 authored Aug 26, 2024
1 parent 91ff7b7 commit 61e2f84
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
2 changes: 2 additions & 0 deletions thunder/core/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,8 @@ def from_trace(trace: TraceCtx) -> TraceCtx:
t.name_ctr = trace.name_ctr
t.obj_name_ctr = trace.obj_name_ctr
t.names = trace.names
# This is a detail for enabling transformer_engine's autocast manager.
t._include_te_fp8_autocast = trace._include_te_fp8_autocast

t._siginfo = trace._siginfo
return t
Expand Down
32 changes: 32 additions & 0 deletions thunder/tests/test_transformer_engine_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,35 @@ def foo(x, w):
# https://github.com/NVIDIA/TransformerEngine/issues/990
out.backward(torch.randn_like(out), retain_graph=True)
out.backward(torch.randn_like(out))


@requiresCUDA
def test_te_trace_metadata_propagation():
# This test is to verify that we correctly propagate metadata `_include_te_fp8_autocast` on
# trace using `from_trace`. `_include_te_fp8_autocast` is used to enable wrapping forward trace with `fp8_autocast`.
def foo(x, w):
return torch.nn.functional.linear(x, w)

device = "cuda"
x = torch.randn(16, 16, device=device, requires_grad=True)
w = torch.randn(16, 16, device=device, requires_grad=True)

class MyNoopTransform(thunder.core.transforms.Transform):
def transform_trace_post_optimization(self, computation_trace, **kwargs):
new_trace = thunder.core.trace.from_trace(computation_trace)
new_trace.bound_symbols = computation_trace.bound_symbols
return new_trace

cfunc = thunder.jit(
foo,
executors=[transformer_engine_ex],
transforms=[
MyNoopTransform(),
],
)
out = cfunc(x, w)

fwd_traces = thunder.last_traces(cfunc)

# Verify that we have `te_linear` in the trace.
assert any(bsym.sym.name.startswith("te_linear") for bsym in fwd_traces[-1].bound_symbols)

0 comments on commit 61e2f84

Please sign in to comment.