Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 57 additions & 124 deletions autoparallel/cost_models/collective_runtime_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,151 +3,84 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast
"""
Communication cost estimation for DTensor redistribution.

NOTE: This module depends on PyTorch PR that adds `all_to_all_cost` and
`include_compute_cost` to `torch.distributed.tensor._collective_utils`.
See: pytorch/pytorch branch `cost-model-consolidation`
"""

import torch.distributed.tensor._dtensor_spec as dtensor_spec
from torch._prims_common import check_contiguous_sizes_strides
from torch.distributed.tensor._collective_utils import (
MeshTopoInfo,
all_to_all_cost,
allgather_cost,
allreduce_cost,
)
from torch.distributed.tensor._collective_utils import (
redistribute_cost as _pytorch_redistribute_cost,
)
from torch.distributed.tensor._collective_utils import (
reduce_scatter_cost,
spec_to_bytes,
)
from torch.distributed.tensor.placement_types import Partial, Shard

from .compute_estimation import compute_read_write_time


def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float:
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim]
mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim]
num_hops = num_devices_on_mesh_dim - 1
# base latency + comm latency
latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us
bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s
total_time = latency + bw * 1e6 # rescale to us
# FIXME: this is a hack, we need to spend some more effort on the cost model
total_time *= 5
return total_time


# this is a copy-paste from https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py
# with iteration order introduced
def redistribute_cost(
current_spec: "dtensor_spec.DTensorSpec",
target_spec: "dtensor_spec.DTensorSpec",
order: list[int],
order: list[int] | None = None,
) -> float:
"""
This function returns the cost of redistribute from current to target DTensorSpec.
Estimate the cost of redistributing from current to target DTensorSpec.

NOTE:
1. Only consider communication cost here, since computation costs for redistribute
are quite trivial (i.e. we only need to narrow or simple division)
2. Only consider redistribute cost on same mesh, cross mesh communication cost is
not quite needed for operator strategy estimation/selection.
"""
if current_spec.mesh != target_spec.mesh:
# make infinite cost if meshes are not same
# TODO: see if we want to support this once there's cross mesh communication
return float("inf")
This is a thin wrapper around PyTorch's redistribute_cost that enables
compute cost estimation by default (for accurate sharding strategy selection).

if current_spec.is_replicated():
# short-cut:
# comm cost is 0 if current spec is already full replication
# except if output is partial, which doesn't make sense for us
if any(p.is_partial() for p in target_spec.placements):
return float("inf")
return 0.0
Args:
current_spec: The current DTensorSpec.
target_spec: The target DTensorSpec.
order: Deprecated. Previously used for custom iteration order.
PyTorch now uses _gen_transform_infos for optimal ordering.

mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh)
cost = 0.0
comm_bytes_gb = (
spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024
)
# Transformation that considered for redistribute cost:
# 1. allgather 2. alltoall
# 3. allreduce 4. reduce_scatter
curr_placements = [current_spec.placements[i] for i in order]
tgt_placements = [target_spec.placements[i] for i in order]
is_contiguous: bool = check_contiguous_sizes_strides(
current_spec.shape, current_spec.stride
Returns:
The estimated cost of redistribution in microseconds.
"""
# Use PyTorch's upstream redistribute_cost with compute costs enabled
# This accounts for reshuffle overhead on non-dim-0 shards
return _pytorch_redistribute_cost(
current_spec,
target_spec,
include_compute_cost=True,
)
for i, current, target in zip(order, curr_placements, tgt_placements):
if current == target:
continue
num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i]
if not is_contiguous:
cost += compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
if current.is_shard() and target.is_replicate():
current = cast(Shard, current)
# allgather gives larger comm bytes
comm_bytes_gb *= num_devices_on_mesh_dim
# add up allgather comm cost
cost += allgather_cost(comm_bytes_gb, mesh_topo, i)
if current.dim != 0:
# penalize cases like S(1) -> R as there are additional compute cost
# which corresponds to reshuffling the whole output tensor
# we multiply the cost by 2 because we need to count input and output
# reads for the reshuffle
compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
cost += compute_cost
elif current.is_shard() and target.is_shard():
current = cast(Shard, current)
target = cast(Shard, target)
# should be alltoall comm, since we haven't implement it yet, add penalty
# to favor allgather instead
cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us

num_copies = 0
if current.dim != 0:
num_copies += 1

if target.dim != 0:
num_copies += 1

compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
cost += num_copies * compute_cost

elif current.is_partial() and target.is_replicate():
# add up allreduce comm cost
cost += allreduce_cost(comm_bytes_gb, mesh_topo, i)
elif current.is_partial() and target.is_shard():
target = cast(Shard, target)
# add up reduce_scatter comm cost
cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i)
if target.dim != 0:
# penalize cases like P -> S(1) as there are additional compute cost
# which corresponds to reshuffling the whole input tensor
# we multiply the cost by 2 because we need to count input and output
# reads for the reshuffle
compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3)
cost += compute_cost
# after reduce_scatter the comm bytes for further collectives halved.
comm_bytes_gb /= num_devices_on_mesh_dim
elif current.is_shard() and target.is_partial():
# ban shard -> partial as it does not make sense to perform
# this redistribute
return float("inf")
elif current.is_replicate() and target.is_partial():
# ban replicate -> partial as it does not make sense to perform
# this redistribute in our case
return float("inf")

# once we redistribute across one mesh dim, assume the output
# is now contiguous. This is generally the case for most operations,
# except when we fuse nd collectives into a 1d collective.
is_contiguous = True

return cost
def estimate_strategy_comms_cost(
src_spec: "dtensor_spec.DTensorSpec",
tgt_spec: "dtensor_spec.DTensorSpec",
) -> float:
"""
Estimate communication cost for a sharding strategy transition.

Args:
src_spec: Source DTensorSpec (current sharding).
tgt_spec: Target DTensorSpec (desired sharding).

def estimate_strategy_comms_cost(src_spec, tgt_spec):
order = list(range(src_spec.mesh.ndim))
if src_spec.placements == (Partial(), Partial()) and all(
p.is_shard() for p in tgt_spec.placements
):
order = [1, 0]
comms_cost = redistribute_cost(src_spec, tgt_spec, order)
return comms_cost
Returns:
Estimated communication cost in microseconds.
"""
return redistribute_cost(src_spec, tgt_spec)


# Re-export for convenience
__all__ = [
"redistribute_cost",
"estimate_strategy_comms_cost",
"all_to_all_cost",
"allgather_cost",
"allreduce_cost",
"reduce_scatter_cost",
"MeshTopoInfo",
"spec_to_bytes",
]
151 changes: 151 additions & 0 deletions tests/test_cost_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""Tests for communication cost estimation module."""

import unittest

import torch
from torch.distributed.tensor import DeviceMesh, Replicate, Shard
from torch.distributed.tensor._collective_utils import MeshTopoInfo
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.testing._internal.distributed.fake_pg import FakeStore

from autoparallel.cost_models.collective_runtime_estimation import (
all_to_all_cost,
estimate_strategy_comms_cost,
redistribute_cost,
)


def extract_tensor_meta(t: torch.Tensor) -> TensorMeta:
return TensorMeta(t.shape, t.stride(), t.dtype)


class TestCollectiveRuntimeEstimation(unittest.TestCase):
"""Test communication cost estimation functions."""

@classmethod
def setUpClass(cls):
"""Set up fake distributed environment."""
cls.world_size = 4
store = FakeStore()
torch.distributed.init_process_group(
backend="fake",
rank=0,
world_size=cls.world_size,
store=store,
)
cls.mesh = DeviceMesh("cuda", list(range(cls.world_size)))

@classmethod
def tearDownClass(cls):
"""Tear down distributed environment."""
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

def test_redistribute_cost_basic(self):
"""Test basic redistribute cost estimation."""
shard_placement = (Shard(0),)
replica_placement = (Replicate(),)

tensor = torch.randn(8, 8)
tensor_meta = extract_tensor_meta(tensor)

shard_spec = DTensorSpec(self.mesh, shard_placement, tensor_meta)
replica_spec = DTensorSpec(self.mesh, replica_placement, tensor_meta)

# Same spec should have zero cost
cost_same = redistribute_cost(shard_spec, shard_spec)
self.assertEqual(cost_same, 0)

# Shard -> Replicate should have positive cost
cost_allgather = redistribute_cost(shard_spec, replica_spec)
self.assertGreater(cost_allgather, 0)

def test_redistribute_cost_with_compute_overhead(self):
"""Test that non-dim-0 shards include compute overhead."""
shard0_placement = (Shard(0),)
shard1_placement = (Shard(1),)
replica_placement = (Replicate(),)

tensor = torch.randn(8, 8)
tensor_meta = extract_tensor_meta(tensor)

shard0_spec = DTensorSpec(self.mesh, shard0_placement, tensor_meta)
shard1_spec = DTensorSpec(self.mesh, shard1_placement, tensor_meta)
replica_spec = DTensorSpec(self.mesh, replica_placement, tensor_meta)

# Shard(0) -> Replicate (no reshuffle needed)
cost_dim0 = redistribute_cost(shard0_spec, replica_spec)
# Shard(1) -> Replicate (reshuffle needed)
cost_dim1 = redistribute_cost(shard1_spec, replica_spec)

# Shard(1) -> Replicate should be more expensive due to reshuffle
self.assertGreater(cost_dim1, cost_dim0)

def test_all_to_all_cost(self):
"""Test all_to_all_cost function."""
mesh_topo = MeshTopoInfo.build_from_mesh(self.mesh)

# Test with 1MB
cost = all_to_all_cost(0.001, mesh_topo, 0)
self.assertGreater(cost, 0)

# Larger tensor should have higher cost
cost_larger = all_to_all_cost(0.01, mesh_topo, 0)
self.assertGreater(cost_larger, cost)

def test_shard_to_shard_uses_all_to_all(self):
"""Test that shard->shard transitions have reasonable cost."""
shard0_placement = (Shard(0),)
shard1_placement = (Shard(1),)

tensor = torch.randn(8, 8)
tensor_meta = extract_tensor_meta(tensor)

shard0_spec = DTensorSpec(self.mesh, shard0_placement, tensor_meta)
shard1_spec = DTensorSpec(self.mesh, shard1_placement, tensor_meta)

# Shard(0) -> Shard(1) should use all_to_all
cost = redistribute_cost(shard0_spec, shard1_spec)
self.assertGreater(cost, 0)
self.assertNotEqual(cost, float("inf"))

def test_estimate_strategy_comms_cost(self):
"""Test estimate_strategy_comms_cost wrapper."""
shard_placement = (Shard(0),)
replica_placement = (Replicate(),)

tensor = torch.randn(8, 8)
tensor_meta = extract_tensor_meta(tensor)

shard_spec = DTensorSpec(self.mesh, shard_placement, tensor_meta)
replica_spec = DTensorSpec(self.mesh, replica_placement, tensor_meta)

cost = estimate_strategy_comms_cost(shard_spec, replica_spec)
expected_cost = redistribute_cost(shard_spec, replica_spec)
self.assertEqual(cost, expected_cost)

def test_order_parameter_deprecated(self):
"""Test that order parameter is accepted but ignored."""
shard_placement = (Shard(0),)
replica_placement = (Replicate(),)

tensor = torch.randn(8, 8)
tensor_meta = extract_tensor_meta(tensor)

shard_spec = DTensorSpec(self.mesh, shard_placement, tensor_meta)
replica_spec = DTensorSpec(self.mesh, replica_placement, tensor_meta)

# Should accept order parameter without error
cost_with_order = redistribute_cost(shard_spec, replica_spec, order=[0])
cost_without_order = redistribute_cost(shard_spec, replica_spec)
# Results should be the same (order is ignored)
self.assertEqual(cost_with_order, cost_without_order)


if __name__ == "__main__":
unittest.main()
Loading