Skip to content

Commit

Permalink
Merge branch 'develop' into chore/bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Jan 1, 2024
2 parents c0aa674 + 3ebc34c commit ceb19d2
Show file tree
Hide file tree
Showing 6 changed files with 496 additions and 110 deletions.
6 changes: 4 additions & 2 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def init(
G_data: int,
G_intra_r: int = 1,
G_intra_c: int = 1,
G_intra_d: int = 1,
gpus_per_node: Optional[int] = None,
mixed_precision=False,
float16_allreduce=True,
Expand Down Expand Up @@ -128,13 +129,14 @@ def init(
global comm_handle, is_initialized, computation_dtype, _float16_all_reduce
global _cpu_offload, _use_bf16, _mixed_precision, loss_scale
comm_handle = communication_handle(
G_inter, G_data, G_intra_r, G_intra_c, gpus_per_node
G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node
)
config.G_inter = G_inter
config.G_data = G_data
config.G_intra = G_intra_r * G_intra_c
config.G_intra = G_intra_r * G_intra_c * G_intra_d
config.G_intra_r = G_intra_r
config.G_intra_c = G_intra_c
config.G_intra_d = G_intra_d
config.inter_layer_parallel_rank = comm_handle.inter_layer_parallel_rank
config.data_parallel_rank = comm_handle.data_parallel_rank
config.intra_layer_parallel_rank = comm_handle.intra_layer_parallel_rank
Expand Down
138 changes: 102 additions & 36 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
from mpi4py import MPI

try:
from mpi4py import MPI

MPI4PY = True
except ImportError:
MPI4PY = False
import torch
import numpy as np


class communication_handle:
Expand All @@ -15,7 +22,13 @@ class communication_handle:
"""

def __init__(
self, G_inter: int, G_data: int, G_intra_r=1, G_intra_c=1, gpus_per_node=None
self,
G_inter: int,
G_data: int,
G_intra_r=1,
G_intra_c=1,
G_intra_d=1,
gpus_per_node=None,
):
"""Constructor for the communication handle
Expand All @@ -24,22 +37,30 @@ def __init__(
G_data (int): number of GPUs used for data parallelism
gpus_per_node (int, optional): number of GPUs per node, if not
provided this is inferred using pytorch
G_intra (int): degree of intra-layer parallelism. Note that
the user is supposed to implement their intra-layer parallel
kernels. AxoNN will just create communicationgroups for
intra-layer parallelism
G_intra_r (int): number of GPUs in the row intra-layer parallel dimension
G_intra_c (int): number of GPUs in the column intra-layer parallel dimension
G_intra_d (int): number of GPUs in the depth intra-layer parallel dimension
"""
self.world_rank = MPI.COMM_WORLD.Get_rank()
self.world_size = MPI.COMM_WORLD.Get_size()
G_intra = G_intra_r * G_intra_c
if not torch.distributed.is_initialized():
assert MPI4PY, "either install mpi4py and launch via mpirun/srun"
"or initialize torch.distributed outside axonn"
self.world_rank = MPI.COMM_WORLD.Get_rank()
self.world_size = MPI.COMM_WORLD.Get_size()
else:
self.world_rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()

G_intra = G_intra_r * G_intra_c * G_intra_d
assert (
G_inter * G_data * G_intra == self.world_size
), "The product of G_inter and G_data should be equal to the number of GPUs"
), "The product of G_inter, G_intra_r, G_intra_c, G_intra_d,"
f"G_data should be equal to the number of GPUs - {self.world_size}"
self.G_intra = G_intra
self.G_inter = G_inter
self.G_data = G_data
self.G_intra_r = G_intra_r
self.G_intra_c = G_intra_c
self.G_intra_d = G_intra_d

# infer gpus per node if not provided
self.gpus_per_node = (
Expand All @@ -51,15 +72,34 @@ def __init__(
self.intra_layer_column_parallel_rank = (
self.intra_layer_parallel_rank % G_intra_c
)
self.intra_layer_row_parallel_rank = self.intra_layer_parallel_rank // G_intra_c
self.intra_layer_row_parallel_rank = (
self.intra_layer_parallel_rank // G_intra_c
) % G_intra_r
self.intra_layer_depth_parallel_rank = self.intra_layer_parallel_rank // (
G_intra_c * G_intra_r
)

self.inter_layer_parallel_rank = (self.world_rank // G_intra) % G_inter
self.data_parallel_rank = self.world_rank // (G_inter * G_intra)

# create communicator for point-to-point(MPI) communication
colour = self.intra_layer_parallel_rank + G_intra * self.data_parallel_rank
# this needs to be checked
self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour)
assert self.p2p_mpi_comm.Get_size() == G_inter

if G_inter > 1:
# this needs to be checked
if MPI4PY:
self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour)
assert self.p2p_mpi_comm.Get_size() == G_inter
else:
self.p2p_mpi_comm = None
print(
"Warning: AxoNN's implementation of inter-layer"
"parallelism (pipelining) requires mpi4py, which wasn't found."
"You will have to use an external implementation"
"of pipeline parallelism."
)
else:
self.p2p_mpi_comm = None

# create communicator for collective (NCCL) communication
if not torch.distributed.is_initialized():
Expand Down Expand Up @@ -89,37 +129,63 @@ def __init__(
self.coll_nccl_comm = ith_jth_data_parallel_group

# create communicators for intra-layer parallelism
for i in range(G_data):
for j in range(G_inter):
for i_ in range(G_data):
for j_ in range(G_inter):
ranks_in_ith_jth_intra_layer_group = [
i * G_inter * G_intra + j * G_intra + k for k in range(G_intra)
i_ * G_inter * G_intra + j_ * G_intra + k for k in range(G_intra)
]

ith_jth_intra_layer_group = torch.distributed.new_group(
ranks=ranks_in_ith_jth_intra_layer_group, backend="nccl"
)
if self.world_rank in ranks_in_ith_jth_intra_layer_group:
self.intra_layer_group = ith_jth_intra_layer_group

assert (
len(ranks_in_ith_jth_intra_layer_group)
== G_intra_r * G_intra_c * G_intra_d
)

ranks_in_ith_jth_intra_layer_group = np.array(
ranks_in_ith_jth_intra_layer_group
).reshape(G_intra_d, G_intra_r, G_intra_c)
# form row and column tensor parallel groups
# G_intra_r x G_intra_c
assert len(ranks_in_ith_jth_intra_layer_group) == G_intra_r * G_intra_c
intra_layer_ranks = ranks_in_ith_jth_intra_layer_group
# G_intra_d x G_intra_r x G_intra_c

# inner
for i in range(G_intra_d):
for j in range(G_intra_r):
group_members = list(
ranks_in_ith_jth_intra_layer_group[i, j, :]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.inner_intra_layer_parallel_group = group

# outer
for i in range(G_intra_d):
for j in range(G_intra_c):
group_members = list(
ranks_in_ith_jth_intra_layer_group[i, :, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.outer_intra_layer_parallel_group = group

# depth
for i in range(G_intra_r):
offset = i * G_intra_c
group_members = intra_layer_ranks[offset : offset + G_intra_c]
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.inner_intra_layer_parallel_group = group

for i in range(G_intra_c):
group_members = intra_layer_ranks[i::G_intra_c]
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.outer_intra_layer_parallel_group = group
for j in range(G_intra_c):
group_members = list(
ranks_in_ith_jth_intra_layer_group[:, i, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
)
if self.world_rank in group_members:
self.depth_intra_layer_parallel_group = group

def _torch_to_mpi(self, tensor: torch.Tensor):
"""Converts a PyTorch tensor into an mpi4py compatible array using its
Expand Down
142 changes: 138 additions & 4 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,159 @@
from contextlib import contextmanager
from .fully_connected import Linear # noqa: F401
from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401

from .communication import Drop, Gather
from .gradient_normalization import clip_grad_norm_ # noqa: F401

from axonn import axonn as ax
import torch
import torch.distributed as dist


def drop(x, transpose=False, dim=-1):
def drop(x, transpose=False, dim=-1, batch_dim=0):
if not transpose:
group = ax.comm_handle.inner_intra_layer_parallel_group
else:
group = ax.comm_handle.outer_intra_layer_parallel_group

return Drop.apply(x, group, dim)
x = Drop.apply(x, group, dim)
x = Drop.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim)
return x


def gather(x, transpose=False, dim=-1):
def gather(x, transpose=False, dim=-1, batch_dim=0):
if not transpose:
group = ax.comm_handle.inner_intra_layer_parallel_group
else:
group = ax.comm_handle.outer_intra_layer_parallel_group

return Gather.apply(x, group, dim)
x = Gather.apply(x, group, dim)
x = Gather.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim)
return x


OVERLAP_REDUCE_SCATTER = False
OVERLAP_ALL_REDUCE = False
CACHE_WEIGHTS = False
ALL_GATHER_ITERATOR = None
handles = []
pending_grad_accumulations = []
weights_cache = {}


def register_handle(handle):
# ToDo: This might be unnecesary since
# we are calling synchronize in clear_handles
global handles
handles.append(handle)


def clear_handles():
global handles
torch.cuda.synchronize()
handles = []


def accumulate_later(param, grad):
global pending_grad_accumulations
pending_grad_accumulations.append([param, grad])


@torch.no_grad()
def accumulate():
global pending_grad_accumulations
for param, grad in pending_grad_accumulations:
if param.grad is None:
param.grad = grad
else:
param.grad.add_(grad)

pending_grad_accumulations = []


def clear_weights_cache():
global weights_cache
weights_cache = {}


def trigger_async_all_gathers(model):
global weights_cache
for module in model.modules():
if isinstance(module, Linear):
weight = module.weight
if weight not in weights_cache:
# only trigger all gathers if not in cache
process_group = module.depth_group
world_size = dist.get_world_size(process_group)
if world_size == 1:
all_gathered_weight = weight
handle = None
else:
assert weight.ndim == 1
output_shape = weight.shape[0] * world_size
all_gathered_weight = torch.empty(
output_shape, dtype=weight.dtype, device=weight.device
)
handle = dist.all_gather_into_tensor(
all_gathered_weight, weight, group=process_group, async_op=True
)
weights_cache[weight] = [all_gathered_weight, handle]
yield


def enqueue_next_all_gather():
global ALL_GATHER_ITERATOR
assert ALL_GATHER_ITERATOR is not None
try:
next(ALL_GATHER_ITERATOR)
except StopIteration:
pass


def retrieve_all_gathered_weight(weight):
global CACHE_WEIGHTS, ALL_GATHER_ITERATOR
assert weight in weights_cache
all_gathered_weight, handle = weights_cache[weight]
if ALL_GATHER_ITERATOR is not None:
enqueue_next_all_gather()
return all_gathered_weight, handle


@contextmanager
def optimize_communication(
overlap_all_reduce=True,
overlap_reduce_scatter=False,
cache_weights=False,
overlap_all_gather=False,
model=None,
*args,
**kwargs
):
global OVERLAP_ALL_REDUCE, OVERLAP_REDUCE_SCATTER, CACHE_WEIGHTS
global ALL_GATHER_ITERATOR
OVERLAP_ALL_REDUCE = overlap_all_reduce
OVERLAP_REDUCE_SCATTER = overlap_reduce_scatter

CACHE_WEIGHTS = cache_weights

if overlap_all_gather:
if model is None:
raise ValueError(
"You need to pass your model as an argument - "
"optimize_communication(...,model=model, ...)"
"if overlap_all_gather is True"
)
assert (
cache_weights
), "all gathers can only be overlapped if cache_weights is True"
ALL_GATHER_ITERATOR = trigger_async_all_gathers(model)
enqueue_next_all_gather()

try:
yield None
finally:
clear_handles()
accumulate()
OVERLAP_ALL_REDUCE = False
OVERLAP_REDUCE_SCATTER = False
ALL_GATHER_ITERATOR = None
Loading

0 comments on commit ceb19d2

Please sign in to comment.