From 6d42f9aeeab46abaa6e74c60690acee418250be3 Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Thu, 22 Jan 2026 12:53:36 -0800 Subject: [PATCH] Fix grad norm clipping for AutoP and dsv3 model init stack-info: PR: https://github.com/pytorch/torchtitan/pull/2270, branch: sanketpurandare/stack/1 --- torchtitan/distributed/utils.py | 15 +++++++++++---- .../autoparallel/local_map_deepseek_v3/model.py | 7 ++++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 2ba9c08422..5f9bf3108c 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -494,18 +494,25 @@ def _clip_grad_norm_with_ep( else: non_ep_params.append(p) non_ep_grads.append(p.grad) + + # Either list can be empty depending on the parallelization strategy: + # - In torchtitan with separate dense/sparse meshes, both lists are typically non-empty + # - In autoparallel, all params may live on a single sparse mesh with "ep" dimension, + # so non_ep_grads would be empty + # - In PP + EP setups, certain PP ranks may only own EP or non-EP layers ep_grads_total_norm = torch.nn.utils.get_total_norm( ep_grads, norm_type, error_if_nonfinite, foreach ) - # ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor - # This can occur in PP + EP setups where certain PP ranks only own non-EP layers, for instance. + # get_total_norm returns tensor(0.) for empty list, which is a non-DTensor if isinstance(ep_grads_total_norm, DTensor): ep_grads_total_norm = ep_grads_total_norm.full_tensor() - # pyrefly: ignore [missing-attribute] non_ep_grads_total_norm = torch.nn.utils.get_total_norm( non_ep_grads, norm_type, error_if_nonfinite, foreach - ).full_tensor() + ) + # get_total_norm returns tensor(0.) for empty list, which is a non-DTensor + if isinstance(non_ep_grads_total_norm, DTensor): + non_ep_grads_total_norm = non_ep_grads_total_norm.full_tensor() if math.isinf(norm_type): total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py index f4915fb708..a620066330 100644 --- a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py @@ -15,4 +15,9 @@ # Need to share same base class with torchtitan models class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol): def __init__(self, model_args: DeepSeekV3ModelArgs): - super().__init__(model_args) + # Call _DeepSeekV3Model.__init__ which calls nn.Module.__init__ + # Note: We don't call ModelProtocol.__init__ separately because: + # 1. nn.Module.__init__() is already called by _DeepSeekV3Model.__init__ + # 2. Calling ModelProtocol.__init__ after would reset all module state + # (nn.Module.__init__ clears _modules, _parameters, etc.) + _DeepSeekV3Model.__init__(self, model_args)