Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TransformerEngine and Activation checkpointing interaction for thunderFX #1473

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions thunder/core/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ def __init__(self, fn: None | Callable = None, *, prologue: TraceCtx | None = No

self._any_future_tensors = False

# This is a detail for enabling transformer_engine's autocast manager.
# We only want the forward function to be called with ctx manager.
self._include_te_fp8_autocast = False

@property
def tags(self):
return self._tags
Expand Down Expand Up @@ -422,9 +418,9 @@ def keyfn(class_or_module: type | ModuleType) -> str:
# NOTE: For TE v1.6 onwards, `fp8_autocast` checks if `torch.is_grad_enabled` for updating
# the FP8 scales/inverses. So this decorator should be applied before `torch.no_grad` (so that
# it is in grad enabled part).
from thunder.executors.transformer_engineex import _is_te_linear_enabled, _get_te_wrapper_string
from thunder.executors.transformer_engineex import _is_te_linear_fwd_present, _get_te_wrapper_string

if self._include_te_fp8_autocast and _is_te_linear_enabled(import_ctx, object_ctx):
if _is_te_linear_fwd_present(self.bound_symbols):
program.append(_get_te_wrapper_string())

# Disable gradients since Thunder takes care of this (for when calling torch operations)
Expand Down Expand Up @@ -520,8 +516,6 @@ def from_trace(trace: TraceCtx) -> TraceCtx:
t.name_ctr = trace.name_ctr
t.obj_name_ctr = trace.obj_name_ctr
t.names = trace.names
# This is a detail for enabling transformer_engine's autocast manager.
t._include_te_fp8_autocast = trace._include_te_fp8_autocast
t._tags = trace._tags.copy()

t._siginfo = trace._siginfo
Expand Down
36 changes: 28 additions & 8 deletions thunder/executors/transformer_engineex.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from thunder.extend import OperatorExecutor, register_executor
from thunder.core.compile_data import get_compile_option, get_compile_data
from thunder.distributed import FSDPType
from thunder import torch as ltorch
from thunder.executors.utils import Context, set_saved_tensors


Expand Down Expand Up @@ -415,6 +416,7 @@ def _te_functional_linear_backward_meta(

IMPORT_CTX_TE_KEY = "transformer_engine"
FP8_RECIPE_KEY = "te_fp8_recipe"
TE_FWD_SYMBOL_PREFIX = "te_linear_"


# Creates a new stateful operator for each invocation of `linear`.
Expand All @@ -423,7 +425,7 @@ def _create_fp8_linear_bound_symbol(
) -> tuple[torch.Tensor, AnyProxy | None]:
linear_fn = partial(TELinear(w.shape[1], w.shape[0]), is_grad_enabled=is_grad_enabled)
global LINEAR_CALLS_COUNTER
name = f"te_linear_{LINEAR_CALLS_COUNTER}"
name = f"{TE_FWD_SYMBOL_PREFIX}{LINEAR_CALLS_COUNTER}"

desc = "transformer_engine_ex: Optional fp8_recipe for `fp8_autocast` context manager."
if (fp8_recipe := get_compile_option(FP8_RECIPE_KEY, desc)) is None:
Expand All @@ -438,7 +440,12 @@ def bind_postprocess(bsym: BoundSymbol) -> None:

meta_fn = make_te_linear_meta(is_grad_enabled=is_grad_enabled)
sym = Symbol(
name=name, meta=meta_fn, is_prim=True, executor=transformer_engine_ex, _bind_postprocess=bind_postprocess
name=name,
id=name,
meta=meta_fn,
is_prim=True,
executor=transformer_engine_ex,
_bind_postprocess=bind_postprocess,
)
bsym = sym.bind(a, w, b, output=meta_fn(a, w, b))

Expand Down Expand Up @@ -515,20 +522,33 @@ def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy:
return out


# Registers the implementation for torch.nn.functional.linear
# Registers the implementation for both prims.linear and torch.nn.functional.linear
transformer_engine_ex.register_implementation(
prims.linear,
checker=_linear_checker,
execution_transform=_linear_transform,
grad_transform=_linear_grad,
)

transformer_engine_ex.register_implementation(
ltorch.linear,
checker=_linear_checker,
execution_transform=_linear_transform,
grad_transform=_linear_grad,
)

def _is_te_linear_enabled(import_ctx, object_ctx):
# These keys are present in `import_ctx` and `object_ctx` only if
# we actually replaced a linear call with a new TE operator.
is_te_exec_enabled = IMPORT_CTX_TE_KEY in import_ctx and FP8_RECIPE_KEY in object_ctx
return is_te_exec_enabled

def _is_te_linear_fwd_present(bsyms):
"""
Returns True if any bsym is a Symbol for forward TE Linear.
Useful, to determine if trace should be wrapped with fp8_autocast.
"""
for bsym in bsyms:
sym_id = bsym.sym.id
if type(sym_id) == str and sym_id.startswith(TE_FWD_SYMBOL_PREFIX):
return True

return False


TE_CTX_STR = f"@{IMPORT_CTX_TE_KEY}.fp8_autocast(fp8_recipe={FP8_RECIPE_KEY})"
Expand Down
40 changes: 40 additions & 0 deletions thunder/tests/test_transformer_engine_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,43 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs):

# Verify that we have `te_linear` in the trace.
assert any(bsym.sym.name.startswith("te_linear") for bsym in fwd_traces[-1].bound_symbols)


@requiresCUDA
def test_checkpointing_with_transformer_engine():
def fn(x, w):
return torch.utils.checkpoint.checkpoint(lambda x, w: torch.nn.functional.linear(x, w), x, w)

te_linear_ckpt = te.Linear(16, 16)

def fn_checkpoint_eager(x):
return torch.utils.checkpoint.checkpoint(lambda x: te_linear_ckpt(x), x, use_reentrant=True)

# NOTE: Currently, checkpointing is supported via thunderfx path
from thunder.dynamo import thunderfx

x = torch.rand(16, 16, device="cuda", requires_grad=True)
w = te_linear_ckpt.weight.detach().clone()
w.requires_grad = True

actual = thunderfx(
fn,
executors=[
transformer_engine_ex,
],
)(x, w)

x_ref = x.detach().clone()
x_ref.requires_grad = True
with te.fp8_autocast():
expected = fn_checkpoint_eager(x_ref)

torch.testing.assert_close(actual, expected, rtol=1e-1, atol=1e-1)

grad_output = torch.rand_like(actual)

actual.backward(grad_output)
expected.backward(grad_output)

torch.testing.assert_close(x.grad, x_ref.grad, rtol=1e-1, atol=1e-1)
torch.testing.assert_close(w.grad, te_linear_ckpt.weight.grad, rtol=1e-1, atol=1e-1)
Loading