From f7becc606ef7e586d6c21e83e1835dec51f59804 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Mon, 26 Jan 2026 13:16:34 -0800 Subject: [PATCH] Use upstream PyTorch cost models instead of local copy DEPENDENCY: Requires pytorch/pytorch branch `cost-model-consolidation` which adds `all_to_all_cost` and `include_compute_cost` to `torch.distributed.tensor._collective_utils`. Changes: - Remove local `all_to_all_cost` function (now imported from PyTorch) - Remove copy-pasted `redistribute_cost` function - Use PyTorch's `redistribute_cost` with `include_compute_cost=True` - Deprecate `order` parameter (PyTorch now uses _gen_transform_infos for optimal ordering) - Add comprehensive tests for cost model functions This consolidation: - Reduces code duplication between PyTorch and Autoparallel - Ensures cost models stay in sync across projects - Removes the 5x hack in all_to_all_cost (calibration moved to PyTorch) Authored with Claude. --- .../collective_runtime_estimation.py | 181 ++++++------------ tests/test_cost_models.py | 151 +++++++++++++++ 2 files changed, 208 insertions(+), 124 deletions(-) create mode 100644 tests/test_cost_models.py diff --git a/autoparallel/cost_models/collective_runtime_estimation.py b/autoparallel/cost_models/collective_runtime_estimation.py index cbffcd13..ef2c57e6 100644 --- a/autoparallel/cost_models/collective_runtime_estimation.py +++ b/autoparallel/cost_models/collective_runtime_estimation.py @@ -3,151 +3,84 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. -from typing import cast +""" +Communication cost estimation for DTensor redistribution. + +NOTE: This module depends on PyTorch PR that adds `all_to_all_cost` and +`include_compute_cost` to `torch.distributed.tensor._collective_utils`. +See: pytorch/pytorch branch `cost-model-consolidation` +""" import torch.distributed.tensor._dtensor_spec as dtensor_spec -from torch._prims_common import check_contiguous_sizes_strides from torch.distributed.tensor._collective_utils import ( MeshTopoInfo, + all_to_all_cost, allgather_cost, allreduce_cost, +) +from torch.distributed.tensor._collective_utils import ( + redistribute_cost as _pytorch_redistribute_cost, +) +from torch.distributed.tensor._collective_utils import ( reduce_scatter_cost, spec_to_bytes, ) -from torch.distributed.tensor.placement_types import Partial, Shard -from .compute_estimation import compute_read_write_time - -def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: - num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] - mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] - num_hops = num_devices_on_mesh_dim - 1 - # base latency + comm latency - latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us - bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s - total_time = latency + bw * 1e6 # rescale to us - # FIXME: this is a hack, we need to spend some more effort on the cost model - total_time *= 5 - return total_time - - -# this is a copy-paste from https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py -# with iteration order introduced def redistribute_cost( current_spec: "dtensor_spec.DTensorSpec", target_spec: "dtensor_spec.DTensorSpec", - order: list[int], + order: list[int] | None = None, ) -> float: """ - This function returns the cost of redistribute from current to target DTensorSpec. + Estimate the cost of redistributing from current to target DTensorSpec. - NOTE: - 1. Only consider communication cost here, since computation costs for redistribute - are quite trivial (i.e. we only need to narrow or simple division) - 2. Only consider redistribute cost on same mesh, cross mesh communication cost is - not quite needed for operator strategy estimation/selection. - """ - if current_spec.mesh != target_spec.mesh: - # make infinite cost if meshes are not same - # TODO: see if we want to support this once there's cross mesh communication - return float("inf") + This is a thin wrapper around PyTorch's redistribute_cost that enables + compute cost estimation by default (for accurate sharding strategy selection). - if current_spec.is_replicated(): - # short-cut: - # comm cost is 0 if current spec is already full replication - # except if output is partial, which doesn't make sense for us - if any(p.is_partial() for p in target_spec.placements): - return float("inf") - return 0.0 + Args: + current_spec: The current DTensorSpec. + target_spec: The target DTensorSpec. + order: Deprecated. Previously used for custom iteration order. + PyTorch now uses _gen_transform_infos for optimal ordering. - mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) - cost = 0.0 - comm_bytes_gb = ( - spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 - ) - # Transformation that considered for redistribute cost: - # 1. allgather 2. alltoall - # 3. allreduce 4. reduce_scatter - curr_placements = [current_spec.placements[i] for i in order] - tgt_placements = [target_spec.placements[i] for i in order] - is_contiguous: bool = check_contiguous_sizes_strides( - current_spec.shape, current_spec.stride + Returns: + The estimated cost of redistribution in microseconds. + """ + # Use PyTorch's upstream redistribute_cost with compute costs enabled + # This accounts for reshuffle overhead on non-dim-0 shards + return _pytorch_redistribute_cost( + current_spec, + target_spec, + include_compute_cost=True, ) - for i, current, target in zip(order, curr_placements, tgt_placements): - if current == target: - continue - num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] - if not is_contiguous: - cost += compute_read_write_time(comm_bytes_gb * 2 * 1024**3) - if current.is_shard() and target.is_replicate(): - current = cast(Shard, current) - # allgather gives larger comm bytes - comm_bytes_gb *= num_devices_on_mesh_dim - # add up allgather comm cost - cost += allgather_cost(comm_bytes_gb, mesh_topo, i) - if current.dim != 0: - # penalize cases like S(1) -> R as there are additional compute cost - # which corresponds to reshuffling the whole output tensor - # we multiply the cost by 2 because we need to count input and output - # reads for the reshuffle - compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3) - cost += compute_cost - elif current.is_shard() and target.is_shard(): - current = cast(Shard, current) - target = cast(Shard, target) - # should be alltoall comm, since we haven't implement it yet, add penalty - # to favor allgather instead - cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us - - num_copies = 0 - if current.dim != 0: - num_copies += 1 - - if target.dim != 0: - num_copies += 1 - - compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3) - cost += num_copies * compute_cost - elif current.is_partial() and target.is_replicate(): - # add up allreduce comm cost - cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) - elif current.is_partial() and target.is_shard(): - target = cast(Shard, target) - # add up reduce_scatter comm cost - cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) - if target.dim != 0: - # penalize cases like P -> S(1) as there are additional compute cost - # which corresponds to reshuffling the whole input tensor - # we multiply the cost by 2 because we need to count input and output - # reads for the reshuffle - compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3) - cost += compute_cost - # after reduce_scatter the comm bytes for further collectives halved. - comm_bytes_gb /= num_devices_on_mesh_dim - elif current.is_shard() and target.is_partial(): - # ban shard -> partial as it does not make sense to perform - # this redistribute - return float("inf") - elif current.is_replicate() and target.is_partial(): - # ban replicate -> partial as it does not make sense to perform - # this redistribute in our case - return float("inf") - # once we redistribute across one mesh dim, assume the output - # is now contiguous. This is generally the case for most operations, - # except when we fuse nd collectives into a 1d collective. - is_contiguous = True - - return cost +def estimate_strategy_comms_cost( + src_spec: "dtensor_spec.DTensorSpec", + tgt_spec: "dtensor_spec.DTensorSpec", +) -> float: + """ + Estimate communication cost for a sharding strategy transition. + Args: + src_spec: Source DTensorSpec (current sharding). + tgt_spec: Target DTensorSpec (desired sharding). -def estimate_strategy_comms_cost(src_spec, tgt_spec): - order = list(range(src_spec.mesh.ndim)) - if src_spec.placements == (Partial(), Partial()) and all( - p.is_shard() for p in tgt_spec.placements - ): - order = [1, 0] - comms_cost = redistribute_cost(src_spec, tgt_spec, order) - return comms_cost + Returns: + Estimated communication cost in microseconds. + """ + return redistribute_cost(src_spec, tgt_spec) + + +# Re-export for convenience +__all__ = [ + "redistribute_cost", + "estimate_strategy_comms_cost", + "all_to_all_cost", + "allgather_cost", + "allreduce_cost", + "reduce_scatter_cost", + "MeshTopoInfo", + "spec_to_bytes", +] diff --git a/tests/test_cost_models.py b/tests/test_cost_models.py new file mode 100644 index 00000000..2799aa76 --- /dev/null +++ b/tests/test_cost_models.py @@ -0,0 +1,151 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for communication cost estimation module.""" + +import unittest + +import torch +from torch.distributed.tensor import DeviceMesh, Replicate, Shard +from torch.distributed.tensor._collective_utils import MeshTopoInfo +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.testing._internal.distributed.fake_pg import FakeStore + +from autoparallel.cost_models.collective_runtime_estimation import ( + all_to_all_cost, + estimate_strategy_comms_cost, + redistribute_cost, +) + + +def extract_tensor_meta(t: torch.Tensor) -> TensorMeta: + return TensorMeta(t.shape, t.stride(), t.dtype) + + +class TestCollectiveRuntimeEstimation(unittest.TestCase): + """Test communication cost estimation functions.""" + + @classmethod + def setUpClass(cls): + """Set up fake distributed environment.""" + cls.world_size = 4 + store = FakeStore() + torch.distributed.init_process_group( + backend="fake", + rank=0, + world_size=cls.world_size, + store=store, + ) + cls.mesh = DeviceMesh("cuda", list(range(cls.world_size))) + + @classmethod + def tearDownClass(cls): + """Tear down distributed environment.""" + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + def test_redistribute_cost_basic(self): + """Test basic redistribute cost estimation.""" + shard_placement = (Shard(0),) + replica_placement = (Replicate(),) + + tensor = torch.randn(8, 8) + tensor_meta = extract_tensor_meta(tensor) + + shard_spec = DTensorSpec(self.mesh, shard_placement, tensor_meta) + replica_spec = DTensorSpec(self.mesh, replica_placement, tensor_meta) + + # Same spec should have zero cost + cost_same = redistribute_cost(shard_spec, shard_spec) + self.assertEqual(cost_same, 0) + + # Shard -> Replicate should have positive cost + cost_allgather = redistribute_cost(shard_spec, replica_spec) + self.assertGreater(cost_allgather, 0) + + def test_redistribute_cost_with_compute_overhead(self): + """Test that non-dim-0 shards include compute overhead.""" + shard0_placement = (Shard(0),) + shard1_placement = (Shard(1),) + replica_placement = (Replicate(),) + + tensor = torch.randn(8, 8) + tensor_meta = extract_tensor_meta(tensor) + + shard0_spec = DTensorSpec(self.mesh, shard0_placement, tensor_meta) + shard1_spec = DTensorSpec(self.mesh, shard1_placement, tensor_meta) + replica_spec = DTensorSpec(self.mesh, replica_placement, tensor_meta) + + # Shard(0) -> Replicate (no reshuffle needed) + cost_dim0 = redistribute_cost(shard0_spec, replica_spec) + # Shard(1) -> Replicate (reshuffle needed) + cost_dim1 = redistribute_cost(shard1_spec, replica_spec) + + # Shard(1) -> Replicate should be more expensive due to reshuffle + self.assertGreater(cost_dim1, cost_dim0) + + def test_all_to_all_cost(self): + """Test all_to_all_cost function.""" + mesh_topo = MeshTopoInfo.build_from_mesh(self.mesh) + + # Test with 1MB + cost = all_to_all_cost(0.001, mesh_topo, 0) + self.assertGreater(cost, 0) + + # Larger tensor should have higher cost + cost_larger = all_to_all_cost(0.01, mesh_topo, 0) + self.assertGreater(cost_larger, cost) + + def test_shard_to_shard_uses_all_to_all(self): + """Test that shard->shard transitions have reasonable cost.""" + shard0_placement = (Shard(0),) + shard1_placement = (Shard(1),) + + tensor = torch.randn(8, 8) + tensor_meta = extract_tensor_meta(tensor) + + shard0_spec = DTensorSpec(self.mesh, shard0_placement, tensor_meta) + shard1_spec = DTensorSpec(self.mesh, shard1_placement, tensor_meta) + + # Shard(0) -> Shard(1) should use all_to_all + cost = redistribute_cost(shard0_spec, shard1_spec) + self.assertGreater(cost, 0) + self.assertNotEqual(cost, float("inf")) + + def test_estimate_strategy_comms_cost(self): + """Test estimate_strategy_comms_cost wrapper.""" + shard_placement = (Shard(0),) + replica_placement = (Replicate(),) + + tensor = torch.randn(8, 8) + tensor_meta = extract_tensor_meta(tensor) + + shard_spec = DTensorSpec(self.mesh, shard_placement, tensor_meta) + replica_spec = DTensorSpec(self.mesh, replica_placement, tensor_meta) + + cost = estimate_strategy_comms_cost(shard_spec, replica_spec) + expected_cost = redistribute_cost(shard_spec, replica_spec) + self.assertEqual(cost, expected_cost) + + def test_order_parameter_deprecated(self): + """Test that order parameter is accepted but ignored.""" + shard_placement = (Shard(0),) + replica_placement = (Replicate(),) + + tensor = torch.randn(8, 8) + tensor_meta = extract_tensor_meta(tensor) + + shard_spec = DTensorSpec(self.mesh, shard_placement, tensor_meta) + replica_spec = DTensorSpec(self.mesh, replica_placement, tensor_meta) + + # Should accept order parameter without error + cost_with_order = redistribute_cost(shard_spec, replica_spec, order=[0]) + cost_without_order = redistribute_cost(shard_spec, replica_spec) + # Results should be the same (order is ignored) + self.assertEqual(cost_with_order, cost_without_order) + + +if __name__ == "__main__": + unittest.main()