Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding timers with cct #101

Merged
merged 3 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from typing import Optional
from .communication import communication_handle
import torch
from .timers import Timers

# True when init has been called
is_initialized = False
# Communication handle for point-to-point (MPI) and collective (NCCL) communication
comm_handle = None
enable_timers = False
timers = None


def init(
Expand All @@ -22,6 +25,7 @@ def init(
G_intra_c: int = 1,
G_intra_d: int = 1,
gpus_per_node: Optional[int] = None,
enable_internal_timers: bool = False,
) -> None:
"""
Initialize AxoNN's 2D parallelism with G_inter-way inter-layer
Expand All @@ -35,9 +39,12 @@ def init(
AxoNN just creates the required process groups.
gpus_per_node (int, optional): number of GPUs per node, if not
provided this is inferred using pytorch
enable_internal_timers (bool): enable AxoNN's internal timers. This will give
you information about time spent in synchronous communication regions
and matrix multiplications.

"""
global comm_handle, is_initialized
global comm_handle, is_initialized, enable_timers, timers
comm_handle = communication_handle(
G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node
)
Expand All @@ -56,6 +63,8 @@ def init(
comm_handle.intra_layer_column_parallel_rank
)
is_initialized = True
enable_timers = enable_internal_timers
timers = Timers()


def create_dataloader(
Expand Down Expand Up @@ -110,3 +119,8 @@ def create_dataloader(
*args,
**kwargs,
) # not working with drop_last=False


def get_timers():
global timers
return timers
28 changes: 27 additions & 1 deletion axonn/intra_layer/asym_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def print_rank(msg):

@torch.no_grad()
def gather_batch_sizes(local_batch_size, process_group=None):
ax.get_timers().start("gather-batch-sizes")
world_size = dist.get_world_size(process_group)
local_batch_tensor = torch.tensor(local_batch_size, device="cuda")
global_batch_tensor = torch.empty(
Expand All @@ -23,11 +24,15 @@ def gather_batch_sizes(local_batch_size, process_group=None):
dist.all_gather_into_tensor(
global_batch_tensor, local_batch_tensor, group=process_group
)
return global_batch_tensor.cpu()

global_batch_tensor = global_batch_tensor.cpu()
ax.get_timers().stop("gather-batch-sizes")
return global_batch_tensor


@torch.no_grad()
def _allgatherv(tensor, rank_local_batch_sizes, process_group=None):
ax.get_timers().start("allgatherv")
output_tensor_list = []
for batch_size in rank_local_batch_sizes:
shape = list(tensor.shape)
Expand All @@ -37,6 +42,7 @@ def _allgatherv(tensor, rank_local_batch_sizes, process_group=None):
)
input_tensor_list = [tensor.contiguous() for _ in rank_local_batch_sizes]
dist.all_to_all(output_tensor_list, input_tensor_list, group=process_group)
ax.get_timers().stop("allgatherv")
return torch.cat(output_tensor_list)


Expand All @@ -57,10 +63,12 @@ def symbolic(graph, input_, rank_local_batch_sizes, process_group=None):

@staticmethod
def forward(ctx, input_, rank_local_batch_sizes, process_group=None):
ax.get_timers().start("extra-non-expert-communication")
output = _allgatherv(input_, rank_local_batch_sizes, process_group)
ctx.save_for_backward(rank_local_batch_sizes)
# print_rank(f"Gatherv forward - {rank_local_batch_sizes}")
ctx.process_group = process_group
ax.get_timers().stop("extra-non-expert-communication")
return output

@staticmethod
Expand Down Expand Up @@ -105,9 +113,11 @@ def forward(ctx, input_, rank_local_batch_sizes, process_group=None):

@staticmethod
def backward(ctx, grad_output):
ax.get_timers().start("extra-non-expert-communication")
(rank_local_batch_sizes,) = ctx.saved_tensors
# print_rank("Start - DropVBack")
grad_input = _allgatherv(grad_output, rank_local_batch_sizes, ctx.process_group)
ax.get_timers().stop("extra-non-expert-communication")
# print_rank("End - DropVBack")
return grad_input, None, None

Expand All @@ -116,6 +126,8 @@ def backward(ctx, grad_output):
def _gather_batch_scatter_channels(input_, rank_local_batch_sizes, process_group=None):
# if input in GPU i is of shape [m_{i},...,k], and process group size is G
# then this returns a tensor of [sum_{i}(m_{i}),....,k/G].
ax.get_timers().start("gather-batch-scatter-channels")
ax.get_timers().start("alltoallv")
input_ = input_.contiguous()
world_size = torch.distributed.get_world_size(process_group)
send_tensors = list(torch.chunk(input_, world_size, dim=-1))
Expand All @@ -130,6 +142,8 @@ def _gather_batch_scatter_channels(input_, rank_local_batch_sizes, process_group
torch.empty(tuple(shape), device="cuda", dtype=input_.dtype)
)
torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group)
ax.get_timers().stop("alltoallv")
ax.get_timers().stop("gather-batch-scatter-channels")
return torch.cat(recv_tensors, dim=0)


Expand All @@ -138,6 +152,8 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group
# if input in GPU i is of shape [m,...,k/G], and process group size is G
# then this returns a tensor of [m_{i},....,k],
# where m_{i} = rank_local_batch_sizes[i]
ax.get_timers().start("gather-channels-scatter-batch")
ax.get_timers().start("alltoallv")
input_ = input_.contiguous()
world_size = torch.distributed.get_world_size(process_group)
send_tensors = list(torch.split(input_, list(rank_local_batch_sizes), dim=0))
Expand All @@ -152,6 +168,8 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group
)

torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group)
ax.get_timers().stop("gather-channels-scatter-batch")
ax.get_timers().stop("alltoallv")
return torch.cat(recv_tensors, dim=-1)


Expand All @@ -172,20 +190,24 @@ def symbolic(graph, input_, rank_local_batch_sizes, process_group=None):

@staticmethod
def forward(ctx, input_, rank_local_batch_sizes, process_group=None):
ax.get_timers().start("extra-non-expert-communication")
output = _gather_batch_scatter_channels(
input_, rank_local_batch_sizes, process_group
)
ctx.process_group = process_group
ctx.save_for_backward(rank_local_batch_sizes)
ax.get_timers().stop("extra-non-expert-communication")
return output

@staticmethod
def backward(ctx, grad_output):
ax.get_timers().start("extra-non-expert-communication")
(rank_local_batch_sizes,) = ctx.saved_tensors
# print_rank("Start - GBSC back")
grad_input = _gather_channels_scatter_batch(
grad_output, rank_local_batch_sizes, ctx.process_group
)
ax.get_timers().stop("extra-non-expert-communication")
# print_rank("End - GBSC back")
return grad_input, None, None

Expand All @@ -208,21 +230,25 @@ def symbolic(graph, input_, rank_local_batch_sizes, process_group=None):

@staticmethod
def forward(ctx, input_, rank_local_batch_sizes, process_group=None):
ax.get_timers().start("extra-non-expert-communication")
output = _gather_channels_scatter_batch(
input_, rank_local_batch_sizes, process_group
)
ctx.process_group = process_group
ctx.save_for_backward(rank_local_batch_sizes)
ax.get_timers().stop("extra-non-expert-communication")
return output

@staticmethod
def backward(ctx, grad_output):
ax.get_timers().start("extra-non-expert-communication")
(rank_local_batch_sizes,) = ctx.saved_tensors
# print_rank("Start - GCSB back")
grad_input = _gather_batch_scatter_channels(
grad_output, rank_local_batch_sizes, ctx.process_group
)
# print_rank("End - GCSB back")
ax.get_timers().stop("extra-non-expert-communication")
return grad_input, None, None


Expand Down
7 changes: 7 additions & 0 deletions axonn/lightning/axonn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
G_intra_c: int = 1,
G_intra_d: int = 1,
overlap_communication=False,
enable_timers=False,
activation_checkpointing: Optional[
Union[Type[Module], List[Type[Module]]]
] = None,
Expand All @@ -91,6 +92,7 @@ def __init__(
self.G_intra_d = G_intra_d
self._backward_sync_control = _AxoNNBackwardSyncControl()
self.overlap_communication = overlap_communication
self.enable_timers = enable_timers

self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
activation_checkpointing, activation_checkpointing_policy
Expand Down Expand Up @@ -216,6 +218,7 @@ def _setup_distributed(self) -> None:
G_intra_r=self.G_intra_r,
G_intra_c=self.G_intra_c,
G_intra_d=self.G_intra_d,
enable_internal_timers=self.enable_timers,
)

def _get_process_group_backend(self) -> str:
Expand Down Expand Up @@ -316,6 +319,10 @@ def module_init_context(self, empty_init: Optional[bool] = None):
def module_sharded_context(self) -> ContextManager:
return auto_parallelize()

def get_timers(self):
assert self.enable_timers, "you should set enable_timers=True in AxoNNStrategy"
return ax.get_timers()


class _AxoNNBackwardSyncControl(_BackwardSyncControl):
@override
Expand Down
46 changes: 46 additions & 0 deletions axonn/timers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from collections import defaultdict
from collections import deque
import axonn


class Timers:
def __init__(self):
self.timers = defaultdict(list)
self.curr_index = defaultdict(int)
self.stack = deque()

def start(self, key):
if not axonn.axonn.enable_timers:
return
self.stack.append(key)
key = tuple(self.stack)
index = self.curr_index[key]
timers = self.timers[key]
assert index == len(timers) or index < len(timers)
if index == len(timers):
self.timers[key].append(
[torch.cuda.Event(enable_timing=True) for _ in range(2)]
)
self.timers[key][index][0].record()

def stop(self, key):
if not axonn.axonn.enable_timers:
return
key = tuple(self.stack)
index = self.curr_index[key]
self.timers[key][index][1].record()
self.curr_index[key] += 1
self.stack.pop()

def get_times(self):
torch.cuda.synchronize()
total_times = defaultdict(float)
total_events = defaultdict(int)
for key in self.timers:
for events in self.timers[key]:
start_event, end_event = events
total_times[key] += start_event.elapsed_time(end_event)
total_events[key] = self.curr_index[key]
self.curr_index[key] = 0
return total_times, total_events
Loading