Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort allgathers according to consumer order, reduce scatter according to producer order #592

Closed
wants to merge 5 commits into from

Conversation

kiya00
Copy link
Collaborator

@kiya00 kiya00 commented Jun 13, 2024

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #574.

This PR suggests a way to sort the communication ops. Previously we used one function to sort the allgather and reduce_scatter, but ran into some problems when dealing with topological equal nodes (when we sort allgather and reduce_scatter in one pass, if top-down, it results in allgather relying on the order of input params, if bottom-up, it results in reduce_scatter might accumulate before wait)

In this PR we have 2 sorting function:

  • sort allgather: bottom-up topological sorting, sort the all_gather_prim_impl and its wait nodes according to the consumer order, and put all_gather_prim_impl just before wait.

  • sort reduce_scatter: top-down topological sorting, sort the reduce/reduce_scatter and its wait node according to the producer order, and maximum the distance between reduce/reduce_scatter and wait.

But since our reduce_scatter is already in the right place without the sorting, we need to think about whether this PR is really necessary

@kiya00 kiya00 requested a review from IvanYashchuk June 13, 2024 13:40
@kiya00
Copy link
Collaborator Author

kiya00 commented Jun 25, 2024

main w/o bucket main w/ bucket this PR w/o bucket this PR w/ bucket
zero3 444.11 ms 28.85 GB 455.65 ms 34.45 GB 432.00 ms 29.09 GB 456.50 ms 34.97 GB
zero2 408.44 ms 42.06 GB 424.66 ms 42.07 GB 408.29 ms 42.06 GB 424.68 ms 42.07 GB

@kiya00 kiya00 requested a review from crcrpar June 26, 2024 12:07
@kiya00 kiya00 marked this pull request as ready for review June 26, 2024 12:07
@kiya00 kiya00 requested review from mruberry, lantiga and t-vi as code owners June 26, 2024 12:07
@t-vi
Copy link
Collaborator

t-vi commented Jun 26, 2024

But since our reduce_scatter is already in the right place without the sorting, we need to think about whether this PR is really necessary

There is small memory increase. Is that expected?

@kiya00 @crcrpar curious to hear your thoughts.

@IvanYashchuk
Copy link
Collaborator

The issue description has the following expectation:

sort the allgathers to their consumer order and list them at the beginning of the trace, the corresponding waits are right before the consumers (maximum distance between allgathers and waits).

but the description of this pull request says that allgather sorting puts the allgather op right before the wait op so the communication wouldn't be overlapping with computation. Is the pull request description accurate?

@kiya00
Copy link
Collaborator Author

kiya00 commented Jun 27, 2024

currently the single sort_allgather puts the allgather op right before the wait op, and the number in limit_in_flight_allgathers (INT_MAX, 3) will get the allgather to proper position for zero2/3.

@kiya00
Copy link
Collaborator Author

kiya00 commented Jun 27, 2024

After some discussion with @IvanYashchuk , I think we can use the fix suggested by @kshitij12345 , the reduce_scatters before sorting are already in the right position, it's enough to solve the #557 . I think we could consider using this PR if we can not rely on the original order one day. But for now I'll close it.

diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py
index 7e0e81a..3ff3337 100644
--- a/thunder/distributed/utils.py
+++ b/thunder/distributed/utils.py
@@ -84,8 +84,9 @@ def sort_communication_ops(execution_trace):
                 case wait_prim_impl.id | unpack_for_fsdp_prim_impl.id:
                     return len(order_in_trace)
                 case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id:
-                    # Prefer larger communication ops over smaller ones
-                    return -node.bsym.args[0].numel
+                    # We want to keep the `reduce` close to it's producer
+                    # (which is close to the original place in the trace).
+                    return order_in_trace[node.bsym]
                 case all_gather_prim_impl.id:
                     return len(order_in_trace) + order_in_trace[node.bsym]
                 case _: 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Find a way to properly sort the communication operators for zero2/zero3
3 participants