From 4d61ff785c815eeadc3b0bcba189955b35f9ad50 Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Wed, 28 Jan 2026 10:13:40 +0530 Subject: [PATCH 1/2] Add TE_FusedAdamW optimizer support in build_optimizers function This commit introduces the TE_FusedAdamW optimizer to the optimizer building function. It includes error handling for the import of the FusedAdam optimizer from the transformer_engine package, providing installation instructions if the import fails. Additionally, it updates the optimizer_kwargs to set the appropriate data types for the exponential moving averages. No existing functionality was altered, and the new optimizer is integrated seamlessly into the existing optimizer framework. --- torchtitan/components/optimizer.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 5ddc80852a..4b1f23db5f 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -311,6 +311,23 @@ def build_optimizers( "Adam": torch.optim.Adam, "AdamW": torch.optim.AdamW, } + + if name == 'TE_FusedAdamW': + try: + from transformer_engine.pytorch.optimizers.fused_adam import TE_FusedAdam + except (ImportError, ModuleNotFoundError) as e: + raise ImportError( + "FusedAdam optimizer could not be imported from transformer_engine. " + "Install transformer_engine with: " + "NVTE_FRAMEWORK=pytorch pip install --no-build-isolation " + "git+https://github.com/NVIDIA/TransformerEngine.git@stable " + "or use another optimizer" + ) from e + optimizer_classes["TE_FusedAdamW"] = TE_FusedAdam + del optimizer_kwargs['fused'], optimizer_kwargs['foreach'] + optimizer_kwargs['exp_avg_dtype'] = torch.bfloat16 + optimizer_kwargs['exp_avg_sq_dtype'] = torch.bfloat16 + if name not in optimizer_classes: raise NotImplementedError(f"Optimizer {name} not added.") optimizer_cls = optimizer_classes[name] From fd871508da92d3fa3be04e0e4079f7d0ac1f43ef Mon Sep 17 00:00:00 2001 From: Vivek Goel Date: Thu, 29 Jan 2026 09:55:36 +0530 Subject: [PATCH 2/2] Fix typo in importing TE Fused_Adam --- torchtitan/components/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 4b1f23db5f..7533733526 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -314,7 +314,7 @@ def build_optimizers( if name == 'TE_FusedAdamW': try: - from transformer_engine.pytorch.optimizers.fused_adam import TE_FusedAdam + from transformer_engine.pytorch.optimizers.fused_adam import FusedAdam except (ImportError, ModuleNotFoundError) as e: raise ImportError( "FusedAdam optimizer could not be imported from transformer_engine. " @@ -323,7 +323,7 @@ def build_optimizers( "git+https://github.com/NVIDIA/TransformerEngine.git@stable " "or use another optimizer" ) from e - optimizer_classes["TE_FusedAdamW"] = TE_FusedAdam + optimizer_classes["TE_FusedAdamW"] = FusedAdam del optimizer_kwargs['fused'], optimizer_kwargs['foreach'] optimizer_kwargs['exp_avg_dtype'] = torch.bfloat16 optimizer_kwargs['exp_avg_sq_dtype'] = torch.bfloat16