Skip to content

Commit

Permalink
Adding timers with cct (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Oct 23, 2024
1 parent 464e4a8 commit 84de428
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 2 deletions.
16 changes: 15 additions & 1 deletion axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from typing import Optional
from .communication import communication_handle
import torch
from .timers import Timers

# True when init has been called
is_initialized = False
# Communication handle for point-to-point (MPI) and collective (NCCL) communication
comm_handle = None
enable_timers = False
timers = None


def init(
Expand All @@ -22,6 +25,7 @@ def init(
G_intra_c: int = 1,
G_intra_d: int = 1,
gpus_per_node: Optional[int] = None,
enable_internal_timers: bool = False,
) -> None:
"""
Initialize AxoNN's 2D parallelism with G_inter-way inter-layer
Expand All @@ -35,9 +39,12 @@ def init(
AxoNN just creates the required process groups.
gpus_per_node (int, optional): number of GPUs per node, if not
provided this is inferred using pytorch
enable_internal_timers (bool): enable AxoNN's internal timers. This will give
you information about time spent in synchronous communication regions
and matrix multiplications.
"""
global comm_handle, is_initialized
global comm_handle, is_initialized, enable_timers, timers
comm_handle = communication_handle(
G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node
)
Expand All @@ -56,6 +63,8 @@ def init(
comm_handle.intra_layer_column_parallel_rank
)
is_initialized = True
enable_timers = enable_internal_timers
timers = Timers()


def create_dataloader(
Expand Down Expand Up @@ -110,3 +119,8 @@ def create_dataloader(
*args,
**kwargs,
) # not working with drop_last=False


def get_timers():
global timers
return timers
28 changes: 27 additions & 1 deletion axonn/intra_layer/asym_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def print_rank(msg):

@torch.no_grad()
def gather_batch_sizes(local_batch_size, process_group=None):
ax.get_timers().start("gather-batch-sizes")
world_size = dist.get_world_size(process_group)
local_batch_tensor = torch.tensor(local_batch_size, device="cuda")
global_batch_tensor = torch.empty(
Expand All @@ -23,11 +24,15 @@ def gather_batch_sizes(local_batch_size, process_group=None):
dist.all_gather_into_tensor(
global_batch_tensor, local_batch_tensor, group=process_group
)
return global_batch_tensor.cpu()

global_batch_tensor = global_batch_tensor.cpu()
ax.get_timers().stop("gather-batch-sizes")
return global_batch_tensor


@torch.no_grad()
def _allgatherv(tensor, rank_local_batch_sizes, process_group=None):
ax.get_timers().start("allgatherv")
output_tensor_list = []
for batch_size in rank_local_batch_sizes:
shape = list(tensor.shape)
Expand All @@ -37,6 +42,7 @@ def _allgatherv(tensor, rank_local_batch_sizes, process_group=None):
)
input_tensor_list = [tensor.contiguous() for _ in rank_local_batch_sizes]
dist.all_to_all(output_tensor_list, input_tensor_list, group=process_group)
ax.get_timers().stop("allgatherv")
return torch.cat(output_tensor_list)


Expand All @@ -57,10 +63,12 @@ def symbolic(graph, input_, rank_local_batch_sizes, process_group=None):

@staticmethod
def forward(ctx, input_, rank_local_batch_sizes, process_group=None):
ax.get_timers().start("extra-non-expert-communication")
output = _allgatherv(input_, rank_local_batch_sizes, process_group)
ctx.save_for_backward(rank_local_batch_sizes)
# print_rank(f"Gatherv forward - {rank_local_batch_sizes}")
ctx.process_group = process_group
ax.get_timers().stop("extra-non-expert-communication")
return output

@staticmethod
Expand Down Expand Up @@ -105,9 +113,11 @@ def forward(ctx, input_, rank_local_batch_sizes, process_group=None):

@staticmethod
def backward(ctx, grad_output):
ax.get_timers().start("extra-non-expert-communication")
(rank_local_batch_sizes,) = ctx.saved_tensors
# print_rank("Start - DropVBack")
grad_input = _allgatherv(grad_output, rank_local_batch_sizes, ctx.process_group)
ax.get_timers().stop("extra-non-expert-communication")
# print_rank("End - DropVBack")
return grad_input, None, None

Expand All @@ -116,6 +126,8 @@ def backward(ctx, grad_output):
def _gather_batch_scatter_channels(input_, rank_local_batch_sizes, process_group=None):
# if input in GPU i is of shape [m_{i},...,k], and process group size is G
# then this returns a tensor of [sum_{i}(m_{i}),....,k/G].
ax.get_timers().start("gather-batch-scatter-channels")
ax.get_timers().start("alltoallv")
input_ = input_.contiguous()
world_size = torch.distributed.get_world_size(process_group)
send_tensors = list(torch.chunk(input_, world_size, dim=-1))
Expand All @@ -130,6 +142,8 @@ def _gather_batch_scatter_channels(input_, rank_local_batch_sizes, process_group
torch.empty(tuple(shape), device="cuda", dtype=input_.dtype)
)
torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group)
ax.get_timers().stop("alltoallv")
ax.get_timers().stop("gather-batch-scatter-channels")
return torch.cat(recv_tensors, dim=0)


Expand All @@ -138,6 +152,8 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group
# if input in GPU i is of shape [m,...,k/G], and process group size is G
# then this returns a tensor of [m_{i},....,k],
# where m_{i} = rank_local_batch_sizes[i]
ax.get_timers().start("gather-channels-scatter-batch")
ax.get_timers().start("alltoallv")
input_ = input_.contiguous()
world_size = torch.distributed.get_world_size(process_group)
send_tensors = list(torch.split(input_, list(rank_local_batch_sizes), dim=0))
Expand All @@ -152,6 +168,8 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group
)

torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group)
ax.get_timers().stop("gather-channels-scatter-batch")
ax.get_timers().stop("alltoallv")
return torch.cat(recv_tensors, dim=-1)


