-
Notifications
You must be signed in to change notification settings - Fork 699
Add Transformer-Engine Fused_Adam Optimizer Support #2293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit introduces the TE_FusedAdamW optimizer to the optimizer building function. It includes error handling for the import of the FusedAdam optimizer from the transformer_engine package, providing installation instructions if the import fails. Additionally, it updates the optimizer_kwargs to set the appropriate data types for the exponential moving averages. No existing functionality was altered, and the new optimizer is integrated seamlessly into the existing optimizer framework.
|
@tianyu-l Please help review this PR. I have kept it as draft for now because this needs a bug fix in Transformer-Engine (TE) related to DTensor handling which will be available in next TE release (fix already available in a pre-release branch). If it looks good to you I can take it out of draft after next TE release. Thanks. |
| optimizer_kwargs['exp_avg_dtype'] = torch.bfloat16 | ||
| optimizer_kwargs['exp_avg_sq_dtype'] = torch.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @vivekgoe
My impression is that low-precision optimizer states and TE dependency are two separate topics, each worth its own discussions.
For low-precision optimizer states
- @vivekgoe Could you show evidences that this is becoming the default for training?
- @janeyx99 Does pytorch optimizer plan to support this feature?
In terms of dependency, I know that TE is an important and popular optimization library. However, torchtitan has been about pytorch native, prioritizing simplicity & maintainability, and platform neutral.
- In the past, we have asked other hardware partners to create their own fork to host special optimizations. We try to organize the codebase to support extensibility (while not losing too much readbiility) so that modification on top is simple.
- On the other hand, we recommend shipping critical features to pytorch before they are integrated in torchtitan.
Would love to hear your thoughts on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tianyu-l Thanks for detailed review comments. I agree It makes sense to separate out adopting low-precision optimizer states from creating TE dependency.
Regarding your question "Could you show evidences that this is becoming the default for training?", as far as I know DeepSeek pioneered this approach and I have not seen this used elsewhere, perhaps because support for this is not available in off the shelf optimizers.
I did a quick check on Qwen3-32B and LLama4-17Bx16e to confirm if models other than DeepSeek maintain quality (along with smaller memory footprint) with this feature, results look promising (see plots below). I can create a issue in pytorch repo to check if there is interest in adding this feature to torch.optim.adamw directly.


Regarding adding TE dependency, you make very valid arguments. As a short to medium term solution, will it be ok to move TE dependency to a new folder within experiment area so that users who wish to benefit from this feature can use it? Longer term we can add feature to native pytorch optimizer (assuming maintainers agree) and remove it from experiment area.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a short to medium term solution, will it be ok to move TE dependency to a new folder within experiment area so that users who wish to benefit from this feature can use it? Longer term we can add feature to native pytorch optimizer (assuming maintainers agree) and remove it from experiment area.
The long-term plan sounds good to me.
For the short-term plan, the feedback we heard is that the current experiments folder is too messy, and we plan to clean things up (by deleting as much as possible). That doesn't mean we won't consider TE as an experiment, but we would like to see a clear vision and integration/maintenance plan, so that we can discuss and review together if it makes sense to the community. If it's just a feature like this fused adam optimizer, maybe showcasing it in a PR is good enough.
Hope it is not nonsense to you.
|
I had looked into adding a low prec AdamW into pytorch/pytorch last year, see pytorch/pytorch#146542 (comment) but ended up giving up because of lack of interest/no discussion for what a good frontend would look like. I'm noticing that in this PR the frontend is to set kwargs specifying the dtype for each state; I would want a more encompassing/scalable design that'd let us unlock low precision for more optimizers. I'm happy to kickstart this discussion again if we have users/people interested in co-designing a reasonable frontend and bring the PR to completion! |
Summary
For some models such as DeepSeek-V3, optimizer states (moments) can be stored in lower precision (torch.bfloat16) without loosing training accuracy. Storing optimizer states in lower precisions translates into significant memory saving for large models such as DeepSeek-V3-671B and allows to increase device utilization by increasing Batch-Size.
This PR adds Transformer-Engine(TE) Fused_Adam optimizer to list of optimizers supported by TorchTitan. This optimizer is a drop-in replacement for torch.optim.AdamW with additional features such as lower precision optimizer states storage support.
Tests
Accuracy
Tested on DeepSeek-V4-671B (EP=64, PP=4, GBS=1024).

Memory
Tested on DeepSeek-V4-671B (EP=64, PP=4, GBS=1024).
