diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 44822039a..5e3f59738 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -151,7 +151,9 @@ def _build_mesh_without_ep(self) -> DeviceMesh: [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], ["pp", "dp_replicate", "dp_shard", "cp", "tp"], ): - if d > 1: + # Include dp_shard dimension even if it equals 1 when replicate > 1 + # to make device_mesh compatible with replicate function + if d > 1 or (name == "dp_shard" and self.dp_replicate > 1): dims.append(d) names.append(name) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 159b6229d..08c93556a 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -217,7 +217,7 @@ def context(cp_context: Generator[None, None, None] | None = None): def maybe_enable_amp( parallel_dims: ParallelDims, mixed_precision_param: str, device_type: torch.device ) -> Generator[None, None, None]: - if parallel_dims.fsdp_enabled: + if parallel_dims.fsdp_enabled or parallel_dims.dp_replicate_enabled: # FSDP handles mixed precision internally logger.info("Mixed precision training is handled by fully_shard") return contextlib.nullcontext() @@ -432,9 +432,7 @@ def _clip_grad_norm_with_ep( if math.isinf(norm_type): total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) else: - total_norm = ( - ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type - ) + total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type total_norm **= 1.0 / norm_type if pp_mesh is not None: diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 7efc04b78..702e31996 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -28,7 +28,7 @@ TensorParallel, ) from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp -from torchtitan.models.llama3.infra.parallelize import apply_ddp +from torchtitan.models.llama3.infra.parallelize import apply_replicate from torchtitan.tools.logging import logger @@ -169,14 +169,12 @@ def parallelize_llama( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh - apply_ddp( + dp_mesh_dim_names = ("dp_replicate", "dp_shard") + apply_replicate( model, - dp_mesh, - enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) return model diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index 8367eb445..12493636a 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -28,7 +28,7 @@ apply_fsdp, apply_moe_ep_tp, ) -from torchtitan.models.llama3.infra.parallelize import apply_ddp +from torchtitan.models.llama3.infra.parallelize import apply_replicate from torchtitan.tools.logging import logger @@ -164,13 +164,12 @@ def parallelize_qwen3( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism") - apply_ddp( + dp_mesh_dim_names = ("dp_replicate", "dp_shard") + apply_replicate( model, - world_mesh, - enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) # Enable weight tying after applying parallelisms diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index 977ab04ae..9bff23c88 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -20,7 +20,7 @@ from torchtitan.models.llama3.infra.parallelize import ( _save_list as sac_save_list, apply_compile, - apply_ddp, + apply_replicate, ) from torchtitan.tools.logging import logger @@ -101,13 +101,12 @@ def parallelize_vlm( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism") - apply_ddp( + dp_mesh_dim_names = ("dp_replicate", "dp_shard") + apply_replicate( model, - world_mesh, - enable_compile=job_config.compile.enable, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) return model diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 7182b1fca..29b6c5aaf 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -25,7 +25,7 @@ apply_fsdp, apply_moe_ep_tp, ) -from torchtitan.models.llama3.infra.parallelize import apply_ddp +from torchtitan.models.llama3.infra.parallelize import apply_replicate from torchtitan.tools.logging import logger @@ -162,14 +162,12 @@ def parallelize_deepseekv3( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh - apply_ddp( + dp_mesh_dim_names = ("dp_replicate", "dp_shard") + apply_replicate( model, - dp_mesh, - enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) return model diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 7d8aa76f0..f3292eebe 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn -from torch.distributed._composable.replicate import replicate +from torch.distributed._composable.replicate_with_fsdp import replicate from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy @@ -135,13 +135,12 @@ def parallelize_llama( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: - raise RuntimeError("DDP has not supported > 1D parallelism") - apply_ddp( + dp_mesh_dim_names = ("dp_replicate", "dp_shard") + apply_replicate( model, - world_mesh, - enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) return model @@ -314,20 +313,31 @@ def apply_fsdp( fully_shard(model, **fsdp_config) -def apply_ddp( +def apply_replicate( model: nn.Module, dp_mesh: DeviceMesh, - enable_compile: bool, - enable_compiled_autograd: bool, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, ): - if enable_compile: - if enable_compiled_autograd: - torch._dynamo.config.optimize_ddp = ( - "python_reducer_without_compiled_forward" - ) - else: - torch._dynamo.config.optimize_ddp = "ddp_optimizer" + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + replicate_config = {"device_mesh": dp_mesh, "mp_policy": mp_policy} + + if model.tok_embeddings is not None: + replicate( + model.tok_embeddings, + **replicate_config, + ) + for layer_id, transformer_block in model.layers.items(): + replicate( + transformer_block, + **replicate_config, + ) - replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + if model.norm is not None and model.output is not None: + replicate( + [model.norm, model.output], + **replicate_config, + ) + replicate(model, **replicate_config) logger.info("Applied DDP to the model")