Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformerEngine + cudagraphs #1000

Closed
mattteochen opened this issue Aug 20, 2024 · 2 comments
Closed

TransformerEngine + cudagraphs #1000

mattteochen opened this issue Aug 20, 2024 · 2 comments
Assignees
Labels
cudagraphs program-coverage Requests for model and program coverage thunderfx for things that could be applicable to the dynamo+thunder frontend TransformerEngine

Comments

@mattteochen
Copy link

mattteochen commented Aug 20, 2024

🐛 Bug

Compiling a model with Transformer Engine executor with Cudagraphs enabled is not supported

To Reproduce

Code sample

import torch
import thunder

class Module(torch.nn.Module):
    def __init__(self, in_features, out_features) -> None:
        super().__init__()
        self.linear =  torch.nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor):
        return self.linear(x)

with torch.device('cuda'):
    m = 1
    in_features = 4096 * m
    out_features = 4096 * m
    model = Module(in_features, out_features)
    x = torch.randn(768, in_features, requires_grad=True)

    jmodel_def = thunder.jit(model, executors=['transformer_engine'], use_cudagraphs=True)

    y = jmodel_def(x)

Expected behaviour

Traceback (most recent call last):
  File "/workspace/workdir/examples/dev/te.py", line 32, in <module>
    y = jmodel_def(x)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/workdir/thunder/core/module.py", line 63, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/workspace/workdir/thunder/__init__.py", line 781, in fn_
    result = cache_entry.computation_fn(*inps)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "thunder.augmented_forward_fn_3", line 12, in augmented_forward_fn
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/workdir/thunder/executors/transformer_engineex.py", line 212, in forward
    weight_fp8, weight_t_fp8 = self.get_fp8_weight_version_compat(
  File "/workspace/workdir/thunder/executors/transformer_engineex.py", line 293, in get_fp8_weight_version_compat
    weight_fp8 = self.get_fp8_workspace(
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/base.py", line 965, in get_fp8_workspace
    out.cast_transpose_(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/float8_tensor.py", line 732, in cast_transpose_
    fp8_meta = self._fp8_meta[fp8_meta_key]
KeyError: 'scaling_fwd'

Environment

  • PyTorch Version (e.g., 1.0): 2.5.0a0+gitb0fc6aa
  • Thunder: f9dbf9c
  • OS (e.g., Linux): Linux
  • Python version: 3.10.12
  • CUDA/cuDNN version: 12.6
  • GPU models and configuration: RTX ADA 6000
  • Any other relevant information: Tested on NVIDIA internal docker containers
@t-vi t-vi added cudagraphs program-coverage Requests for model and program coverage labels Aug 20, 2024
@kshitij12345 kshitij12345 self-assigned this Aug 21, 2024
@mattteochen
Copy link
Author

I remember that I had this KeyError: 'scaling_fwd' error once when running a trace TraceCtx the decorator for transformer engine was not present (@transformer_engine.fp8_autocast(fp8_recipe=te_fp8_recipe)).

This info may help.

@tfogal tfogal changed the title TE + cudagraphs TransformerEngine + cudagraphs Aug 23, 2024
@tfogal tfogal added the thunderfx for things that could be applicable to the dynamo+thunder frontend label Aug 23, 2024
@kshitij12345
Copy link
Collaborator

Fixed in #1021

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cudagraphs program-coverage Requests for model and program coverage thunderfx for things that could be applicable to the dynamo+thunder frontend TransformerEngine
Projects
None yet
Development

No branches or pull requests

4 participants