diff --git a/axonn/intra_layer/__init__.py b/axonn/intra_layer/__init__.py index 46242bf..ee6ff97 100644 --- a/axonn/intra_layer/__init__.py +++ b/axonn/intra_layer/__init__.py @@ -125,6 +125,7 @@ def sync_gradients( model, gradient_attr_name, mean, vectorize ) return + ax.get_timers().start("sync-gradients-non-expert") grads_to_sync = { "tensor_parallel_weights": [], "tensor_parallel_biases": [], @@ -166,8 +167,11 @@ def sync_gradients( if mean: grad.div_(torch.distributed.get_world_size()) + ax.get_timers().start("AR-others-world") for grad in grads_to_sync["others"]: # all other weights are purely data parallel dist.all_reduce(grad) if mean: grad.div_(torch.distributed.get_world_size()) + ax.get_timers().stop("AR-others-world") + ax.get_timers().stop("sync-gradients-non-expert") diff --git a/axonn/intra_layer/fully_connected.py b/axonn/intra_layer/fully_connected.py index 6ef8337..af75590 100644 --- a/axonn/intra_layer/fully_connected.py +++ b/axonn/intra_layer/fully_connected.py @@ -311,12 +311,18 @@ def forward( if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1): # extra communication to transition from pure data parallelism # to 4D hybrid parallelism - inner_group_batch_sizes = gather_batch_sizes(x.shape[0], self.inner_group) - x = GatherBatchScatterChannels.apply( - x, inner_group_batch_sizes, self.inner_group - ) - outer_group_batch_sizes = gather_batch_sizes(x.shape[0], self.outer_group) - x = Gatherv.apply(x, outer_group_batch_sizes, self.outer_group) + if self.inner_group_size > 1: + inner_group_batch_sizes = gather_batch_sizes( + x.shape[0], self.inner_group + ) + x = GatherBatchScatterChannels.apply( + x, inner_group_batch_sizes, self.inner_group + ) + if self.outer_group_size > 1: + outer_group_batch_sizes = gather_batch_sizes( + x.shape[0], self.outer_group + ) + x = Gatherv.apply(x, outer_group_batch_sizes, self.outer_group) x = AsyncLinear.apply( x, weight, @@ -329,10 +335,12 @@ def forward( if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1): # extra communication to transition from 4D hybrid parallelism # to pure data parallelism - x = GatherChannelsScatterBatch.apply( - x, outer_group_batch_sizes, self.outer_group - ) - x = Dropv.apply(x, inner_group_batch_sizes, self.inner_group) + if self.outer_group_size > 1: + x = GatherChannelsScatterBatch.apply( + x, outer_group_batch_sizes, self.outer_group + ) + if self.inner_group_size > 1: + x = Dropv.apply(x, inner_group_batch_sizes, self.inner_group) x = x.reshape(*original_shape_x[:-1], x.shape[-1])