From 221ac9b200e08c6b692d40301073d30dfc119cd9 Mon Sep 17 00:00:00 2001 From: Yunlong Liu Date: Mon, 14 Oct 2024 14:25:20 -0700 Subject: [PATCH] PR #18254: [XLA:GPU] Simplifies All-Reduce if they can be simplified or are degenerated. Imported from GitHub PR https://github.com/openxla/xla/pull/18254 This PR adds back the AllReduceSimplifier pass to the GPU compiler as in some cases we see JAX programs will add some degenerated psum to the HLOs, Code pointer in JAX: 1) https://cs.opensource.google/jax/jax/+/main:jax/experimental/shard_map.py;drc=875f44c63a6c0f082ef3b228559e86e1d0a398d1;l=1674 2) pbroadcast to psum conversions. The compiler should generally remove these unnecessary ARs so that it unblocks many optimizations like collective pipeliner. It also emits unncessary runtime thunks as well. Copybara import of the project: -- b89df553bd260e60a472d4b69268be66fff478b6 by Yunlong Liu : Removes degenerated all reduce. -- 055c409c51e9c4ef36f19da0c049cf6de8880b34 by Yunlong Liu : Creates a new e2e test target for GPU. -- 93934eb15d4fd0571123965174f3713c19d183f5 by Yunlong Liu : E2E test passed. Merging this change closes #18254 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/18254 from yliu120:ar-simplify 93934eb15d4fd0571123965174f3713c19d183f5 PiperOrigin-RevId: 685835366 --- xla/service/gpu/tests/BUILD | 14 +++++ .../gpu/tests/gpu_compiler_e2e_test.cc | 62 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 xla/service/gpu/tests/gpu_compiler_e2e_test.cc diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index 1538d88668854..6870d3350014e 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -882,6 +882,8 @@ cc_library( srcs = ["simple_optimization_test.cc"], tags = tf_cuda_tests_tags(), deps = [ + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -971,3 +973,15 @@ xla_test( "@tsl//tsl/platform:test_main", ], ) + +xla_test( + name = "gpu_compiler_e2e_test", + srcs = ["gpu_compiler_e2e_test.cc"], + backends = ["gpu"], + deps = [ + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/xla/service/gpu/tests/gpu_compiler_e2e_test.cc b/xla/service/gpu/tests/gpu_compiler_e2e_test.cc new file mode 100644 index 0000000000000..2baf42312e4a6 --- /dev/null +++ b/xla/service/gpu/tests/gpu_compiler_e2e_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace m = match; + +class GpuCompilerE2ETest : public HloTestBase {}; + +TEST_F(GpuCompilerE2ETest, DegeneratedAllReduceRemoval) { + constexpr absl::string_view kHloText = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +main { + p0 = f32[8,16] parameter(0), parameter_replication={false} + ROOT all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + use_global_device_ids=true, + replica_groups={{0},{1},{2},{3},{4},{5},{6},{7}}, + to_apply=sum +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kHloText, /*replica_count=*/1, + /*num_partitions=*/8)); + module->mutable_config().set_use_spmd_partitioning(true); + TF_ASSERT_OK_AND_ASSIGN(auto optimized_module, + GetOptimizedModule(std::move(module))); + EXPECT_THAT(optimized_module->entry_computation()->root_instruction(), + GmockMatch(m::Copy(m::Parameter(0)))); +} + +} // namespace +} // namespace gpu +} // namespace xla