diff --git a/torchtitan/experiments/torchcomms/parallel_dims.py b/torchtitan/experiments/torchcomms/parallel_dims.py index fd8b860e34..fcfb242896 100644 --- a/torchtitan/experiments/torchcomms/parallel_dims.py +++ b/torchtitan/experiments/torchcomms/parallel_dims.py @@ -5,8 +5,8 @@ # LICENSE file in the root directory of this source tree. import os -from dataclasses import dataclass -from typing import Dict, List +from dataclasses import dataclass, field +from typing import Any import torch import torchcomms @@ -21,11 +21,11 @@ def _calculate_ranks_per_dimension( - meshes: List[torch.Tensor], - dim_names: List[str], - dim_sizes: List[int], + meshes: list[torch.Tensor], + dim_names: list[str], + dim_sizes: list[int], cur_rank: int, -) -> Dict[str, List[int]]: +) -> dict[str, list[int]]: """Util function to calculate global ranks mapping for each mesh dimension. Args: @@ -49,79 +49,12 @@ def _calculate_ranks_per_dimension( return ranks_per_dim -def _create_device_mesh( - world_size: int, - mesh_shape: tuple, - mesh_dim_names: List[str], -) -> Dict: - """Util function to create device mesh with communicators for each dimension. - - Args: - world_size: Total number of ranks in the world - mesh_shape: Shape of the device mesh - mesh_dim_names: List of dimension names for the mesh - - Returns: - Dictionary containing: - - comm: Root communicator - - device_mesh: Initialized DeviceMesh object - - mesh: Tensor representation of the mesh - - comm_per_dim: Communicators for each dimension - Returns empty dict if initialization fails - """ - backend = os.environ["TEST_BACKEND"] - device = torch.device("cuda") - mesh = torch.arange(world_size, dtype=torch.int, device="cpu").view(mesh_shape) - comm = torchcomms.new_comm( - backend, - device, - name="comms_test_n_d_parallel", - ) - - cur_rank = comm.get_rank() - - mesh_sizes = [mesh.size(idx) for idx in range(len(mesh_dim_names))] - meshes = [mesh] * len(mesh_dim_names) - ranks_per_dim = _calculate_ranks_per_dimension( - meshes, mesh_dim_names, mesh_sizes, cur_rank - ) - - # Create sub-communicators for each dimension - comm_per_dim = {} - for dim_name, ranks in ranks_per_dim.items(): - comm_per_dim[dim_name] = comm.split(ranks, dim_name) - - # Initialize device mesh with communicators - mesh_dim_comms = tuple(comm_per_dim[name] for name in mesh_dim_names) - try: - device_mesh = init_device_mesh( - mesh_dim_comms=mesh_dim_comms, - mesh_dim_names=tuple(mesh_dim_names), - _global_comm=comm, - ) - except TypeError as e: - # TODO: remove this once PT 2.10 is released - if "_rank" in str(e): - for sub_comm in comm_per_dim.values(): - sub_comm.finalize() - comm.finalize() - return {} - raise - - return { - "comm": comm, - "device_mesh": device_mesh, - "mesh": mesh, - "comm_per_dim": comm_per_dim, - } - - def _flatten_comms( - flatten_ranks_per_dim: Dict[str, List[int]], - comm, - flatten_mesh_dim_names: Dict[str, List[str]], + flatten_ranks_per_dim: dict[str, list[int]], + comm: Any, + flatten_mesh_dim_names: dict[str, list[str]], device_mesh: DeviceMesh, - comm_per_dim: Dict[str, any], + comm_per_dim: dict[str, Any], ) -> None: """Util function to flatten mesh dimensions and create corresponding communicators. @@ -152,173 +85,184 @@ def _flatten_comms( @dataclass class TorchCommsParallelDims(ParallelDims): - def _build_mesh_without_ep(self) -> DeviceMesh: - mesh_shape = (self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp) - mesh_dim_names = ["pp", "dp_replicate", "dp_shard", "cp", "tp"] - - dims = [d for d in mesh_shape if d > 1] - names = [name for d, name in zip(mesh_shape, mesh_dim_names) if d > 1] - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - - result = _create_device_mesh(self.world_size, mesh_shape, mesh_dim_names) - comm = result.get("comm", None) - device_mesh = result.get("device_mesh", None) - mesh = result.get("mesh", None) - comm_per_dim = result.get("comm_per_dim", None) - assert ( - comm is not None - and device_mesh is not None - and mesh is not None - and comm_per_dim is not None - ), "fail to init device mesh" - - cur_rank = comm.get_rank() - - flatten_mesh = [ - mesh.view(self.pp, self.dp_replicate * self.dp_shard, self.cp, self.tp), - mesh.view(self.pp, self.dp_replicate, self.dp_shard * self.cp, self.tp), - mesh.view(self.pp, self.dp_replicate * self.dp_shard * self.cp, self.tp), - ] - flattened_mesh_dim_names = ["dp", "dp_shard_cp", "dp_cp"] - flatten_mesh_dim_names = { - "dp": ["dp_replicate", "dp_shard"], - "dp_shard_cp": ["dp_shard", "cp"], - "dp_cp": ["dp_replicate", "dp_shard", "cp"], - } - reshape_size = [ - self.dp_replicate * self.dp_shard, - self.dp_shard * self.cp, - self.dp_replicate * self.dp_shard * self.cp, - ] - - flatten_ranks_per_dim = _calculate_ranks_per_dimension( - flatten_mesh, flattened_mesh_dim_names, reshape_size, cur_rank + """ParallelDims implementation using torchcomms for device mesh initialization.""" + + # Store communicators for cleanup + comms: list[Any] = field(default_factory=list) + + def build_mesh(self) -> DeviceMesh: + """ + This method follows the same mesh structure as the base ParallelDims but uses + torchcomms for communicator initialization instead of torch.distributed. + """ + # Calculate derived dimensions + batch = self.dp_replicate * self.dp_shard + fsdp = self.dp_shard * self.cp + efsdp = fsdp * self.tp // (self.etp * self.ep) + + # Build mesh shape and names based on EP configuration + if self.ep > 1: + # With EP, we need to split dp_shard for expert parallelism + if self.etp == self.tp: + dp_shard_mod_ep = self.dp_shard * self.cp // self.ep + dp_shard_in_ep = self.ep // self.cp + else: + assert self.etp == 1 + dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep + dp_shard_in_ep = self.ep // (self.cp * self.tp) + + mesh_shape = ( + self.pp, + self.dp_replicate, + dp_shard_mod_ep, + dp_shard_in_ep, + self.cp, + self.tp, + ) + mesh_dim_names = [ + "pp", + "dp_replicate", + "dp_shard_mod_ep", + "dp_shard_in_ep", + "cp", + "tp", + ] + else: + mesh_shape = (self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp) + mesh_dim_names = ["pp", "dp_replicate", "dp_shard", "cp", "tp"] + + # Log active dimensions + active_dims = [d for d in mesh_shape if d > 1] + active_names = [name for d, name in zip(mesh_shape, mesh_dim_names) if d > 1] + logger.info( + f"Building {len(active_dims)}-D device mesh with {active_names}, {active_dims}" ) - _flatten_comms( - flatten_ranks_per_dim, - comm, - flatten_mesh_dim_names, - device_mesh, - comm_per_dim, + # Initialize torchcomms communicators + backend = os.environ["TEST_BACKEND"] + device = torch.device("cuda") + mesh = torch.arange(self.world_size, dtype=torch.int, device="cpu").view( + mesh_shape ) - # Call .finalize() in train.py after training but before destroying the process group - # to release sub-communicators before the root communicator. - self.comms = [*comm_per_dim.values(), comm] - return device_mesh - - def _build_mesh_with_ep(self) -> DeviceMesh: - # With ep, dp_shard and ep are derived submeshes: - # dp_shard = dp_shard_mod_ep * dp_shard_in_ep - if self.etp == self.tp: - # ep = dp_shard_in_ep * cp - dp_shard_mod_ep = self.dp_shard * self.cp // self.ep - dp_shard_in_ep = self.ep // self.cp - else: - assert self.etp == 1 - # ep = dp_shard_in_ep * cp * tp - dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep - dp_shard_in_ep = self.ep // (self.cp * self.tp) - - mesh_shape = ( - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep, - self.cp, - self.tp, + comm = torchcomms.new_comm( + backend, + device, + name="torchcomms_parallel_dims", ) - mesh_dim_names = [ - "pp", - "dp_replicate", - "dp_shard_mod_ep", - "dp_shard_in_ep", - "cp", - "tp", - ] - - dims = [ - d - for d, name in zip(mesh_shape, mesh_dim_names) - if d > 1 or name == "dp_shard_mod_ep" - ] - names = [ - name - for d, name in zip(mesh_shape, mesh_dim_names) - if d > 1 or name == "dp_shard_mod_ep" - ] - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - - result = _create_device_mesh(self.world_size, mesh_shape, mesh_dim_names) - comm = result.get("comm", None) - device_mesh = result.get("device_mesh", None) - mesh = result.get("mesh", None) - comm_per_dim = result.get("comm_per_dim", None) - assert ( - comm is not None - and device_mesh is not None - and mesh is not None - and comm_per_dim is not None - ), "fail to init device mesh" - cur_rank = comm.get_rank() - flatten_mesh = [ - mesh.view( - self.pp, - self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep, - self.cp, - self.tp, - ), - mesh.view( - self.pp, - self.dp_replicate, - dp_shard_mod_ep * dp_shard_in_ep * self.cp, - self.tp, - ), - mesh.view( - self.pp, - self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep * self.cp, - self.tp, - ), - mesh.view( - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep * self.cp * self.tp, - ), - ] - - flattened_mesh_dim_names = ["dp", "dp_shard_cp", "dp_cp", "ep"] - flatten_mesh_dim_names = { - "dp": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep"], - "dp_shard_cp": ["dp_shard_mod_ep", "dp_shard_in_ep", "cp"], - "dp_cp": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp"], - "ep": ["dp_shard_in_ep", "cp", "tp"], - } + # Calculate ranks per dimension + mesh_sizes = [mesh.size(idx) for idx in range(len(mesh_dim_names))] + meshes = [mesh] * len(mesh_dim_names) + ranks_per_dim = _calculate_ranks_per_dimension( + meshes, mesh_dim_names, mesh_sizes, cur_rank + ) - reshape_size = [ - self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep, - dp_shard_mod_ep * dp_shard_in_ep * self.cp, - self.dp_replicate * dp_shard_mod_ep * dp_shard_in_ep * self.cp, - dp_shard_in_ep * self.cp * self.tp, - ] + # Create sub-communicators for each dimension + comm_per_dim: dict[str, Any] = {} + for dim_name, ranks in ranks_per_dim.items(): + comm_per_dim[dim_name] = comm.split(ranks, dim_name) + + # Initialize device mesh with torchcomms + mesh_dim_comms = tuple(comm_per_dim[name] for name in mesh_dim_names) + try: + device_mesh = init_device_mesh( + mesh_dim_comms=mesh_dim_comms, + mesh_dim_names=tuple(mesh_dim_names), + _global_comm=comm, + ) + except TypeError as e: + if "_rank" in str(e): + for sub_comm in comm_per_dim.values(): + sub_comm.finalize() + comm.finalize() + raise RuntimeError( + "Failed to init device mesh due to torchcomms API mismatch" + ) from e + raise + + # Create flattened mesh dimensions for compatibility with ParallelDims API + if self.ep > 1: + flatten_mesh = [ + mesh.view(self.pp, batch, self.cp, self.tp), + mesh.view(self.pp, self.dp_replicate, fsdp, self.tp), + mesh.view(self.pp, batch * self.cp, self.tp), + mesh.view(self.pp, self.dp_replicate, efsdp, self.ep, self.etp), + ] + flattened_mesh_dim_names = ["batch", "fsdp", "loss", "ep"] + flatten_mesh_dim_names_map = { + "batch": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep"], + "fsdp": ["dp_shard_mod_ep", "dp_shard_in_ep", "cp"], + "loss": ["dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp"], + "ep": ["dp_shard_in_ep", "cp"] + if self.etp == self.tp + else ["dp_shard_in_ep", "cp", "tp"], + } + reshape_sizes = [batch, fsdp, batch * self.cp, self.ep] + else: + flatten_mesh = [ + mesh.view(self.pp, batch, self.cp, self.tp), + mesh.view(self.pp, self.dp_replicate, fsdp, self.tp), + mesh.view(self.pp, batch * self.cp, self.tp), + ] + flattened_mesh_dim_names = ["batch", "fsdp", "loss"] + flatten_mesh_dim_names_map = { + "batch": ["dp_replicate", "dp_shard"], + "fsdp": ["dp_shard", "cp"], + "loss": ["dp_replicate", "dp_shard", "cp"], + } + reshape_sizes = [batch, fsdp, batch * self.cp] flatten_ranks_per_dim = _calculate_ranks_per_dimension( - flatten_mesh, flattened_mesh_dim_names, reshape_size, cur_rank + flatten_mesh, flattened_mesh_dim_names, reshape_sizes, cur_rank ) _flatten_comms( flatten_ranks_per_dim, comm, - flatten_mesh_dim_names, + flatten_mesh_dim_names_map, device_mesh, comm_per_dim, ) + # Store world mesh + self._world_mesh = device_mesh + + # Build internal mesh references following ParallelDims convention + self._meshes = { + "pp": device_mesh["pp"], + "batch": device_mesh["batch"], + "loss": device_mesh["loss"], + "dp_replicate": device_mesh["dp_replicate"], + "fsdp": device_mesh["fsdp"], + "cp": device_mesh["cp"], + "tp": device_mesh["tp"], + } + + if self.ep > 1: + self._meshes["ep"] = device_mesh["ep"] + self._meshes["efsdp"] = ( + device_mesh["efsdp"] + if "efsdp" in device_mesh.mesh_dim_names + else device_mesh["dp_shard_mod_ep"] + ) + self._meshes["etp"] = ( + device_mesh["etp"] + if "etp" in device_mesh.mesh_dim_names + else device_mesh["tp"] + ) + else: + # Create fake meshes for EP-related dimensions when EP is not enabled + self._meshes["ep"] = device_mesh["pp"] # placeholder + self._meshes["efsdp"] = device_mesh["fsdp"] + self._meshes["etp"] = device_mesh["tp"] + + logger.info( + f"Successfully created torchcomms meshes with dimensions: " + f"{list(comm_per_dim.keys())}" + ) + # Call .finalize() in train.py after training but before destroying the process group # to release sub-communicators before the root communicator. self.comms = [*comm_per_dim.values(), comm]