Expand All @@ -172,20 +190,24 @@ def symbolic(graph, input_, rank_local_batch_sizes, process_group=None):

@staticmethod
def forward(ctx, input_, rank_local_batch_sizes, process_group=None):
ax.get_timers().start("extra-non-expert-communication")
output = _gather_batch_scatter_channels(
input_, rank_local_batch_sizes, process_group
)
ctx.process_group = process_group
ctx.save_for_backward(rank_local_batch_sizes)
ax.get_timers().stop("extra-non-expert-communication")
return output

@staticmethod
def backward(ctx, grad_output):
ax.get_timers().start("extra-non-expert-communication")
(rank_local_batch_sizes,) = ctx.saved_tensors
# print_rank("Start - GBSC back")
grad_input = _gather_channels_scatter_batch(
grad_output, rank_local_batch_sizes, ctx.process_group
)
ax.get_timers().stop("extra-non-expert-communication")
# print_rank("End - GBSC back")
return grad_input, None, None

Expand All @@ -208,21 +230,25 @@ def symbolic(graph, input_, rank_local_batch_sizes, process_group=None):

@staticmethod
def forward(ctx, input_, rank_local_batch_sizes, process_group=None):
ax.get_timers().start("extra-non-expert-communication")
output = _gather_channels_scatter_batch(
input_, rank_local_batch_sizes, process_group
)
ctx.process_group = process_group
ctx.save_for_backward(rank_local_batch_sizes)
ax.get_timers().stop("extra-non-expert-communication")
return output

@staticmethod
def backward(ctx, grad_output):
ax.get_timers().start("extra-non-expert-communication")
(rank_local_batch_sizes,) = ctx.saved_tensors
# print_rank("Start - GCSB back")
grad_input = _gather_batch_scatter_channels(
grad_output, rank_local_batch_sizes, ctx.process_group
)
# print_rank("End - GCSB back")
ax.get_timers().stop("extra-non-expert-communication")
return grad_input, None, None


Expand Down
7 changes: 7 additions & 0 deletions axonn/lightning/axonn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
G_intra_c: int = 1,
G_intra_d: int = 1,
overlap_communication=False,
enable_timers=False,
activation_checkpointing: Optional[
Union[Type[Module], List[Type[Module]]]
] = None,
Expand All @@ -91,6 +92,7 @@ def __init__(
self.G_intra_d = G_intra_d
self._backward_sync_control = _AxoNNBackwardSyncControl()
self.overlap_communication = overlap_communication
self.enable_timers = enable_timers

self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
activation_checkpointing, activation_checkpointing_policy
Expand Down Expand Up @@ -216,6 +218,7 @@ def _setup_distributed(self) -> None:
G_intra_r=self.G_intra_r,
G_intra_c=self.G_intra_c,
G_intra_d=self.G_intra_d,
enable_internal_timers=self.enable_timers,
)

def _get_process_group_backend(self) -> str:
Expand Down Expand Up @@ -316,6 +319,10 @@ def module_init_context(self, empty_init: Optional[bool] = None):
def module_sharded_context(self) -> ContextManager:
return auto_parallelize()

def get_timers(self):
assert self.enable_timers, "you should set enable_timers=True in AxoNNStrategy"
return ax.get_timers()


class _AxoNNBackwardSyncControl(_BackwardSyncControl):
@override
Expand Down
46 changes: 46 additions & 0 deletions axonn/timers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from collections import defaultdict
from collections import deque
import axonn


class Timers:
def __init__(self):
self.timers = defaultdict(list)
self.curr_index = defaultdict(int)
self.stack = deque()

def start(self, key):
if not axonn.axonn.enable_timers:
return
self.stack.append(key)
key = tuple(self.stack)
index = self.curr_index[key]
timers = self.timers[key]
assert index == len(timers) or index < len(timers)
if index == len(timers):
self.timers[key].append(
[torch.cuda.Event(enable_timing=True) for _ in range(2)]
)
self.timers[key][index][0].record()

def stop(self, key):
if not axonn.axonn.enable_timers:
return
key = tuple(self.stack)
index = self.curr_index[key]
self.timers[key][index][1].record()
self.curr_index[key] += 1
self.stack.pop()

def get_times(self):
torch.cuda.synchronize()
total_times = defaultdict(float)
total_events = defaultdict(int)
for key in self.timers:
for events in self.timers[key]:
start_event, end_event = events
total_times[key] += start_event.elapsed_time(end_event)
total_events[key] = self.curr_index[key]
self.curr_index[key] = 0
return total_times, total_events

0 comments on commit 84de428

Please sign in to comment.