Skip to content

reorderShardedAxisPass for DID loop split #4256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 55 commits into from
Jun 4, 2025
Merged

Conversation

Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Apr 16, 2025

Issue #3900.

Key changes:

  1. Modifying allocation domain instead of logical domain
  • Previous implementation modified the logical shape of the communication inputs and outputs such that the gathered/scattered axis were outermost. Current implementation only sets the allocation domain
  1. hasShardingChanges -> getGatherOrScatterCommInfo
  • The new function finds any gather / scatter / reduce scatter communicaton patterns and returns the logical iterdomains involved in the communication. This accomodates any DID loop split.
  1. isInnerResharding -> isCommLayoutCompliant and isAllocatedOutermost to support split allocation domains seen in DID loop split

  2. Dependency on canLower is removed to decouple from Issue Extend InsertReshardingPass for loop split. #4382.

This PR does not handle ParallelType::Stream right now.

TODO: ReduceScatter tests after PR #4384

@Priya2698
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Apr 16, 2025

Review updated until commit 50a4ee2

Description

  • Refactor communication layout compliance checks

  • Introduce CommunicationInfo struct for communication pattern analysis

  • Update canLower to use isCommunicationLayoutCompliant

  • Enhance lower_to_communication.cpp with new functions and logic


Changes walkthrough 📝

Relevant files
Enhancement
7 files
lower.cpp
Update canLower to use isCommunicationLayoutCompliant       
+2/-2     
lower_to_communication.cpp
Introduce CommunicationInfo and related functions               
+188/-5 
communication.cpp
Add contiguity checks in postAllreduce and postReduceScatter
+16/-1   
utils.cpp
Remove deprecated functions and add new utility functions
+2/-86   
reorder_sharded_axis.cpp
Refactor pass to use new communication layout compliance checks
+173/-118
lower_to_communication.h
Add declarations for CommunicationInfo and related functions
+29/-0   
utils.h
Remove deprecated functions and add new utility functions declarations
+5/-15   
Tests
4 files
test_multidevice_lower_communication.cpp
Update and add new tests for communication layout compliance
+127/-90
test_resharding.cpp
Update tests to use new communication layout compliance checks
+3/-2     
test_communication.py
Re-enable and update test_reduce_scatter_noncontiguous     
+0/-3     
test_matmul.py
Update test parameters for multidevice matmul                       
+1/-1     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Issue

The new function isCommunicationLayoutCompliant is used in place of isInnerResharding. Ensure that this new function correctly identifies inner resharding scenarios and that the logic is consistent with the previous implementation.

