Skip to content

Commit

Permalink
Adding axonn_strategy.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Anish Bhupalam committed May 29, 2024
1 parent 8f2c98c commit ca8deb2
Showing 1 changed file with 171 additions and 0 deletions.
171 changes: 171 additions & 0 deletions axonn/axonn_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from contextlib import nullcontext
from datetime import timedelta
from typing import Any, ContextManager, Dict, List, Literal, Optional, Union

import torch
import torch.distributed
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from typing_extensions import override

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import TBroadcast
from lightning.fabric.utilities.distributed import (
ReduceOp,
_distributed_is_initialized,
_get_default_process_group_backend_for_device,
_init_dist_connection,
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.rank_zero import rank_zero_only
from axonn import axonn as ax


class AxonnStrategy(ParallelStrategy):

def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
G_data: int = 1,
G_inter: int = 1,
G_intra_r: int = 1,
G_intra_c: int = 1,
G_intra_d: int = 1,
**kwargs: Any,
) -> None:
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision=precision,
)

self._num_nodes = 1

self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self.G_data = G_data
self.G_inter = G_inter
self.G_intra_r = G_intra_r
self.G_intra_c = G_intra_c
self.G_intra_d = G_intra_d
self._kwargs = kwargs

@property
@override
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]

@property
def num_nodes(self) -> int:
return self._num_nodes

@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
self._num_nodes = num_nodes

@property
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
@override
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
return {"num_replicas": ax.config.G_intra_d * ax.config.G_data,
"rank": ax.config.G_intra_d * ax.config.data_parallel_rank + ax.config.intra_layer_depth_parallel_rank}

@property
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend

@override
def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)

@override
def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()

@override
def setup_module(self, module: Module):
return module # use autoparallelize later

@override
def module_to_device(self, module: Module) -> None:
module.to(self.root_device)

@override
def all_reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
if isinstance(tensor, Tensor):
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

@override
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not _distributed_is_initialized():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self._determine_device_ids())
else:
torch.distributed.barrier()

@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not _distributed_is_initialized():
return obj

obj = [obj]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

@classmethod
@override
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
pass

def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

ax.init(
G_data=self.G_data,
G_inter=self.G_inter,
G_intra_r=self.G_intra_r,
G_intra_c=self.G_intra_c,
G_intra_d=self.G_intra_d,
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

def _set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank

def _determine_device_ids(self) -> Optional[List[int]]:
return None if self.root_device.type == "cpu" else [self.root_device.index]

0 comments on commit ca8deb2

Please sign in to comment.