Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A context manager to optimize communication #54

Merged
merged 6 commits into from
Nov 29, 2023
Merged

Conversation

siddharth9820
Copy link
Collaborator

@siddharth9820 siddharth9820 commented Nov 24, 2023

If the forward and backward pass are done within this context, axonn overlaps all the tensor parallel communication in the backward pass (reduce scatters for depth tensor parallelism and all reduces for row and column tensor parallelism)

Example of Usage

for batch, label in loader():
     with axonn.intra_layer.optimize_communication():
         out = model(batch)
         out.backward() # enqueues reduce scatters asynchronously
         # gradients are not ready

     # as soon as we exit the context all reduce scatters are synchronized and the gradients 
     # are ready
     optimizer.step() # important to run OUTSIDE context.

@siddharth9820
Copy link
Collaborator Author

image

@siddharth9820
Copy link
Collaborator Author

siddharth9820 commented Nov 24, 2023

ToDo

  • CI is broken when it tries to test asynchrony.

Gotchas

  • overlap_reduce_scatter=True might cause issues with frameworks that try to attach gradient hooks to weight tensors - like DDP, ZeRO.

@siddharth9820 siddharth9820 merged commit 3ebc34c into develop Nov 29, 2023
6 checks passed
@siddharth9820 siddharth9820 deleted the comm-ctx-manager branch November 29, 2023 19:15
Avuxon pushed a commit that referenced this pull request Jan 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant