Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TE - fix propagate metadata for fp8_autocast in from_trace #1021

Merged
merged 1 commit into from
Aug 26, 2024

Conversation

kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Aug 22, 2024

Fixes - #1000
Smaller Repro (without CUDA graph)-

import torch
import thunder
from thunder.core.transform_common import Transform

class Module(torch.nn.Module):
    def __init__(self, in_features, out_features) -> None:
        super().__init__()
        self.linear =  torch.nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor):
        return self.linear(x)

class MyNoopTransform(Transform):
    def transform_trace_post_optimization(self, computation_trace, **kwargs):
        new_trace = thunder.core.trace.from_trace(computation_trace)
        new_trace.bound_symbols = computation_trace.bound_symbols
        print(new_trace._include_te_fp8_autocast)  # False (from_trace doesn't propagate this).
        return new_trace

with torch.device('cuda'):
    in_features = 16
    out_features = 16
    model = Module(in_features, out_features)
    for p in model.parameters():
        p.requires_grad = True

    x = torch.randn(16, in_features, requires_grad=True)

    jmodel_def = thunder.jit(model, executors=['transformer_engine',], transforms=[MyNoopTransform(),])

    y = jmodel_def(x)

Error

  File "/home/kkalambarkar/git/TransformerEngine/transformer_engine/pytorch/module/base.py", line 964, in get_fp8_workspace
    out.cast_transpose_(
  File "/home/kkalambarkar/git/pytorch/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/kkalambarkar/git/TransformerEngine/transformer_engine/pytorch/float8_tensor.py", line 732, in cast_transpose_
    fp8_meta = self._fp8_meta[fp8_meta_key]
KeyError: 'scaling_fwd'

Problem -

On trace for the forward, we set _include_te_fp8_autocast on the trace object which is then used to wrap the python representation of the trace with fp8_autocast (this is required for the TELinear to actually do the computation in FP8).

# NOTE: For TransformerEngine executor, we want to wrap the generated
# forward function in fp8_autocast ctx manager.
# In the future, if other executor has similar requirements, we should
# add a new extension point for executors
# NOTE: For TE v1.6 onwards, `fp8_autocast` checks if `torch.is_grad_enabled` for updating
# the FP8 scales/inverses. So this decorator should be applied before `torch.no_grad` (so that
# it is in grad enabled part).
from thunder.executors.transformer_engineex import _is_te_linear_enabled, _get_te_wrapper_string
if self._include_te_fp8_autocast and _is_te_linear_enabled(import_ctx, object_ctx):
program.append(_get_te_wrapper_string())

# Enable wrapping with `te.fp8_autocast`.
fw_extrace._include_te_fp8_autocast = True
# We only want the forward function to be called with `te.fp8_autocast` manager.
bw_extrace._include_te_fp8_autocast = False

However, from_trace doesn't propagate this metadata. So, if we have an additional transform in post optimisation, which uses from_trace, it will not set this. So we will have TELinear without a fp8_autocast.

Solution -
Update from_trace to propagate this metadata.

Test -
Tested locally with newly added test. TE tests don't run on CI. However, we do run it on nightlies.

NOTE - TE Linear is only enabled for training run (i.e. when inputs have requires_grad=True).

Thanks @mattteochen for pointing that he had seen the same error when fp8_autocast was missing.

@kshitij12345 kshitij12345 changed the title TE - fix propagate metadata for fp8_autocast in TE - fix propagate metadata for fp8_autocast in from_trace Aug 22, 2024
@kshitij12345 kshitij12345 marked this pull request as ready for review August 22, 2024 11:18
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @kshitij12345

@@ -482,6 +482,8 @@ def from_trace(trace: TraceCtx) -> TraceCtx:
t.name_ctr = trace.name_ctr
t.obj_name_ctr = trace.obj_name_ctr
t.names = trace.names
# This is a detail for enabling transformer_engine's autocast manager.
t._include_te_fp8_autocast = trace._include_te_fp8_autocast
Copy link
Collaborator

@t-vi t-vi Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes me wonder whether this could / should be a compile data information ("option") rather than on the trace?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this is to revisit later)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have filed an issue - #1050

@t-vi t-vi enabled auto-merge (squash) August 23, 2024 15:30
@t-vi t-vi merged commit 61e2f84 into Lightning-AI:main Aug 26, 2024
40 checks passed
@github-actions github-actions bot deleted the te-decorator-meta-propagate branch November 22, 2024 00:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants