Skip to content

Commit

Permalink
Disable unnecessary non expert comm (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddharth9820 authored Oct 25, 2024
1 parent 6252013 commit b7f7886
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
4 changes: 4 additions & 0 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def sync_gradients(
model, gradient_attr_name, mean, vectorize
)
return
ax.get_timers().start("sync-gradients-non-expert")
grads_to_sync = {
"tensor_parallel_weights": [],
"tensor_parallel_biases": [],
Expand Down Expand Up @@ -166,8 +167,11 @@ def sync_gradients(
if mean:
grad.div_(torch.distributed.get_world_size())

ax.get_timers().start("AR-others-world")
for grad in grads_to_sync["others"]:
# all other weights are purely data parallel
dist.all_reduce(grad)
if mean:
grad.div_(torch.distributed.get_world_size())
ax.get_timers().stop("AR-others-world")
ax.get_timers().stop("sync-gradients-non-expert")
28 changes: 18 additions & 10 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,18 @@ def forward(
if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1):
# extra communication to transition from pure data parallelism
# to 4D hybrid parallelism
inner_group_batch_sizes = gather_batch_sizes(x.shape[0], self.inner_group)
x = GatherBatchScatterChannels.apply(
x, inner_group_batch_sizes, self.inner_group
)
outer_group_batch_sizes = gather_batch_sizes(x.shape[0], self.outer_group)
x = Gatherv.apply(x, outer_group_batch_sizes, self.outer_group)
if self.inner_group_size > 1:
inner_group_batch_sizes = gather_batch_sizes(
x.shape[0], self.inner_group
)
x = GatherBatchScatterChannels.apply(
x, inner_group_batch_sizes, self.inner_group
)
if self.outer_group_size > 1:
outer_group_batch_sizes = gather_batch_sizes(
x.shape[0], self.outer_group
)
x = Gatherv.apply(x, outer_group_batch_sizes, self.outer_group)
x = AsyncLinear.apply(
x,
weight,
Expand All @@ -329,10 +335,12 @@ def forward(
if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1):
# extra communication to transition from 4D hybrid parallelism
# to pure data parallelism
x = GatherChannelsScatterBatch.apply(
x, outer_group_batch_sizes, self.outer_group
)
x = Dropv.apply(x, inner_group_batch_sizes, self.inner_group)
if self.outer_group_size > 1:
x = GatherChannelsScatterBatch.apply(
x, outer_group_batch_sizes, self.outer_group
)
if self.inner_group_size > 1:
x = Dropv.apply(x, inner_group_batch_sizes, self.inner_group)

x = x.reshape(*original_shape_x[:-1], x.shape[-1])

Expand Down

0 comments on commit b7f7886

Please sign in to comment.