Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable unnecessary non expert comm #103

Merged
merged 4 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def sync_gradients(
model, gradient_attr_name, mean, vectorize
)
return
ax.get_timers().start("sync-gradients-non-expert")
grads_to_sync = {
"tensor_parallel_weights": [],
"tensor_parallel_biases": [],
Expand Down Expand Up @@ -166,8 +167,11 @@ def sync_gradients(
if mean:
grad.div_(torch.distributed.get_world_size())

ax.get_timers().start("AR-others-world")
for grad in grads_to_sync["others"]:
# all other weights are purely data parallel
dist.all_reduce(grad)
if mean:
grad.div_(torch.distributed.get_world_size())
ax.get_timers().stop("AR-others-world")
ax.get_timers().stop("sync-gradients-non-expert")
28 changes: 18 additions & 10 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,12 +311,18 @@ def forward(
if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1):
# extra communication to transition from pure data parallelism
# to 4D hybrid parallelism
inner_group_batch_sizes = gather_batch_sizes(x.shape[0], self.inner_group)
x = GatherBatchScatterChannels.apply(
x, inner_group_batch_sizes, self.inner_group
)
outer_group_batch_sizes = gather_batch_sizes(x.shape[0], self.outer_group)
x = Gatherv.apply(x, outer_group_batch_sizes, self.outer_group)
if self.inner_group_size > 1:
inner_group_batch_sizes = gather_batch_sizes(
x.shape[0], self.inner_group
)
x = GatherBatchScatterChannels.apply(
x, inner_group_batch_sizes, self.inner_group
)
if self.outer_group_size > 1:
outer_group_batch_sizes = gather_batch_sizes(
x.shape[0], self.outer_group
)
x = Gatherv.apply(x, outer_group_batch_sizes, self.outer_group)
x = AsyncLinear.apply(
x,
weight,
Expand All @@ -329,10 +335,12 @@ def forward(
if not self.expert_mode and (self.inner_group_size * self.outer_group_size > 1):
# extra communication to transition from 4D hybrid parallelism
# to pure data parallelism
x = GatherChannelsScatterBatch.apply(
x, outer_group_batch_sizes, self.outer_group
)
x = Dropv.apply(x, inner_group_batch_sizes, self.inner_group)
if self.outer_group_size > 1:
x = GatherChannelsScatterBatch.apply(
x, outer_group_batch_sizes, self.outer_group
)
if self.inner_group_size > 1:
x = Dropv.apply(x, inner_group_batch_sizes, self.inner_group)

x = x.reshape(*original_shape_x[:-1], x.shape[-1])

Expand Down
Loading