Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion autoparallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,11 @@
# LICENSE file in the root directory of this source tree.

from autoparallel.api import AutoParallel, AutoParallelPP, auto_parallel
from autoparallel.collectives import with_sharding_constraint

__all__ = ["auto_parallel", "AutoParallel", "AutoParallelPP"]
__all__ = [
"auto_parallel",
"AutoParallel",
"AutoParallelPP",
"with_sharding_constraint",
]
47 changes: 46 additions & 1 deletion autoparallel/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,63 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional
from typing import Any, Optional, Tuple

import torch
import torch.distributed.distributed_c10d as c10d
from torch.distributed._tensor.experimental import local_map as _local_map
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Placement

# Import GroupName for type checking
GroupName = c10d.GroupName

_local_map_device_mesh = None


def with_sharding_constraint(
x: torch.Tensor,
shardings: Tuple[Placement, ...],
device_mesh: Optional[DeviceMesh] = None,
) -> torch.Tensor:
"""Constrain the sharding of an intermediate tensor.

Similar to JAX's with_sharding_constraint, this constrains the sharding
of a tensor to a specific placement. This is useful for controlling
intermediate tensor shardings within a computation.

Args:
x: The tensor to constrain.
shardings: Tuple of placements specifying how the tensor should be
sharded across each mesh dimension.
device_mesh: The device mesh to use. If None, uses the mesh from
the enclosing local_map region.

Returns:
The tensor with the specified sharding constraint applied.

Example:
>>> from torch.distributed.tensor.placement_types import Shard, Replicate
>>> # Inside a local_map region or with explicit mesh:
>>> x = with_sharding_constraint(x, (Shard(0), Replicate()))
"""
if device_mesh is None:
device_mesh = get_mesh_from_global()

@_local_map(
out_placements=(shardings,),
in_placements=(shardings,),
redistribute_inputs=True,
device_mesh=device_mesh,
)
def identity(t):
# clone() is required because local_map HOP doesn't support
# input-to-output aliasing during dynamo tracing
return t.clone()

return identity(x)


def local_map(*args, **kwargs):
# TODO: ideally after we get out of the local map region we should
# just reset the global device mesh to None. For now we just keep it
Expand Down
Loading
Loading