diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 5ddc80852a..7533733526 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -311,6 +311,23 @@ def build_optimizers( "Adam": torch.optim.Adam, "AdamW": torch.optim.AdamW, } + + if name == 'TE_FusedAdamW': + try: + from transformer_engine.pytorch.optimizers.fused_adam import FusedAdam + except (ImportError, ModuleNotFoundError) as e: + raise ImportError( + "FusedAdam optimizer could not be imported from transformer_engine. " + "Install transformer_engine with: " + "NVTE_FRAMEWORK=pytorch pip install --no-build-isolation " + "git+https://github.com/NVIDIA/TransformerEngine.git@stable " + "or use another optimizer" + ) from e + optimizer_classes["TE_FusedAdamW"] = FusedAdam + del optimizer_kwargs['fused'], optimizer_kwargs['foreach'] + optimizer_kwargs['exp_avg_dtype'] = torch.bfloat16 + optimizer_kwargs['exp_avg_sq_dtype'] = torch.bfloat16 + if name not in optimizer_classes: raise NotImplementedError(f"Optimizer {name} not added.") optimizer_cls = optimizer_classes[name]