Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,18 +494,25 @@ def _clip_grad_norm_with_ep(
else:
non_ep_params.append(p)
non_ep_grads.append(p.grad)

# Either list can be empty depending on the parallelization strategy:
# - In torchtitan with separate dense/sparse meshes, both lists are typically non-empty
# - In autoparallel, all params may live on a single sparse mesh with "ep" dimension,
# so non_ep_grads would be empty
# - In PP + EP setups, certain PP ranks may only own EP or non-EP layers
ep_grads_total_norm = torch.nn.utils.get_total_norm(
ep_grads, norm_type, error_if_nonfinite, foreach
)
# ep_grads may be an empty list, in which case get_total_norm returns tensor(0.), a non-DTensor
# This can occur in PP + EP setups where certain PP ranks only own non-EP layers, for instance.
# get_total_norm returns tensor(0.) for empty list, which is a non-DTensor
if isinstance(ep_grads_total_norm, DTensor):
ep_grads_total_norm = ep_grads_total_norm.full_tensor()

# pyrefly: ignore [missing-attribute]
non_ep_grads_total_norm = torch.nn.utils.get_total_norm(
non_ep_grads, norm_type, error_if_nonfinite, foreach
).full_tensor()
)
# get_total_norm returns tensor(0.) for empty list, which is a non-DTensor
if isinstance(non_ep_grads_total_norm, DTensor):
non_ep_grads_total_norm = non_ep_grads_total_norm.full_tensor()

if math.isinf(norm_type):
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,9 @@
# Need to share same base class with torchtitan models
class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol):
def __init__(self, model_args: DeepSeekV3ModelArgs):
super().__init__(model_args)
# Call _DeepSeekV3Model.__init__ which calls nn.Module.__init__
# Note: We don't call ModelProtocol.__init__ separately because:
# 1. nn.Module.__init__() is already called by _DeepSeekV3Model.__init__
# 2. Calling ModelProtocol.__init__ after would reset all module state
# (nn.Module.__init__ clears _modules, _parameters, etc.)
_DeepSeekV3Model.__init__(self, model_args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will break when ModelProtocol.init does subclass init things

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a better solution in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could use torchtitan as the model repository, instead of autoparallel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea, we'll do it after we have tested CP as well in auto parallel since that would alos require local map and is faster to iterate when we have model definition in AutoP

Loading