From cbb335b7a5aeecda23f3f8fd8a11e42afd1cdf3e Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 5 Feb 2026 18:01:00 -0800 Subject: [PATCH 1/6] fix TorchCommsParallelDims --- .../experiments/torchcomms/parallel_dims.py | 411 ++++++++---------- torchtitan/experiments/torchcomms/train.py | 4 +- 2 files changed, 185 insertions(+), 230 deletions(-) diff --git a/torchtitan/experiments/torchcomms/parallel_dims.py b/torchtitan/experiments/torchcomms/parallel_dims.py index fd8b860e34..284988cb48 100644 --- a/torchtitan/experiments/torchcomms/parallel_dims.py +++ b/torchtitan/experiments/torchcomms/parallel_dims.py @@ -5,8 +5,8 @@ # LICENSE file in the root directory of this source tree. import os -from dataclasses import dataclass -from typing import Dict, List +from dataclasses import dataclass, field +from typing import Any import torch import torchcomms @@ -21,11 +21,11 @@ def _calculate_ranks_per_dimension( - meshes: List[torch.Tensor], - dim_names: List[str], - dim_sizes: List[int], + meshes: list[torch.Tensor], + dim_names: list[str], + dim_sizes: list[int], cur_rank: int, -) -> Dict[str, List[int]]: +) -> dict[str, list[int]]: """Util function to calculate global ranks mapping for each mesh dimension. Args: @@ -49,79 +49,12 @@ def _calculate_ranks_per_dimension( return ranks_per_dim -def _create_device_mesh( - world_size: int, - mesh_shape: tuple, - mesh_dim_names: List[str], -) -> Dict: - """Util function to create device mesh with communicators for each dimension. - - Args: - world_size: Total number of ranks in the world - mesh_shape: Shape of the device mesh - mesh_dim_names: List of dimension names for the mesh - - Returns: - Dictionary containing: - - comm: Root communicator - - device_mesh: Initialized DeviceMesh object - - mesh: Tensor representation of the mesh - - comm_per_dim: Communicators for each dimension - Returns empty dict if initialization fails - """ - backend = os.environ["TEST_BACKEND"] - device = torch.device("cuda") - mesh = torch.arange(world_size, dtype=torch.int, device="cpu").view(mesh_shape) - comm = torchcomms.new_comm( - backend, - device, - name="comms_test_n_d_parallel", - ) - - cur_rank = comm.get_rank() - - mesh_sizes = [mesh.size(idx) for idx in range(len(mesh_dim_names))] - meshes = [mesh] * len(mesh_dim_names) - ranks_per_dim = _calculate_ranks_per_dimension( - meshes, mesh_dim_names, mesh_sizes, cur_rank - ) - - # Create sub-communicators for each dimension - comm_per_dim = {} - for dim_name, ranks in ranks_per_dim.items(): - comm_per_dim[dim_name] = comm.split(ranks, dim_name) - - # Initialize device mesh with communicators - mesh_dim_comms = tuple(comm_per_dim[name] for name in mesh_dim_names) - try: - device_mesh = init_device_mesh( - mesh_dim_comms=mesh_dim_comms, - mesh_dim_names=tuple(mesh_dim_names), - _global_comm=comm, - ) - except TypeError as e: - # TODO: remove this once PT 2.10 is released - if "_rank" in str(e): - for sub_comm in comm_per_dim.values(): - sub_comm.finalize() - comm.finalize() - return {} - raise - - return { - "comm": comm, - "device_mesh": device_mesh, - "mesh": mesh, - "comm_per_dim": comm_per_dim, - } - - def _flatten_comms( - flatten_ranks_per_dim: Dict[str, List[int]], - comm, - flatten_mesh_dim_names: Dict[str, List[str]], + flatten_ranks_per_dim: dict[str, list[int]], + comm: Any, + flatten_mesh_dim_names: dict[str, list[str]], device_mesh: DeviceMesh, - comm_per_dim: Dict[str, any], + comm_per_dim: dict[str, Any], ) -> None: """Util function to flatten mesh dimensions and create corresponding communicators. @@ -152,174 +85,198 @@ def _flatten_comms( @dataclass class TorchCommsParallelDims(ParallelDims): - def _build_mesh_without_ep(self) -> DeviceMesh: - mesh_shape = (self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp) - mesh_dim_names = ["pp", "dp_replicate", "dp_shard", "cp", "tp"] - - dims = [d for d in mesh_shape if d > 1] - names = [name for d, name in zip(mesh_shape, mesh_dim_names) if d > 1] - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - - result = _create_device_mesh(self.world_size, mesh_shape, mesh_dim_names) - comm = result.get("comm", None) - device_mesh = result.get("device_mesh", None) - mesh = result.get("mesh", None) - comm_per_dim = result.get("comm_per_dim", None) - assert ( - comm is not None - and device_mesh is not None - and mesh is not None - and comm_per_dim is not None - ), "fail to init device mesh" + """ParallelDims implementation using torchcomms for device mesh initialization.""" + + # Store communicators for cleanup + comms: list[Any] = field(default_factory=list) + + def build_mesh(self) -> DeviceMesh: + """Build the device mesh using torchcomms. + + This method follows the same mesh structure as the base ParallelDims but uses + torchcomms for communicator initialization instead of torch.distributed. + + The following mesh dimensions will be created: + pp: Pipeline Parallelism (PP). + batch: Used by data loading (dp_replicate * dp_shard). + loss: Used by all-reduce when computing the loss (dp_replicate * dp_shard * cp). + dp_replicate: For DDP or HSDP replicate dimension. + fsdp: For FSDP dimension (dp_shard * cp). + cp: Context Parallelism (CP). + tp: Tensor Parallelism (TP). + ep: Expert Parallelism (EP). + efsdp: FSDP in the EP region. + etp: TP in the EP region. + """ + logger.info( + f"Building torchcomms device mesh with parallelism: " + f"pp={self.pp}, dp_replicate={self.dp_replicate}, dp_shard={self.dp_shard}, " + f"cp={self.cp}, tp={self.tp}, ep={self.ep}, etp={self.etp}" + ) + # Calculate derived dimensions + batch = self.dp_replicate * self.dp_shard + fsdp = self.dp_shard * self.cp + efsdp = fsdp * self.tp // (self.etp * self.ep) + + # Build mesh shape and names based on EP configuration + if self.ep > 1: + # With EP, we need to split dp_shard for expert parallelism + if self.etp == self.tp: + dp_shard_mod_ep = self.dp_shard * self.cp // self.ep + dp_shard_in_ep = self.ep // self.cp + else: + assert self.etp == 1 + dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep + dp_shard_in_ep = self.ep // (self.cp * self.tp) + + mesh_shape = ( + self.pp, + self.dp_replicate, + dp_shard_mod_ep, + dp_shard_in_ep, + self.cp, + self.tp, + ) + mesh_dim_names = [ + "pp", + "dp_replicate", + "dp_shard_mod_ep", + "dp_shard_in_ep", + "cp", + "tp", + ] + else: + mesh_shape = (self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp) + mesh_dim_names = ["pp", "dp_replicate", "dp_shard", "cp", "tp"] + + # Log active dimensions + active_dims = [d for d in mesh_shape if d > 1] + active_names = [name for d, name in zip(mesh_shape, mesh_dim_names) if d > 1] + logger.info(f"Building {len(active_dims)}-D device mesh with {active_names}, {active_dims}") + + # Initialize torchcomms communicators + backend = os.environ["TEST_BACKEND"] + device = torch.device("cuda") + mesh = torch.arange(self.world_size, dtype=torch.int, device="cpu").view(mesh_shape) + + comm = torchcomms.new_comm( + backend, + device, + name="torchcomms_parallel_dims", + ) cur_rank = comm.get_rank() - flatten_mesh = [ - mesh.view(self.pp, self.dp_replicate * self.dp_shard, self.cp, self.tp), - mesh.view(self.pp, self.dp_replicate, self.dp_shard * self.cp, self.tp), - mesh.view(self.pp, self.dp_replicate * self.dp_shard * self.cp, self.tp), - ] - flattened_mesh_dim_names = ["dp", "dp_shard_cp", "dp_cp"] - flatten_mesh_dim_names = { - "dp": ["dp_replicate", "dp_shard"], - "dp_shard_cp": ["dp_shard", "cp"], - "dp_cp": ["dp_replicate", "dp_shard", "cp"], - } - reshape_size = [ - self.dp_replicate * self.dp_shard, - self.dp_shard * self.cp, - self.dp_replicate * self.dp_shard * self.cp, - ] + # Calculate ranks per dimension + mesh_sizes = [mesh.size(idx) for idx in range(len(mesh_dim_names))] + meshes = [mesh] * len(mesh_dim_names) + ranks_per_dim = _calculate_ranks_per_dimension( + meshes, mesh_dim_names, mesh_sizes, cur_rank + ) + + # Create sub-communicators for each dimension + comm_per_dim: dict[str, Any] = {} + for dim_name, ranks in ranks_per_dim.items(): + comm_per_dim[dim_name] = comm.split(ranks, dim_name) + + # Initialize device mesh with torchcomms + mesh_dim_comms = tuple(comm_per_dim[name] for name in mesh_dim_names) + try: + device_mesh = init_device_mesh( + mesh_dim_comms=mesh_dim_comms, + mesh_dim_names=tuple(mesh_dim_names), + _global_comm=comm, + ) + except TypeError as e: + if "_rank" in str(e): + for sub_comm in comm_per_dim.values(): + sub_comm.finalize() + comm.finalize() + raise RuntimeError("Failed to init device mesh due to torchcomms API mismatch") from e + raise + + # Create flattened mesh dimensions for compatibility with ParallelDims API + if self.ep > 1: + flatten_mesh = [ + mesh.view(self.pp, batch, self.cp, self.tp), + mesh.view(self.pp, self.dp_replicate, fsdp, self.tp), + mesh.view(self.pp, batch * self.cp, self.tp), + mesh.view(self.pp, self.dp_replicate, efsdp, self.ep, self.etp), + ] + flattened_mesh_dim_names = ["batch", "fsdp", "loss", "ep"] + flatten_mesh_dim_names_map = { + "batch": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep"], + "fsdp": ["dp_shard_mod_ep", "dp_shard_in_ep", "cp"], + "loss": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp"], + "ep": ["dp_shard_in_ep", "cp"] if self.etp == self.tp else ["dp_shard_in_ep", "cp", "tp"], + } + reshape_sizes = [batch, fsdp, batch * self.cp, self.ep] + else: + flatten_mesh = [ + mesh.view(self.pp, batch, self.cp, self.tp), + mesh.view(self.pp, self.dp_replicate, fsdp, self.tp), + mesh.view(self.pp, batch * self.cp, self.tp), + ] + flattened_mesh_dim_names = ["batch", "fsdp", "loss"] + flatten_mesh_dim_names_map = { + "batch": ["dp_replicate", "dp_shard"], + "fsdp": ["dp_shard", "cp"], + "loss": ["dp_replicate", "dp_shard", "cp"], + } + reshape_sizes = [batch, fsdp, batch * self.cp] flatten_ranks_per_dim = _calculate_ranks_per_dimension( - flatten_mesh, flattened_mesh_dim_names, reshape_size, cur_rank + flatten_mesh, flattened_mesh_dim_names, reshape_sizes, cur_rank ) _flatten_comms( flatten_ranks_per_dim, comm, - flatten_mesh_dim_names, + flatten_mesh_dim_names_map, device_mesh, comm_per_dim, ) - # Call .finalize() in train.py after training but before destroying the process group - # to release sub-communicators before the root communicator. + # Store communicators for cleanup (sub-comms first, root comm last) self.comms = [*comm_per_dim.values(), comm] - return device_mesh - - def _build_mesh_with_ep(self) -> DeviceMesh: - # With ep, dp_shard and ep are derived submeshes: - # dp_shard = dp_shard_mod_ep * dp_shard_in_ep - if self.etp == self.tp: - # ep = dp_shard_in_ep * cp - dp_shard_mod_ep = self.dp_shard * self.cp // self.ep - dp_shard_in_ep = self.ep // self.cp - else: - assert self.etp == 1 - # ep = dp_shard_in_ep * cp * tp - dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep - dp_shard_in_ep = self.ep // (self.cp * self.tp) - - mesh_shape = ( - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep, - self.cp, - self.tp, - ) - mesh_dim_names = [ - "pp", - "dp_replicate", - "dp_shard_mod_ep", - "dp_shard_in_ep", - "cp", - "tp", - ] - - dims = [ - d - for d, name in zip(mesh_shape, mesh_dim_names) - if d > 1 or name == "dp_shard_mod_ep" - ] - names = [ - name - for d, name in zip(mesh_shape, mesh_dim_names) - if d > 1 or name == "dp_shard_mod_ep" - ] - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - - result = _create_device_mesh(self.world_size, mesh_shape, mesh_dim_names) - comm = result.get("comm", None) - device_mesh = result.get("device_mesh", None) - mesh = result.get("mesh", None) - comm_per_dim = result.get("comm_per_dim", None) - assert ( - comm is not None - and device_mesh is not None - and mesh is not None - and comm_per_dim is not None - ), "fail to init device mesh" - cur_rank = comm.get_rank() - - flatten_mesh = [ - mesh.view( - self.pp, - self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep, - self.cp, - self.tp, - ), - mesh.view( - self.pp, - self.dp_replicate, - dp_shard_mod_ep * dp_shard_in_ep * self.cp, - self.tp, - ), - mesh.view( - self.pp, - self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep * self.cp, - self.tp, - ), - mesh.view( - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep * self.cp * self.tp, - ), - ] - - flattened_mesh_dim_names = ["dp", "dp_shard_cp", "dp_cp", "ep"] - flatten_mesh_dim_names = { - "dp": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep"], - "dp_shard_cp": ["dp_shard_mod_ep", "dp_shard_in_ep", "cp"], - "dp_cp": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp"], - "ep": ["dp_shard_in_ep", "cp", "tp"], + # Store world mesh + self._world_mesh = device_mesh + + # Build internal mesh references following ParallelDims convention + self._meshes = { + "pp": device_mesh["pp"], + "batch": device_mesh["batch"], + "loss": device_mesh["loss"], + "dp_replicate": device_mesh["dp_replicate"], + "fsdp": device_mesh["fsdp"], + "cp": device_mesh["cp"], + "tp": device_mesh["tp"], } - reshape_size = [ - self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep, - dp_shard_mod_ep * dp_shard_in_ep * self.cp, - self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep * self.cp, - dp_shard_in_ep * self.cp * self.tp, - ] - - flatten_ranks_per_dim = _calculate_ranks_per_dimension( - flatten_mesh, flattened_mesh_dim_names, reshape_size, cur_rank - ) - - _flatten_comms( - flatten_ranks_per_dim, - comm, - flatten_mesh_dim_names, - device_mesh, - comm_per_dim, + if self.ep > 1: + self._meshes["ep"] = device_mesh["ep"] + self._meshes["efsdp"] = device_mesh["efsdp"] if "efsdp" in device_mesh.mesh_dim_names else device_mesh["dp_shard_mod_ep"] + self._meshes["etp"] = device_mesh["etp"] if "etp" in device_mesh.mesh_dim_names else device_mesh["tp"] + else: + # Create fake meshes for EP-related dimensions when EP is not enabled + self._meshes["ep"] = device_mesh["pp"] # placeholder + self._meshes["efsdp"] = device_mesh["fsdp"] + self._meshes["etp"] = device_mesh["tp"] + + logger.info( + f"Successfully created torchcomms meshes with dimensions: " + f"{list(comm_per_dim.keys())}" ) - # Call .finalize() in train.py after training but before destroying the process group - # to release sub-communicators before the root communicator. - self.comms = [*comm_per_dim.values(), comm] return device_mesh + + def finalize_comms(self) -> None: + """Finalize all communicators. + + Call this after training but before destroying the process group + to release sub-communicators before the root communicator. + """ + for comm in self.comms: + comm.finalize() + self.comms.clear() diff --git a/torchtitan/experiments/torchcomms/train.py b/torchtitan/experiments/torchcomms/train.py index 108bd5b2a0..487324c2c8 100644 --- a/torchtitan/experiments/torchcomms/train.py +++ b/torchtitan/experiments/torchcomms/train.py @@ -25,7 +25,6 @@ def init_distributed(self) -> ParallelDims: world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism - return TorchCommsParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, dp_replicate=parallelism_config.data_parallel_replicate_degree, @@ -40,8 +39,7 @@ def init_distributed(self) -> ParallelDims: def close(self) -> None: # Call finalize on all comms after training and before destroying process group. if hasattr(self, "parallel_dims"): - for comm in self.parallel_dims.comms: - comm.finalize() + self.parallel_dims.finalize_comms() super().close() From 840f5e71fcb6c959bf22f22b57e2186fe220abf1 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 5 Feb 2026 18:06:06 -0800 Subject: [PATCH 2/6] update --- torchtitan/experiments/torchcomms/parallel_dims.py | 10 ---------- torchtitan/experiments/torchcomms/train.py | 3 ++- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/torchtitan/experiments/torchcomms/parallel_dims.py b/torchtitan/experiments/torchcomms/parallel_dims.py index 284988cb48..cfc789f8f8 100644 --- a/torchtitan/experiments/torchcomms/parallel_dims.py +++ b/torchtitan/experiments/torchcomms/parallel_dims.py @@ -270,13 +270,3 @@ def build_mesh(self) -> DeviceMesh: ) return device_mesh - - def finalize_comms(self) -> None: - """Finalize all communicators. - - Call this after training but before destroying the process group - to release sub-communicators before the root communicator. - """ - for comm in self.comms: - comm.finalize() - self.comms.clear() diff --git a/torchtitan/experiments/torchcomms/train.py b/torchtitan/experiments/torchcomms/train.py index 487324c2c8..a0b1dfd434 100644 --- a/torchtitan/experiments/torchcomms/train.py +++ b/torchtitan/experiments/torchcomms/train.py @@ -39,7 +39,8 @@ def init_distributed(self) -> ParallelDims: def close(self) -> None: # Call finalize on all comms after training and before destroying process group. if hasattr(self, "parallel_dims"): - self.parallel_dims.finalize_comms() + for comm in self.parallel_dims.comms: + comm.finalize() super().close() From b4b0031caac667871b9a54bb18cb079c179d453d Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 5 Feb 2026 18:07:26 -0800 Subject: [PATCH 3/6] update --- .../experiments/torchcomms/parallel_dims.py | 21 +------------------ 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/torchtitan/experiments/torchcomms/parallel_dims.py b/torchtitan/experiments/torchcomms/parallel_dims.py index cfc789f8f8..b7d0310a31 100644 --- a/torchtitan/experiments/torchcomms/parallel_dims.py +++ b/torchtitan/experiments/torchcomms/parallel_dims.py @@ -91,29 +91,10 @@ class TorchCommsParallelDims(ParallelDims): comms: list[Any] = field(default_factory=list) def build_mesh(self) -> DeviceMesh: - """Build the device mesh using torchcomms. - + """ This method follows the same mesh structure as the base ParallelDims but uses torchcomms for communicator initialization instead of torch.distributed. - - The following mesh dimensions will be created: - pp: Pipeline Parallelism (PP). - batch: Used by data loading (dp_replicate * dp_shard). - loss: Used by all-reduce when computing the loss (dp_replicate * dp_shard * cp). - dp_replicate: For DDP or HSDP replicate dimension. - fsdp: For FSDP dimension (dp_shard * cp). - cp: Context Parallelism (CP). - tp: Tensor Parallelism (TP). - ep: Expert Parallelism (EP). - efsdp: FSDP in the EP region. - etp: TP in the EP region. """ - logger.info( - f"Building torchcomms device mesh with parallelism: " - f"pp={self.pp}, dp_replicate={self.dp_replicate}, dp_shard={self.dp_shard}, " - f"cp={self.cp}, tp={self.tp}, ep={self.ep}, etp={self.etp}" - ) - # Calculate derived dimensions batch = self.dp_replicate * self.dp_shard fsdp = self.dp_shard * self.cp From 99a74a3a6492b76f366eb3d2856e3846cf1d54e5 Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 5 Feb 2026 18:10:17 -0800 Subject: [PATCH 4/6] update --- torchtitan/experiments/torchcomms/parallel_dims.py | 7 ++++--- torchtitan/experiments/torchcomms/train.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchtitan/experiments/torchcomms/parallel_dims.py b/torchtitan/experiments/torchcomms/parallel_dims.py index b7d0310a31..93b3b5dec2 100644 --- a/torchtitan/experiments/torchcomms/parallel_dims.py +++ b/torchtitan/experiments/torchcomms/parallel_dims.py @@ -218,9 +218,6 @@ def build_mesh(self) -> DeviceMesh: comm_per_dim, ) - # Store communicators for cleanup (sub-comms first, root comm last) - self.comms = [*comm_per_dim.values(), comm] - # Store world mesh self._world_mesh = device_mesh @@ -250,4 +247,8 @@ def build_mesh(self) -> DeviceMesh: f"{list(comm_per_dim.keys())}" ) + # Call .finalize() in train.py after training but before destroying the process group + # to release sub-communicators before the root communicator. + self.comms = [*comm_per_dim.values(), comm] + return device_mesh diff --git a/torchtitan/experiments/torchcomms/train.py b/torchtitan/experiments/torchcomms/train.py index a0b1dfd434..108bd5b2a0 100644 --- a/torchtitan/experiments/torchcomms/train.py +++ b/torchtitan/experiments/torchcomms/train.py @@ -25,6 +25,7 @@ def init_distributed(self) -> ParallelDims: world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism + return TorchCommsParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, dp_replicate=parallelism_config.data_parallel_replicate_degree, From 580c3c01cb0699e6f14769445c674d0be1d5081a Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 5 Feb 2026 18:10:55 -0800 Subject: [PATCH 5/6] update --- torchtitan/experiments/torchcomms/parallel_dims.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/experiments/torchcomms/parallel_dims.py b/torchtitan/experiments/torchcomms/parallel_dims.py index 93b3b5dec2..97267444ac 100644 --- a/torchtitan/experiments/torchcomms/parallel_dims.py +++ b/torchtitan/experiments/torchcomms/parallel_dims.py @@ -250,5 +250,4 @@ def build_mesh(self) -> DeviceMesh: # Call .finalize() in train.py after training but before destroying the process group # to release sub-communicators before the root communicator. self.comms = [*comm_per_dim.values(), comm] - return device_mesh From 73cc07a8a403be28c88d66ebcff984956400938a Mon Sep 17 00:00:00 2001 From: mori360 Date: Thu, 5 Feb 2026 18:17:18 -0800 Subject: [PATCH 6/6] lint --- .../experiments/torchcomms/parallel_dims.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/torchtitan/experiments/torchcomms/parallel_dims.py b/torchtitan/experiments/torchcomms/parallel_dims.py index 97267444ac..fcfb242896 100644 --- a/torchtitan/experiments/torchcomms/parallel_dims.py +++ b/torchtitan/experiments/torchcomms/parallel_dims.py @@ -134,12 +134,16 @@ def build_mesh(self) -> DeviceMesh: # Log active dimensions active_dims = [d for d in mesh_shape if d > 1] active_names = [name for d, name in zip(mesh_shape, mesh_dim_names) if d > 1] - logger.info(f"Building {len(active_dims)}-D device mesh with {active_names}, {active_dims}") + logger.info( + f"Building {len(active_dims)}-D device mesh with {active_names}, {active_dims}" + ) # Initialize torchcomms communicators backend = os.environ["TEST_BACKEND"] device = torch.device("cuda") - mesh = torch.arange(self.world_size, dtype=torch.int, device="cpu").view(mesh_shape) + mesh = torch.arange(self.world_size, dtype=torch.int, device="cpu").view( + mesh_shape + ) comm = torchcomms.new_comm( backend, @@ -173,7 +177,9 @@ def build_mesh(self) -> DeviceMesh: for sub_comm in comm_per_dim.values(): sub_comm.finalize() comm.finalize() - raise RuntimeError("Failed to init device mesh due to torchcomms API mismatch") from e + raise RuntimeError( + "Failed to init device mesh due to torchcomms API mismatch" + ) from e raise # Create flattened mesh dimensions for compatibility with ParallelDims API @@ -189,7 +195,9 @@ def build_mesh(self) -> DeviceMesh: "batch": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep"], "fsdp": ["dp_shard_mod_ep", "dp_shard_in_ep", "cp"], "loss": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp"], - "ep": ["dp_shard_in_ep", "cp"] if self.etp == self.tp else ["dp_shard_in_ep", "cp", "tp"], + "ep": ["dp_shard_in_ep", "cp"] + if self.etp == self.tp + else ["dp_shard_in_ep", "cp", "tp"], } reshape_sizes = [batch, fsdp, batch * self.cp, self.ep] else: @@ -234,8 +242,16 @@ def build_mesh(self) -> DeviceMesh: if self.ep > 1: self._meshes["ep"] = device_mesh["ep"] - self._meshes["efsdp"] = device_mesh["efsdp"] if "efsdp" in device_mesh.mesh_dim_names else device_mesh["dp_shard_mod_ep"] - self._meshes["etp"] = device_mesh["etp"] if "etp" in device_mesh.mesh_dim_names else device_mesh["tp"] + self._meshes["efsdp"] = ( + device_mesh["efsdp"] + if "efsdp" in device_mesh.mesh_dim_names + else device_mesh["dp_shard_mod_ep"] + ) + self._meshes["etp"] = ( + device_mesh["etp"] + if "etp" in device_mesh.mesh_dim_names + else device_mesh["tp"] + ) else: # Create fake meshes for EP-related dimensions when EP is not enabled self._meshes["ep"] = device_mesh["pp"] # placeholder