Skip to content
4 changes: 3 additions & 1 deletion torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
):
if d > 1:
# Include dp_shard dimension even if it equals 1 when replicate > 1
# to make device_mesh compatible with replicate function
if d > 1 or (name == "dp_shard" and self.dp_replicate > 1):
dims.append(d)
names.append(name)

Expand Down
6 changes: 2 additions & 4 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def context(cp_context: Generator[None, None, None] | None = None):
def maybe_enable_amp(
parallel_dims: ParallelDims, mixed_precision_param: str, device_type: torch.device
) -> Generator[None, None, None]:
if parallel_dims.fsdp_enabled:
if parallel_dims.fsdp_enabled or parallel_dims.dp_replicate_enabled:
# FSDP handles mixed precision internally
logger.info("Mixed precision training is handled by fully_shard")
return contextlib.nullcontext()
Expand Down Expand Up @@ -432,9 +432,7 @@ def _clip_grad_norm_with_ep(
if math.isinf(norm_type):
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)
else:
total_norm = (
ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
)
total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
total_norm **= 1.0 / norm_type

if pp_mesh is not None:
Expand Down
14 changes: 6 additions & 8 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
TensorParallel,
)
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.models.llama3.infra.parallelize import apply_ddp
from torchtitan.models.llama3.infra.parallelize import apply_replicate
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -169,14 +169,12 @@ def parallelize_llama(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
dp_mesh = world_mesh
apply_ddp(
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_replicate(
model,
dp_mesh,
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

return model
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/experiments/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
apply_fsdp,
apply_moe_ep_tp,
)
from torchtitan.models.llama3.infra.parallelize import apply_ddp
from torchtitan.models.llama3.infra.parallelize import apply_replicate
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -164,13 +164,12 @@ def parallelize_qwen3(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
apply_ddp(
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_replicate(
model,
world_mesh,
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

# Enable weight tying after applying parallelisms
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/experiments/vlm/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchtitan.models.llama3.infra.parallelize import (
_save_list as sac_save_list,
apply_compile,
apply_ddp,
apply_replicate,
)
from torchtitan.tools.logging import logger

Expand Down Expand Up @@ -101,13 +101,12 @@ def parallelize_vlm(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
apply_ddp(
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_replicate(
model,
world_mesh,
enable_compile=job_config.compile.enable,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

return model
Expand Down
14 changes: 6 additions & 8 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
apply_fsdp,
apply_moe_ep_tp,
)
from torchtitan.models.llama3.infra.parallelize import apply_ddp
from torchtitan.models.llama3.infra.parallelize import apply_replicate
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -162,14 +162,12 @@ def parallelize_deepseekv3(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
dp_mesh = world_mesh
apply_ddp(
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_replicate(
model,
dp_mesh,
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
world_mesh[tuple(dp_mesh_dim_names)],
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

return model
Expand Down
46 changes: 28 additions & 18 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
import torch.nn as nn
from torch.distributed._composable.replicate import replicate
from torch.distributed._composable.replicate_with_fsdp import replicate

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
Expand Down Expand Up @@ -135,13 +135,12 @@ def parallelize_llama(
if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
apply_ddp(
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
apply_replicate(
model,
world_mesh,
enable_compile=model_compile_enabled,
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
world_mesh[tuple(dp_mesh_dim_names)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this dp_mesh 2d? for user facing apis, it should be a 1d mesh or default 1d world mesh

inside replicate api, we can do 2d

Copy link
Author

@anshul-si anshul-si Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently replicate api doesn't have a method to convert the 1d mesh to 2d. I can leave a TODO to change this when the unflatten method for device mesh is complete, but until then I think it makes more sense to leave it like this. Also the user technically believes they are creating a 1d mesh because they are only changing replicate dim. I think from their perspective, they still are only creating a 1d mesh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember @fduwjj mentioned a mesh api to convert 1d to 2d. this needs to be done before landing. it's a user contract. if we change 2d back to 1d later, that becomes bc-breaking

param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

return model
Expand Down Expand Up @@ -314,20 +313,31 @@ def apply_fsdp(
fully_shard(model, **fsdp_config)


def apply_ddp(
def apply_replicate(
model: nn.Module,
dp_mesh: DeviceMesh,
enable_compile: bool,
enable_compiled_autograd: bool,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
):
if enable_compile:
if enable_compiled_autograd:
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
)
else:
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
replicate_config = {"device_mesh": dp_mesh, "mp_policy": mp_policy}

if model.tok_embeddings is not None:
replicate(
model.tok_embeddings,
**replicate_config,
)
for layer_id, transformer_block in model.layers.items():
replicate(
transformer_block,
**replicate_config,
)

replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
if model.norm is not None and model.output is not None:
replicate(
[model.norm, model.output],
**replicate_config,
)
replicate(model, **replicate_config)

logger.info("Applied DDP to the model")
Loading