if (!ignore_inner_resharding && !isCommunicationLayoutCompliant(expr)) {
  return false;
Performance Concern

The new function isAllocationOrderCompliant checks if the allocation order is compliant with NCCL/UCC requirements. Ensure that this function is efficient and does not introduce significant overhead during the compilation process.

bool isAllocationOrderCompliant(TensorView* tv, IterDomain* sharded_id) {
  NVF_ERROR(
      std::find(
          tv->getLogicalDomain().begin(),
          tv->getLogicalDomain().end(),
          sharded_id) != tv->getLogicalDomain().end(),
      "The sharded ID ",
      sharded_id->toString(),
      " is not in the logical domain ",
      tv->getLogicalDomain());

  if (isLocalSizeOne(sharded_id)) {
    // Parallelized dimension, broadcast, and reduction do not affect
    // allocation.
    return true;
  }

  // This sharded logical ID may not be directly present in allocation domain.
  // This indicates allocation domain has DID transformations.
  std::optional<Layout> layout = canonicalizeLayout(tv);
  if (!layout.has_value()) {
    return false;
  }

  const std::vector<IterDomain*>& allocation_domain = layout->allocation_domain;

  NVF_ERROR(
      std::is_permutation(
          allocation_domain.begin(),
          allocation_domain.end(),
          tv->getLogicalDomain().begin(),
          tv->getLogicalDomain().end()),
      "The allocation domain returned by canonicalizeLayout",
      allocation_domain,
      " should be a permutation of the logical domain ",
      tv->getLogicalDomain());

  // Check if sharded_id appears at the front.
  for (IterDomain* id : allocation_domain) {
    if (id == sharded_id) {
      return true;
    }
    if (!isLocalSizeOne(id)) {
      return false;
    }
Test Coverage

The test Allgather_NonCompliantAllocation checks for non-compliant allocation. Ensure that this test covers all edge cases and that the expected behavior is correctly validated.

TensorView* tv0 = makeConcreteTensor({5, d * 3});
tv0->setAllocationDomain(tv0->getLogicalDomain(), false);

TensorView* tv1 = set(tv0);
tv1->setAllocationDomain(tv1->getLogicalDomain(), true);

tv0->setDeviceMesh(mesh);
tv0->outer_split(1, d);
tv0->axis(1)->parallelize(ParallelType::DIDx);

tv1->setDeviceMesh(mesh);

fusion->addInput(tv0);
fusion->addOutput(tv1);

at::Tensor unsharded_in_tensor = at::randn({5, d * 3}, tensor_options);
at::Tensor in_tensor = shardTensor(unsharded_in_tensor, 1, mesh);

FusionExecutorCache executor_cache(std::move(fusion));
at::Tensor out_tensor =
    executor_cache.runFusionWithInputs({in_tensor})[0].as<at::Tensor>();

EXPECT_TRUE(out_tensor.is_contiguous());
EXPECT_TRUE(at::allclose(out_tensor, unsharded_in_tensor));

FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime();
EXPECT_THAT(
    runtime->fusionSegments()->groups(),
    Contains(HeuristicIs(SchedulerType::PointWise)).Times(2));

@Priya2698 Priya2698 force-pushed the pm/preseg_reorder_sharded branch from b420957 to 8e64e6e Compare May 8, 2025 00:22
@Priya2698
Copy link
Collaborator Author

!test

1 similar comment
@Priya2698
Copy link
Collaborator Author

!test

@Priya2698 Priya2698 force-pushed the pm/preseg_reorder_sharded branch from f8e6435 to fb07219 Compare May 10, 2025 02:49
@Priya2698
Copy link
Collaborator Author

!test --diff

1 similar comment
@Priya2698
Copy link
Collaborator Author

!test --diff

@Priya2698 Priya2698 force-pushed the pm/preseg_reorder_sharded branch from 8803b66 to 58f1039 Compare May 14, 2025 20:49
@Priya2698
Copy link
Collaborator Author

!test --diff

@Priya2698
Copy link
Collaborator Author

!test

@Priya2698 Priya2698 force-pushed the pm/preseg_reorder_sharded branch from 8fb5af9 to 35090b0 Compare May 20, 2025 20:18
@Priya2698 Priya2698 changed the title Pm/preseg reorder sharded reorderShardedAxisPass for DID loop split May 20, 2025
Priya2698 added a commit that referenced this pull request May 21, 2025
Extracted the logic converting a split allocation domain to logical from
`alias.cpp` to ir_utils.
This avoids duplicate logic in PR #4256 when checking allocation of
gathered/scattered axis.
@Priya2698 Priya2698 force-pushed the pm/preseg_reorder_sharded branch from d635cb8 to 602de55 Compare May 21, 2025 22:18
@Priya2698 Priya2698 marked this pull request as ready for review May 22, 2025 00:53
@Priya2698 Priya2698 requested a review from wujingyue May 22, 2025 01:01
@Priya2698
Copy link
Collaborator Author

!test

@Priya2698
Copy link
Collaborator Author

!test

@Priya2698
Copy link
Collaborator Author

!test

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First batch; still reviewing...

const std::vector<IterDomain*>& domain,
const IterDomain* id);

// Returns the communication info for the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code-comment that this assumes expr has been decomposed.

I'd move this to lower_to_communication.h so it's closer to

std::vector<Expr*> convertSingleOpToCommunication(
. The two functions should be kept in sync, so it makes sense to put them close to each other.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can lead to circular dependency since that file import multidevice/utils and multidevice/utils itself needs this function.
In a separate PR, I can extend this function to be more elaborate and use it directly in convertSingleOpToCommunication to avoid any mismatch between them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can lead to circular dependency since that file import multidevice/utils and multidevice/utils itself needs this function.

AFAICT, to avoid circular dependency, only getCommunicationInfo and isCommunicationLayoutCompliant need to go to lower_to_communication. They probably should anyhow because they are all about lowering.

I general, a large "utils.h" header file tends to cause the following issues:

  1. Increased Compilation Times: A large utils.h header will be included in many source files, leading to a lot of parsing and compilation overhead, significantly increasing build times. Any change to utils.h, even a minor one, requires recompilation of all files that include it, further slowing down development.
  2. Reduced Encapsulation and Interface Clarity: A large utils.h likely exposes many internal utility functions and data structures to clients that don't need to know about them, violating the principle of information hiding and reducing encapsulation. This makes it harder to understand the actual public interface and can lead to accidental misuse or dependencies on internal implementation details.

Copy link
Collaborator Author

@Priya2698 Priya2698 May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. lower.cpp and lower_to_communication.cpp are importing each other -> This is easy to resolve. lower.cpp does not need lower_to_communication.h.
  2. lower.cpp can hold the above functions but it imports presegmentation passes. This is fine for now, since reorderShardedAxisPass is removed from there. However, previously these 2 files were also importing each other. The same dependency exists between other preseg passes, and this file. The preseg passes query canLower whereas lower calls the preseg passes.

We need a restructuring to avoid this. Moving canLower to lower_to_communication seems sufficient but it's a part of HostIrLower class

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Circular dependencies among cpps are OK, although still not the best practice. Recall that cpp files are compiled independently. Circular dependencies among header files are much more problematic, but I don't think you are hitting any.

To avoid circular dependencies among cpps 1 for FusionExecutorCache, I think we can:

  1. Let preseg depend on lower_to_communication.
  2. Don't let lower_to_communication depend on lower. HostIrLowerParams can be avoided by passing in CommunicatorBackend directly.
    HostIrLower::canLower(c),
    can be removed or if MultiDeviceExecutor needs it be moved to lower.
  3. Keep lower away from FusionExecutorCache. I saw several #include lower.h in the main stack, but none of them seem necessary.

Footnotes

  1. This isn't accurate because we never include a cpp from another cpp. I'm really referring to scenarios like a.cpp include b.h and b.cpp include a.h.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Priya2698 Priya2698 requested a review from wujingyue May 27, 2025 21:03
@Priya2698 Priya2698 force-pushed the pm/preseg_reorder_sharded branch from 56ebef4 to b2122d5 Compare May 28, 2025 21:38
@Priya2698 Priya2698 force-pushed the pm/preseg_reorder_sharded branch from 1fd7ea9 to 8c55e43 Compare June 3, 2025 22:56
@Priya2698
Copy link
Collaborator Author

!test

@Priya2698
Copy link
Collaborator Author

!test

@Priya2698 Priya2698 merged commit 5508e22 into main Jun 4, 2025
52 of 53 checks passed
@Priya2698 Priya2698 deleted the pm/preseg_reorder_sharded branch June 4, 2025 20:03
wujingyue added a commit that referenced this pull request Jun 16, 2025
As a follow-up to
#4256 (comment)

This makes the test more realistic and gives better coverage. It indeed
caught #4642, a new bug.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants