Skip to content

Commit

Permalink
A context manager to optimize communication (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Nov 29, 2023
1 parent d1144ee commit 3ebc34c
Show file tree
Hide file tree
Showing 5 changed files with 329 additions and 57 deletions.
40 changes: 32 additions & 8 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
from mpi4py import MPI

try:
from mpi4py import MPI

MPI4PY = True
except ImportError:
MPI4PY = False
import torch
import numpy as np

Expand Down Expand Up @@ -35,8 +41,15 @@ def __init__(
G_intra_c (int): number of GPUs in the column intra-layer parallel dimension
G_intra_d (int): number of GPUs in the depth intra-layer parallel dimension
"""
self.world_rank = MPI.COMM_WORLD.Get_rank()
self.world_size = MPI.COMM_WORLD.Get_size()
if not torch.distributed.is_initialized():
assert MPI4PY, "either install mpi4py and launch via mpirun/srun"
"or initialize torch.distributed outside axonn"
self.world_rank = MPI.COMM_WORLD.Get_rank()
self.world_size = MPI.COMM_WORLD.Get_size()
else:
self.world_rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()

G_intra = G_intra_r * G_intra_c * G_intra_d
assert (
G_inter * G_data * G_intra == self.world_size
Expand Down Expand Up @@ -71,9 +84,22 @@ def __init__(

# create communicator for point-to-point(MPI) communication
colour = self.intra_layer_parallel_rank + G_intra * self.data_parallel_rank
# this needs to be checked
self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour)
assert self.p2p_mpi_comm.Get_size() == G_inter

if G_inter > 1:
# this needs to be checked
if MPI4PY:
self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour)
assert self.p2p_mpi_comm.Get_size() == G_inter
else:
self.p2p_mpi_comm = None
print(
"Warning: AxoNN's implementation of inter-layer"
"parallelism (pipelining) requires mpi4py, which wasn't found."
"You will have to use an external implementation"
"of pipeline parallelism."
)
else:
self.p2p_mpi_comm = None

# create communicator for collective (NCCL) communication
if not torch.distributed.is_initialized():
Expand Down Expand Up @@ -103,13 +129,11 @@ def __init__(
self.coll_nccl_comm = ith_jth_data_parallel_group

# create communicators for intra-layer parallelism
print(G_data, G_inter, G_intra)
for i_ in range(G_data):
for j_ in range(G_inter):
ranks_in_ith_jth_intra_layer_group = [
i_ * G_inter * G_intra + j_ * G_intra + k for k in range(G_intra)
]
print(ranks_in_ith_jth_intra_layer_group)
ith_jth_intra_layer_group = torch.distributed.new_group(
ranks=ranks_in_ith_jth_intra_layer_group, backend="nccl"
)
Expand Down
130 changes: 130 additions & 0 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from contextlib import contextmanager
from .fully_connected import Linear # noqa: F401
from .conv import Conv2d as Tensor_Parallel_Conv2d # noqa: F401

from .communication import Drop, Gather
from .gradient_normalization import clip_grad_norm_ # noqa: F401

from axonn import axonn as ax
import torch
import torch.distributed as dist


def drop(x, transpose=False, dim=-1, batch_dim=0):
Expand All @@ -27,3 +30,130 @@ def gather(x, transpose=False, dim=-1, batch_dim=0):
x = Gather.apply(x, group, dim)
x = Gather.apply(x, ax.comm_handle.depth_intra_layer_parallel_group, batch_dim)
return x


OVERLAP_REDUCE_SCATTER = False
OVERLAP_ALL_REDUCE = False
CACHE_WEIGHTS = False
ALL_GATHER_ITERATOR = None
handles = []
pending_grad_accumulations = []
weights_cache = {}


def register_handle(handle):
# ToDo: This might be unnecesary since
# we are calling synchronize in clear_handles
global handles
handles.append(handle)


def clear_handles():
global handles
torch.cuda.synchronize()
handles = []


def accumulate_later(param, grad):
global pending_grad_accumulations
pending_grad_accumulations.append([param, grad])


@torch.no_grad()
def accumulate():
global pending_grad_accumulations
for param, grad in pending_grad_accumulations:
if param.grad is None:
param.grad = grad
else:
param.grad.add_(grad)

pending_grad_accumulations = []


def clear_weights_cache():
global weights_cache
weights_cache = {}


def trigger_async_all_gathers(model):
global weights_cache
for module in model.modules():
if isinstance(module, Linear):
weight = module.weight
if weight not in weights_cache:
# only trigger all gathers if not in cache
process_group = module.depth_group
world_size = dist.get_world_size(process_group)
if world_size == 1:
all_gathered_weight = weight
handle = None
else:
assert weight.ndim == 1
output_shape = weight.shape[0] * world_size
all_gathered_weight = torch.empty(
output_shape, dtype=weight.dtype, device=weight.device
)
handle = dist.all_gather_into_tensor(
all_gathered_weight, weight, group=process_group, async_op=True
)
weights_cache[weight] = [all_gathered_weight, handle]
yield


def enqueue_next_all_gather():
global ALL_GATHER_ITERATOR
assert ALL_GATHER_ITERATOR is not None
try:
next(ALL_GATHER_ITERATOR)
except StopIteration:
pass


def retrieve_all_gathered_weight(weight):
global CACHE_WEIGHTS, ALL_GATHER_ITERATOR
assert weight in weights_cache
all_gathered_weight, handle = weights_cache[weight]
if ALL_GATHER_ITERATOR is not None:
enqueue_next_all_gather()
return all_gathered_weight, handle


@contextmanager
def optimize_communication(
overlap_all_reduce=True,
overlap_reduce_scatter=False,
cache_weights=False,
overlap_all_gather=False,
model=None,
*args,
**kwargs
):
global OVERLAP_ALL_REDUCE, OVERLAP_REDUCE_SCATTER, CACHE_WEIGHTS
global ALL_GATHER_ITERATOR
OVERLAP_ALL_REDUCE = overlap_all_reduce
OVERLAP_REDUCE_SCATTER = overlap_reduce_scatter

CACHE_WEIGHTS = cache_weights

if overlap_all_gather:
if model is None:
raise ValueError(
"You need to pass your model as an argument - "
"optimize_communication(...,model=model, ...)"
"if overlap_all_gather is True"
)
assert (
cache_weights
), "all gathers can only be overlapped if cache_weights is True"
ALL_GATHER_ITERATOR = trigger_async_all_gathers(model)
enqueue_next_all_gather()

try:
yield None
finally:
clear_handles()
accumulate()
OVERLAP_ALL_REDUCE = False
OVERLAP_REDUCE_SCATTER = False
ALL_GATHER_ITERATOR = None
107 changes: 82 additions & 25 deletions axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import torch.distributed as dist
import torch
import axonn


def _all_reduce(input_, process_group=None):
def _all_reduce(input_, process_group=None, overlap_comm=False):
input_ = input_.contiguous()
if dist.get_world_size(process_group) > 1:
dist.all_reduce(input_.contiguous(), group=process_group)
handle = dist.all_reduce(
input_.contiguous(), group=process_group, async_op=overlap_comm
)
if overlap_comm:
axonn.intra_layer.register_handle(handle)
return input_


Expand All @@ -21,28 +27,37 @@ def _drop(input_, dim, process_group=None):
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)


def _gather(input_, dim, process_group=None):
def _gather(input_, dim, process_group=None, cache=False):
"""Gather tensors and concatenate them along a dimension"""
if dist.get_world_size(process_group) == 1:
return input_

input_ = input_.contiguous()
# Size and dimension.
rank = dist.get_rank(process_group)
if input_ in axonn.intra_layer.weights_cache:
output, handle = axonn.intra_layer.retrieve_all_gathered_weight(input_)
if handle is not None:
handle.wait()
axonn.intra_layer.weights_cache[input_][1] = None
else:
input_ = input_.contiguous()
# Size and dimension.
rank = dist.get_rank(process_group)

tensor_list = [
torch.empty_like(input_) for _ in range(dist.get_world_size(process_group))
]
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=process_group)

tensor_list = [
torch.empty_like(input_) for _ in range(dist.get_world_size(process_group))
]
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=process_group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
if cache:
axonn.intra_layer.weights_cache[input_] = output, None

return output


def _reduce_scatter(input_, dim, process_group=None):
def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
assert dim == 0, "reduce scatter only implemented for dim=0"

if dist.get_world_size(process_group) == 1:
Expand All @@ -55,7 +70,18 @@ def _reduce_scatter(input_, dim, process_group=None):
output = torch.empty(
tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device()
)
torch.distributed.reduce_scatter_tensor(output, input_, group=process_group)

if hasattr(torch.distributed, "reduce_scatter_tensor"):
handle = torch.distributed.reduce_scatter_tensor(
output, input_, group=process_group, async_op=overlap_comm
)
else:
handle = torch.distributed._reduce_scatter_base(
output, input_, group=process_group, async_op=overlap_comm
)

if overlap_comm:
axonn.intra_layer.register_handle(handle)
return output


Expand All @@ -75,17 +101,24 @@ def backward(ctx, grad_output):

class BackwardAllReduce(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_, process_group=None):
def symbolic(graph, input_, process_group=None, overlap_comm=False):
return input_

@staticmethod
def forward(ctx, input_, process_group=None):
def forward(ctx, input_, process_group=None, overlap_comm=False):
ctx.process_group = process_group
ctx.overlap_comm = overlap_comm
ctx.input = input_
return input_

@staticmethod
def backward(ctx, grad_output):
return _all_reduce(grad_output, ctx.process_group), None
grad_input = _all_reduce(grad_output, ctx.process_group, ctx.overlap_comm)
if not ctx.overlap_comm:
return grad_input, None, None
else:
axonn.intra_layer.accumulate_later(ctx.input, grad_input)
return None, None, None


class Drop(torch.autograd.Function):
Expand Down Expand Up @@ -130,21 +163,45 @@ def backward(ctx, grad_output):

class ForwardGather_BackwardReduceScatter(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_, process_group=None, dim=0):
def symbolic(
graph,
input_,
process_group=None,
dim=0,
overlap_comm=False,
cache_all_gather=False,
):
return _gather(input_, dim=dim, process_group=process_group)

@staticmethod
def forward(ctx, input_, process_group=None, dim=0):
def forward(
ctx,
input_,
process_group=None,
dim=0,
overlap_comm=False,
cache_all_gather=False,
):
assert dim == 0
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim=dim, process_group=process_group)
ctx.overlap_comm = overlap_comm
ctx.input = input_
return _gather(
input_, dim=dim, process_group=process_group, cache=cache_all_gather
)

@staticmethod
def backward(ctx, grad_output):
assert ctx.dim == 0
return (
_reduce_scatter(grad_output, dim=ctx.dim, process_group=ctx.process_group),
None,
None,
grad_input = _reduce_scatter(
grad_output,
dim=ctx.dim,
process_group=ctx.process_group,
overlap_comm=ctx.overlap_comm,
)
if not ctx.overlap_comm:
return (grad_input, None, None, None, None)
else:
axonn.intra_layer.accumulate_later(ctx.input, grad_input)
return None, None, None, None, None
Loading

0 comments on commit 3ebc34c

Please sign in to comment.