Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more timers #104

Merged
merged 4 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion axonn/intra_layer/asym_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group
)

torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group)
ax.get_timers().stop("alltoallv")
ax.get_timers().stop("gather-channels-scatter-batch")
ax.get_timers().stop("alltoallv")
return torch.cat(recv_tensors, dim=-1)


Expand Down
22 changes: 13 additions & 9 deletions axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,38 @@
import torch.distributed as dist
import torch
import axonn.intra_layer.overlap_communication as overlap_communication

from axonn import axonn as ax

def _all_reduce(input_, process_group=None, overlap_comm=False):
ax.get_timers().start("all-reduce")
input_ = input_.contiguous()
if dist.get_world_size(process_group) > 1:
handle = dist.all_reduce(
input_.contiguous(), group=process_group, async_op=overlap_comm
)
if overlap_comm:
overlap_communication.register_handle(handle)
ax.get_timers().stop("all-reduce")
return input_


def _drop(input_, dim, process_group=None):
"""Divide a tensor among the tensor parallel ranks"""
if dist.get_world_size(process_group) == 1:
return input_

ax.get_timers().start("drop")
total_chunks = dist.get_world_size(process_group)
this_chunk = dist.get_rank(process_group)
assert input_.shape[dim] % total_chunks == 0
chunk_size = input_.shape[dim] // total_chunks

ax.get_timers().stop("drop")
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)


def _gather(input_, dim, process_group=None, cache=False):
"""Gather tensors and concatenate them along a dimension"""
if dist.get_world_size(process_group) == 1:
return input_

ax.get_timers().start("gather")
if input_ in overlap_communication.weights_cache:
output, handle = overlap_communication.retrieve_all_gathered_weight(
input_, delete=not cache
Expand All @@ -61,16 +62,16 @@ def _gather(input_, dim, process_group=None, cache=False):

if cache:
overlap_communication.weights_cache[input_] = output, None

ax.get_timers().stop("gather")
return output


def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
assert dim == 0, "reduce scatter only implemented for dim=0"

if dist.get_world_size(process_group) == 1:
return input_

ax.get_timers().start("reduce-scatter")
total_chunks = dist.get_world_size(process_group)
assert input_.shape[dim] % total_chunks == 0
tensor_shape = list(input_.shape)
Expand All @@ -79,6 +80,7 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device()
)

ax.get_timers().start("reduce-scatter-dist")
if hasattr(torch.distributed, "reduce_scatter_tensor"):
handle = torch.distributed.reduce_scatter_tensor(
output, input_, group=process_group, async_op=overlap_comm
Expand All @@ -87,9 +89,11 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
handle = torch.distributed._reduce_scatter_base(
output, input_, group=process_group, async_op=overlap_comm
)


ax.get_timers().stop("reduce-scatter-dist")
if overlap_comm:
overlap_communication.register_handle(handle)
ax.get_timers().stop("reduce-scatter")
return output


Expand Down
19 changes: 18 additions & 1 deletion axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def forward(
local_weight_shape,
cache_weights,
):
ax.get_timers().start("forward-async")
original_weight = weight
weight = _gather(
weight, dim=0, process_group=depth_parallel_group, cache=cache_weights
Expand All @@ -115,14 +116,17 @@ def forward(
ctx.backward_all_reduce_group = backward_all_reduce_group
ctx.depth_parallel_group = depth_parallel_group
ctx.shape = local_weight_shape
ax.get_timers().start("compute")
output = input_.matmul(weight.t())
ax.get_timers().stop("compute")
dist.all_reduce(output, group=forward_all_reduce_group, async_op=False)

ax.get_timers().stop("forward-async")
return output

@staticmethod
@version_aware_custom_bwd
def backward(ctx, grad_output):
ax.get_timers().start("backward-async")
input_, original_weight = ctx.saved_tensors
weight = _gather(
original_weight, dim=0, process_group=ctx.depth_parallel_group, cache=False
Expand All @@ -137,18 +141,22 @@ def backward(ctx, grad_output):
grad_input, grad_weight = None, None

if ctx.needs_input_grad[0]:
ax.get_timers().start("compute")
grad_input = grad_output.matmul(weight)
ax.get_timers().stop("compute")
handle = dist.all_reduce(
grad_input,
group=ctx.backward_all_reduce_group,
async_op=overlap_all_reduce,
)
if ctx.needs_input_grad[1]:
ax.get_timers().start("compute")
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
)
ax.get_timers().stop("compute")

grad_weight = grad_weight.reshape(-1)
grad_weight = _reduce_scatter(
Expand All @@ -163,16 +171,19 @@ def backward(ctx, grad_output):
if overlap_reduce_scatter and ctx.needs_input_grad[1]:
overlap_communication.accumulate_later(original_weight, grad_weight)
grad_weight = None # weight gradients are not ready yet
ax.get_timers().stop("backward-async")
return grad_input, grad_weight, None, None, None, None, None, None, None
else:
grad_input, grad_weight = None, None

if ctx.needs_input_grad[1]:
ax.get_timers().start("compute")
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
).reshape(-1)
ax.get_timers().stop("compute")
grad_weight = _reduce_scatter(
grad_weight,
dim=0,
Expand All @@ -183,7 +194,10 @@ def backward(ctx, grad_output):
grad_weight = None # weight gradients are not ready yet

if ctx.needs_input_grad[0]:
ax.get_timers().start("compute")
grad_input = grad_output.matmul(weight)
ax.get_timers().stop("compute")
ax.get_timers().stop("backward-async")
return grad_input, grad_weight, None, None, None, None, None, None, None


Expand Down Expand Up @@ -305,6 +319,7 @@ def forward(
x,
cache_weights_in_all_gather=False,
):
ax.get_timers().start("forward-linear")
original_shape_x = x.shape
x = x.reshape(-1, x.shape[-1])
weight = self.weight
Expand Down Expand Up @@ -345,11 +360,13 @@ def forward(
x = x.reshape(*original_shape_x[:-1], x.shape[-1])

if self.bias is None:
ax.get_timers().stop("forward-linear")
return x
else:
bias = self.bias
if not self.expert_mode:
bias = Gather.apply(bias, self.outer_group)
ax.get_timers().stop("forward-linear")
if self.skip_bias_add:
return x, bias
else:
Expand Down
Loading