Skip to content

Commit

Permalink
PR #18254: [XLA:GPU] Simplifies All-Reduce if they can be simplified …
Browse files Browse the repository at this point in the history
…or are degenerated.

Imported from GitHub PR #18254

This PR adds back the AllReduceSimplifier pass to the GPU compiler as in some cases we see JAX programs will add some degenerated psum to the HLOs,

Code pointer in JAX:
1) https://cs.opensource.google/jax/jax/+/main:jax/experimental/shard_map.py;drc=875f44c63a6c0f082ef3b228559e86e1d0a398d1;l=1674
2) pbroadcast to psum conversions.

The compiler should generally remove these unnecessary ARs so that it unblocks many optimizations like collective pipeliner. It also emits unncessary runtime thunks as well.
Copybara import of the project:

--
b89df55 by Yunlong Liu <yunlongl@x.ai>:

Removes degenerated all reduce.

--
055c409 by Yunlong Liu <yunlongl@x.ai>:

Creates a new e2e test target for GPU.

--
93934eb by Yunlong Liu <yunlongl@x.ai>:

E2E test passed.

Merging this change closes #18254

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18254 from yliu120:ar-simplify 93934eb
PiperOrigin-RevId: 685835366
  • Loading branch information
yliu120 authored and Google-ML-Automation committed Oct 21, 2024
1 parent 7f2f6b7 commit 84d548b
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
14 changes: 14 additions & 0 deletions xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,8 @@ cc_library(
srcs = ["simple_optimization_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
Expand Down Expand Up @@ -971,3 +973,15 @@ xla_test(
"@tsl//tsl/platform:test_main",
],
)

xla_test(
name = "gpu_compiler_e2e_test",
srcs = ["gpu_compiler_e2e_test.cc"],
backends = ["gpu"],
deps = [
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/tests:hlo_test_base",
"@tsl//tsl/platform:test_main",
],
)
62 changes: 62 additions & 0 deletions xla/service/gpu/tests/gpu_compiler_e2e_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "absl/strings/string_view.h"
#include "xla/service/pattern_matcher.h"
#include "xla/service/pattern_matcher_gmock.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tsl/lib/core/status_test_util.h"

namespace xla {
namespace gpu {
namespace {

namespace m = match;

class GpuCompilerE2ETest : public HloTestBase {};

TEST_F(GpuCompilerE2ETest, DegeneratedAllReduceRemoval) {
constexpr absl::string_view kHloText = R"(
HloModule m
sum {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT add.2 = f32[] add(a, b)
}
main {
p0 = f32[8,16] parameter(0), parameter_replication={false}
ROOT all-reduce = f32[8,16] all-reduce(p0),
channel_id=1,
use_global_device_ids=true,
replica_groups={{0},{1},{2},{3},{4},{5},{6},{7}},
to_apply=sum
}
)";

TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(kHloText, /*replica_count=*/1,
/*num_partitions=*/8));
module->mutable_config().set_use_spmd_partitioning(true);
TF_ASSERT_OK_AND_ASSIGN(auto optimized_module,
GetOptimizedModule(std::move(module)));
EXPECT_THAT(optimized_module->entry_computation()->root_instruction(),
GmockMatch(m::Copy(m::Parameter(0))));
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 84d548b

Please sign in to comment.