From 76781989ace4f6907eb916906ce30a87f6ce451c Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 23 Oct 2024 11:01:23 -0400 Subject: [PATCH 1/3] infra for timers with cct --- axonn/axonn.py | 15 ++++++++++++- axonn/intra_layer/asym_communication.py | 28 +++++++++++++++++++++++-- axonn/lightning/axonn_strategy.py | 8 +++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/axonn/axonn.py b/axonn/axonn.py index 82013ee..60e3a51 100644 --- a/axonn/axonn.py +++ b/axonn/axonn.py @@ -8,11 +8,14 @@ from typing import Optional from .communication import communication_handle import torch +from .timers import Timers # True when init has been called is_initialized = False # Communication handle for point-to-point (MPI) and collective (NCCL) communication comm_handle = None +enable_timers = False +timers = None def init( @@ -22,6 +25,7 @@ def init( G_intra_c: int = 1, G_intra_d: int = 1, gpus_per_node: Optional[int] = None, + enable_internal_timers: bool = False ) -> None: """ Initialize AxoNN's 2D parallelism with G_inter-way inter-layer @@ -35,9 +39,12 @@ def init( AxoNN just creates the required process groups. gpus_per_node (int, optional): number of GPUs per node, if not provided this is inferred using pytorch + enable_internal_timers (bool): enable AxoNN's internal timers. This will give + you information about time spent in synchronous communication regions + and matrix multiplications. """ - global comm_handle, is_initialized + global comm_handle, is_initialized, enable_timers, timers comm_handle = communication_handle( G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node ) @@ -56,6 +63,8 @@ def init( comm_handle.intra_layer_column_parallel_rank ) is_initialized = True + enable_timers = enable_internal_timers + timers = Timers() def create_dataloader( @@ -110,3 +119,7 @@ def create_dataloader( *args, **kwargs, ) # not working with drop_last=False + +def get_timers(): + global timers + return timers diff --git a/axonn/intra_layer/asym_communication.py b/axonn/intra_layer/asym_communication.py index 4ef0e3c..3b719e1 100644 --- a/axonn/intra_layer/asym_communication.py +++ b/axonn/intra_layer/asym_communication.py @@ -7,7 +7,6 @@ import torch.distributed as dist from axonn import axonn as ax - def print_rank(msg): if dist.get_rank() == 0: print(f"{dist.get_rank()} | {msg}") @@ -15,6 +14,7 @@ def print_rank(msg): @torch.no_grad() def gather_batch_sizes(local_batch_size, process_group=None): + ax.get_timers().start("gather-batch-sizes") world_size = dist.get_world_size(process_group) local_batch_tensor = torch.tensor(local_batch_size, device="cuda") global_batch_tensor = torch.empty( @@ -23,11 +23,14 @@ def gather_batch_sizes(local_batch_size, process_group=None): dist.all_gather_into_tensor( global_batch_tensor, local_batch_tensor, group=process_group ) - return global_batch_tensor.cpu() + global_batch_tensor = global_batch_tensor.cpu() + ax.get_timers().stop("gather-batch-sizes") + return global_batch_tensor @torch.no_grad() def _allgatherv(tensor, rank_local_batch_sizes, process_group=None): + ax.get_timers().start("allgatherv") output_tensor_list = [] for batch_size in rank_local_batch_sizes: shape = list(tensor.shape) @@ -37,6 +40,7 @@ def _allgatherv(tensor, rank_local_batch_sizes, process_group=None): ) input_tensor_list = [tensor.contiguous() for _ in rank_local_batch_sizes] dist.all_to_all(output_tensor_list, input_tensor_list, group=process_group) + ax.get_timers().stop("allgatherv") return torch.cat(output_tensor_list) @@ -57,10 +61,12 @@ def symbolic(graph, input_, rank_local_batch_sizes, process_group=None): @staticmethod def forward(ctx, input_, rank_local_batch_sizes, process_group=None): + ax.get_timers().start("extra-non-expert-communication") output = _allgatherv(input_, rank_local_batch_sizes, process_group) ctx.save_for_backward(rank_local_batch_sizes) # print_rank(f"Gatherv forward - {rank_local_batch_sizes}") ctx.process_group = process_group + ax.get_timers().stop("extra-non-expert-communication") return output @staticmethod @@ -105,9 +111,11 @@ def forward(ctx, input_, rank_local_batch_sizes, process_group=None): @staticmethod def backward(ctx, grad_output): + ax.get_timers().start("extra-non-expert-communication") (rank_local_batch_sizes,) = ctx.saved_tensors # print_rank("Start - DropVBack") grad_input = _allgatherv(grad_output, rank_local_batch_sizes, ctx.process_group) + ax.get_timers().stop("extra-non-expert-communication") # print_rank("End - DropVBack") return grad_input, None, None @@ -116,6 +124,8 @@ def backward(ctx, grad_output): def _gather_batch_scatter_channels(input_, rank_local_batch_sizes, process_group=None): # if input in GPU i is of shape [m_{i},...,k], and process group size is G # then this returns a tensor of [sum_{i}(m_{i}),....,k/G]. + ax.get_timers().start("gather-batch-scatter-channels") + ax.get_timers().start("alltoallv") input_ = input_.contiguous() world_size = torch.distributed.get_world_size(process_group) send_tensors = list(torch.chunk(input_, world_size, dim=-1)) @@ -130,6 +140,8 @@ def _gather_batch_scatter_channels(input_, rank_local_batch_sizes, process_group torch.empty(tuple(shape), device="cuda", dtype=input_.dtype) ) torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group) + ax.get_timers().stop("alltoallv") + ax.get_timers().stop("gather-batch-scatter-channels") return torch.cat(recv_tensors, dim=0) @@ -138,6 +150,8 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group # if input in GPU i is of shape [m,...,k/G], and process group size is G # then this returns a tensor of [m_{i},....,k], # where m_{i} = rank_local_batch_sizes[i] + ax.get_timers().start("gather-channels-scatter-batch") + ax.get_timers().start("alltoallv") input_ = input_.contiguous() world_size = torch.distributed.get_world_size(process_group) send_tensors = list(torch.split(input_, list(rank_local_batch_sizes), dim=0)) @@ -152,6 +166,8 @@ def _gather_channels_scatter_batch(input_, rank_local_batch_sizes, process_group ) torch.distributed.all_to_all(recv_tensors, send_tensors, group=process_group) + ax.get_timers().stop("gather-channels-scatter-batch") + ax.get_timers().stop("alltoallv") return torch.cat(recv_tensors, dim=-1) @@ -172,20 +188,24 @@ def symbolic(graph, input_, rank_local_batch_sizes, process_group=None): @staticmethod def forward(ctx, input_, rank_local_batch_sizes, process_group=None): + ax.get_timers().start("extra-non-expert-communication") output = _gather_batch_scatter_channels( input_, rank_local_batch_sizes, process_group ) ctx.process_group = process_group ctx.save_for_backward(rank_local_batch_sizes) + ax.get_timers().stop("extra-non-expert-communication") return output @staticmethod def backward(ctx, grad_output): + ax.get_timers().start("extra-non-expert-communication") (rank_local_batch_sizes,) = ctx.saved_tensors # print_rank("Start - GBSC back") grad_input = _gather_channels_scatter_batch( grad_output, rank_local_batch_sizes, ctx.process_group ) + ax.get_timers().stop("extra-non-expert-communication") # print_rank("End - GBSC back") return grad_input, None, None @@ -208,21 +228,25 @@ def symbolic(graph, input_, rank_local_batch_sizes, process_group=None): @staticmethod def forward(ctx, input_, rank_local_batch_sizes, process_group=None): + ax.get_timers().start("extra-non-expert-communication") output = _gather_channels_scatter_batch( input_, rank_local_batch_sizes, process_group ) ctx.process_group = process_group ctx.save_for_backward(rank_local_batch_sizes) + ax.get_timers().stop("extra-non-expert-communication") return output @staticmethod def backward(ctx, grad_output): + ax.get_timers().start("extra-non-expert-communication") (rank_local_batch_sizes,) = ctx.saved_tensors # print_rank("Start - GCSB back") grad_input = _gather_batch_scatter_channels( grad_output, rank_local_batch_sizes, ctx.process_group ) # print_rank("End - GCSB back") + ax.get_timers().stop("extra-non-expert-communication") return grad_input, None, None diff --git a/axonn/lightning/axonn_strategy.py b/axonn/lightning/axonn_strategy.py index 881e3c6..91b629d 100644 --- a/axonn/lightning/axonn_strategy.py +++ b/axonn/lightning/axonn_strategy.py @@ -70,6 +70,7 @@ def __init__( G_intra_c: int = 1, G_intra_d: int = 1, overlap_communication=False, + enable_timers = False, activation_checkpointing: Optional[ Union[Type[Module], List[Type[Module]]] ] = None, @@ -91,6 +92,7 @@ def __init__( self.G_intra_d = G_intra_d self._backward_sync_control = _AxoNNBackwardSyncControl() self.overlap_communication = overlap_communication + self.enable_timers = enable_timers self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs( activation_checkpointing, activation_checkpointing_policy @@ -216,6 +218,7 @@ def _setup_distributed(self) -> None: G_intra_r=self.G_intra_r, G_intra_c=self.G_intra_c, G_intra_d=self.G_intra_d, + enable_internal_timers=self.enable_timers ) def _get_process_group_backend(self) -> str: @@ -317,6 +320,11 @@ def module_sharded_context(self) -> ContextManager: return auto_parallelize() + def get_timers(self): + assert self.enable_timers, "you should set enable_timers=True in AxoNNStrategy" + return ax.get_timers() + + class _AxoNNBackwardSyncControl(_BackwardSyncControl): @override def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: From 5d1e53116880606ade30394da7bf7f2cc48b0640 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 23 Oct 2024 11:14:35 -0400 Subject: [PATCH 2/3] push timers --- axonn/timers.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 axonn/timers.py diff --git a/axonn/timers.py b/axonn/timers.py new file mode 100644 index 0000000..98e7be2 --- /dev/null +++ b/axonn/timers.py @@ -0,0 +1,43 @@ +import torch +from collections import defaultdict +from collections import deque +import axonn + +class Timers(): + def __init__(self): + self.timers = defaultdict(list) + self.curr_index = defaultdict(int) + self.stack = deque() + + def start(self, key): + if not axonn.axonn.enable_timers: + return + self.stack.append(key) + key = tuple(self.stack) + index = self.curr_index[key] + timers = self.timers[key] + assert index == len(timers) or index < len(timers) + if index == len(timers): + self.timers[key].append([torch.cuda.Event(enable_timing=True) for _ in range(2)]) + self.timers[key][index][0].record() + + def stop(self, key): + if not axonn.axonn.enable_timers: + return + key = tuple(self.stack) + index = self.curr_index[key] + self.timers[key][index][1].record() + self.curr_index[key] += 1 + self.stack.pop() + + def get_times(self): + torch.cuda.synchronize() + total_times = defaultdict(float) + total_events = defaultdict(int) + for key in self.timers: + for events in self.timers[key]: + start_event, end_event = events + total_times[key] += start_event.elapsed_time(end_event) + total_events[key] = self.curr_index[key] + self.curr_index[key] = 0 + return total_times, total_events From 27740e88636206cc7e08a23ebd2cadaaf6dcef41 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 23 Oct 2024 14:15:46 -0400 Subject: [PATCH 3/3] reformat --- axonn/axonn.py | 3 ++- axonn/intra_layer/asym_communication.py | 2 ++ axonn/lightning/axonn_strategy.py | 5 ++--- axonn/timers.py | 9 ++++++--- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/axonn/axonn.py b/axonn/axonn.py index 60e3a51..279c9f8 100644 --- a/axonn/axonn.py +++ b/axonn/axonn.py @@ -25,7 +25,7 @@ def init( G_intra_c: int = 1, G_intra_d: int = 1, gpus_per_node: Optional[int] = None, - enable_internal_timers: bool = False + enable_internal_timers: bool = False, ) -> None: """ Initialize AxoNN's 2D parallelism with G_inter-way inter-layer @@ -120,6 +120,7 @@ def create_dataloader( **kwargs, ) # not working with drop_last=False + def get_timers(): global timers return timers diff --git a/axonn/intra_layer/asym_communication.py b/axonn/intra_layer/asym_communication.py index 3b719e1..454f320 100644 --- a/axonn/intra_layer/asym_communication.py +++ b/axonn/intra_layer/asym_communication.py @@ -7,6 +7,7 @@ import torch.distributed as dist from axonn import axonn as ax + def print_rank(msg): if dist.get_rank() == 0: print(f"{dist.get_rank()} | {msg}") @@ -28,6 +29,7 @@ def gather_batch_sizes(local_batch_size, process_group=None): ax.get_timers().stop("gather-batch-sizes") return global_batch_tensor + @torch.no_grad() def _allgatherv(tensor, rank_local_batch_sizes, process_group=None): ax.get_timers().start("allgatherv") diff --git a/axonn/lightning/axonn_strategy.py b/axonn/lightning/axonn_strategy.py index 91b629d..41de10e 100644 --- a/axonn/lightning/axonn_strategy.py +++ b/axonn/lightning/axonn_strategy.py @@ -70,7 +70,7 @@ def __init__( G_intra_c: int = 1, G_intra_d: int = 1, overlap_communication=False, - enable_timers = False, + enable_timers=False, activation_checkpointing: Optional[ Union[Type[Module], List[Type[Module]]] ] = None, @@ -218,7 +218,7 @@ def _setup_distributed(self) -> None: G_intra_r=self.G_intra_r, G_intra_c=self.G_intra_c, G_intra_d=self.G_intra_d, - enable_internal_timers=self.enable_timers + enable_internal_timers=self.enable_timers, ) def _get_process_group_backend(self) -> str: @@ -319,7 +319,6 @@ def module_init_context(self, empty_init: Optional[bool] = None): def module_sharded_context(self) -> ContextManager: return auto_parallelize() - def get_timers(self): assert self.enable_timers, "you should set enable_timers=True in AxoNNStrategy" return ax.get_timers() diff --git a/axonn/timers.py b/axonn/timers.py index 98e7be2..5339b4e 100644 --- a/axonn/timers.py +++ b/axonn/timers.py @@ -3,13 +3,14 @@ from collections import deque import axonn -class Timers(): + +class Timers: def __init__(self): self.timers = defaultdict(list) self.curr_index = defaultdict(int) self.stack = deque() - def start(self, key): + def start(self, key): if not axonn.axonn.enable_timers: return self.stack.append(key) @@ -18,7 +19,9 @@ def start(self, key): timers = self.timers[key] assert index == len(timers) or index < len(timers) if index == len(timers): - self.timers[key].append([torch.cuda.Event(enable_timing=True) for _ in range(2)]) + self.timers[key].append( + [torch.cuda.Event(enable_timing=True) for _ in range(2)] + ) self.timers[key][index][0].record() def stop(self, key):