Skip to content

Commit

Permalink
separate args for overlap_all_reduce and overlap_reduce_scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 committed Nov 26, 2023
1 parent ba8a47f commit 9cf6bb6
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 17 deletions.
85 changes: 78 additions & 7 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from axonn import axonn as ax
import torch
import torch.distributed as dist


def drop(x, transpose=False, dim=-1, batch_dim=0):
Expand All @@ -31,12 +32,15 @@ def gather(x, transpose=False, dim=-1, batch_dim=0):
return x


OVERLAP_COMM = False
OVERLAP_REDUCE_SCATTER = False
OVERLAP_ALL_REDUCE = False
CACHE_WEIGHTS = False
ALL_GATHER_ITERATOR = None
handles = []
pending_grad_accumulations = []
weights_cache = {}


def register_handle(handle):
# ToDo: This might be unnecesary since
# we are calling synchronize in clear_handles
Expand Down Expand Up @@ -66,21 +70,88 @@ def accumulate():

pending_grad_accumulations = []


def clear_weights_cache():
global weights_cache
weights_cache = {}


def trigger_async_all_gathers(model):
global weights_cache
for module in model.modules():
if isinstance(module, Linear):
weight = module.weight
if weight not in weights_cache:
# only trigger all gathers if not in cache
process_group = module.depth_group
world_size = dist.get_world_size(process_group)
if world_size == 1:
all_gathered_weight = weight
handle = None
else:
assert weight.ndim == 1
output_shape = weight.shape[0] * world_size
all_gathered_weight = torch.empty(
output_shape, dtype=weight.dtype, device=weight.device
)
handle = dist.all_gather_into_tensor(
all_gathered_weight, weight, group=process_group, async_op=True
)
weights_cache[weight] = [all_gathered_weight, handle]
yield


def enqueue_next_all_gather():
global ALL_GATHER_ITERATOR
assert ALL_GATHER_ITERATOR is not None
try:
next(ALL_GATHER_ITERATOR)
except StopIteration:
pass


def retrieve_all_gathered_weight(weight):
global CACHE_WEIGHTS, ALL_GATHER_ITERATOR
assert weight in weights_cache
all_gathered_weight, handle = weights_cache[weight]
if ALL_GATHER_ITERATOR is not None:
enqueue_next_all_gather()
return all_gathered_weight, handle


@contextmanager
def optimize_communication(cache_weights=False, *args, **kwargs):
global OVERLAP_COMM, CACHE_WEIGHTS
OVERLAP_COMM = True
def optimize_communication(
overlap_all_reduce=True, overlap_reduce_scatter=False, cache_weights=False, overlap_all_gather=False, model=None, *args, **kwargs
):
global OVERLAP_ALL_REDUCE, OVERLAP_REDUCE_SCATTER, CACHE_WEIGHTS, ALL_GATHER_ITERATOR
OVERLAP_ALL_REDUCE = overlap_all_reduce
OVERLAP_REDUCE_SCATTER = overlap_reduce_scatter

if (not cache_weights) and (CACHE_WEIGHTS):
raise ValueError("Attempting to set cache_weights to False, when it was earlier set to True. This can lead to erroneous behaviours. Either always use cache_weights=False or cache_weights=True")
CACHE_WEIGHTS=cache_weights
raise ValueError(
"Attempting to set cache_weights to False, when it was earlier set to True."
"This can lead to erroneous behaviour. Either always use cache_weights=False or cache_weights=True"
)
CACHE_WEIGHTS = cache_weights

if overlap_all_gather:
if model is None:
raise ValueError(
"You need to pass your model as an argument - "
"optimize_communication(...,model=model, ...)"
"if overlap_all_gather is True"
)
assert (
cache_weights
), "all gathers can only be overlapped if cache_weights is True"
ALL_GATHER_ITERATOR = trigger_async_all_gathers(model)
enqueue_next_all_gather()

try:
yield None
finally:
clear_handles()
accumulate()
OVERLAP_COMM = False
OVERLAP_ALL_REDUCE = False
OVERLAP_REDUCE_SCATTER = False
ALL_GATHER_ITERATOR = None
32 changes: 25 additions & 7 deletions axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def _gather(input_, dim, process_group=None, cache=False):
return input_

if input_ in axonn.intra_layer.weights_cache:
output = axonn.intra_layer.weights_cache[input_]

output, handle = axonn.intra_layer.retrieve_all_gathered_weight(input_)
if handle is not None:
handle.wait()
axonn.intra_layer.weights_cache[input_][1] = None
else:
input_ = input_.contiguous()
# Size and dimension.
Expand All @@ -44,8 +46,8 @@ def _gather(input_, dim, process_group=None, cache=False):
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()

if cache:
axonn.intra_layer.weights_cache[input_] = output
if cache:
axonn.intra_layer.weights_cache[input_] = output, None

return output

Expand Down Expand Up @@ -142,17 +144,33 @@ def backward(ctx, grad_output):

class ForwardGather_BackwardReduceScatter(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_, process_group=None, dim=0, overlap_comm=False, cache_all_gather=False):
def symbolic(
graph,
input_,
process_group=None,
dim=0,
overlap_comm=False,
cache_all_gather=False,
):
return _gather(input_, dim=dim, process_group=process_group)

@staticmethod
def forward(ctx, input_, process_group=None, dim=0, overlap_comm=False, cache_all_gather=False):
def forward(
ctx,
input_,
process_group=None,
dim=0,
overlap_comm=False,
cache_all_gather=False,
):
assert dim == 0
ctx.process_group = process_group
ctx.dim = dim
ctx.overlap_comm = overlap_comm
ctx.input = input_
return _gather(input_, dim=dim, process_group=process_group, cache=cache_all_gather)
return _gather(
input_, dim=dim, process_group=process_group, cache=cache_all_gather
)

@staticmethod
def backward(ctx, grad_output):
Expand Down
10 changes: 7 additions & 3 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,11 @@ def forward(self, x, scatter_input=True, gather_output=True):
# gather weights from depth parallel group
# reduce scatter in the backward pass
weight = ForwardGather_BackwardReduceScatter.apply(
self.weight, self.depth_group, 0, axonn.intra_layer.OVERLAP_COMM, axonn.intra_layer.CACHE_WEIGHTS
self.weight,
self.depth_group,
0,
axonn.intra_layer.OVERLAP_REDUCE_SCATTER,
axonn.intra_layer.CACHE_WEIGHTS,
).reshape(self.local_out_features, self.local_in_features)

if not self.transpose:
Expand All @@ -215,7 +219,7 @@ def forward(self, x, scatter_input=True, gather_output=True):
weight,
self.inner_group,
self.outer_group,
axonn.intra_layer.OVERLAP_COMM,
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
if gather_output:
Expand All @@ -231,7 +235,7 @@ def forward(self, x, scatter_input=True, gather_output=True):
weight,
self.outer_group,
self.inner_group,
axonn.intra_layer.OVERLAP_COMM,
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
if gather_output:
Expand Down

0 comments on commit 9cf6bb6

Please sign in to comment.