Skip to content

Commit

Permalink
PR #18152: [XLA:GPU] Avoid fusion-wrapping copies
Browse files Browse the repository at this point in the history
Imported from GitHub PR #18152

Fusion wrapping copies breaks the logic for detecting copies from copy-insertion in rematerialization pass.

This patch avoids wrapping copy instructions and instead emits them directly in IrEmitterUnnested.

This should fix #17922
Copybara import of the project:

--
49daad1 by Jaroslav Sevcik <jsevcik@nvidia.com>:

Avoid fusion-wrapping copies

Merging this change closes #18152

COPYBARA_INTEGRATE_REVIEW=#18152 from jaro-sevcik:avoid-fusion-wrapping-copies 49daad1
PiperOrigin-RevId: 686055013
  • Loading branch information
jaro-sevcik authored and Google-ML-Automation committed Oct 15, 2024
1 parent c4c0f4a commit 8b93301
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 30 deletions.
9 changes: 7 additions & 2 deletions xla/layout_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ absl::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
return CopyLayoutInternal(src, dst);
}

/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
const Shape& rhs) {
/* static */ bool LayoutUtil::LayoutsInShapesEqual(
const Shape& lhs, const Shape& rhs, std::optional<Layout::Equal> equal) {
if (lhs.IsTuple()) {
if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) !=
ShapeUtil::TupleElementCount(rhs)) {
Expand All @@ -647,6 +647,11 @@ absl::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
if (!lhs.has_layout() || !rhs.has_layout()) {
return false;
}

if (equal.has_value()) {
return equal.value()(lhs.layout(), rhs.layout());
}

return LayoutUtil::Equal(lhs.layout(), rhs.layout());
}
// Layouts of non-array and non-tuple shapes is ignored.
Expand Down
4 changes: 3 additions & 1 deletion xla/layout_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ class LayoutUtil {
// lhs and rhs need not be compatible to have the same layout but the two
// shapes must have the same tuple structure (if any) and arrays must have the
// same rank. Element type is ignored.
static bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs);
static bool LayoutsInShapesEqual(
const Shape& lhs, const Shape& rhs,
std::optional<Layout::Equal> equal = std::nullopt);

// Returns whether the given dimensions are consecutive in the given layout,
// not necessarily in the order given.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4469,12 +4469,15 @@ triton_dot {
cvt1 = f32[3,3,2,16]{1,3,2,0} convert(p1)
p0 = f16[9,32]{0,1} parameter(0)
b0 = f16[3,3,2,16]{1,0,3,2} bitcast(p0)
cp0 = f16[3,3,2,16]{1,3,2,0} copy(b0)
cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0)
cp0b0 = f16[2,16,3,3]{3,2,1,0} bitcast(b0)
cp0t0 = f16[3,2,16,3]{3,2,1,0} transpose(cp0b0), dimensions={2,0,1,3}
cp0b1 = f16[3,3,2,16]{1,3,2,0} bitcast(cp0t0)
cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0b1)
m = f32[3,3,2,16]{1,3,2,0} multiply(cvt1, cvt0)
cvt2 = f16[3,3,2,16]{1,3,2,0} convert(m)
cp1 = f16[3,3,2,16]{3,2,1,0} copy(cvt2)
b1 = f16[9,32]{1,0} bitcast(cp1)
cp1b0 = f16[3,2,16,3]{3,2,1,0} bitcast(cvt2)
cp1t0 = f16[3,3,2,16]{3,2,1,0} transpose(cp1b0), dimensions={0,3,1,2}
b1 = f16[9,32]{1,0} bitcast(cp1t0)
p2 = f16[32,32]{1,0} parameter(2)
ROOT r = f16[9,32]{1,0} dot(b1, p2),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
Expand All @@ -4498,12 +4501,15 @@ ENTRY e {
cvt1 = f32[3,3,2,16]{1,3,2,0} convert(p1)
p0 = f16[9,32]{0,1} parameter(0)
b0 = f16[3,3,2,16]{1,0,3,2} bitcast(p0)
cp0 = f16[3,3,2,16]{1,3,2,0} copy(b0)
cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0)
cp0b0 = f16[2,16,3,3]{3,2,1,0} bitcast(b0)
cp0t0 = f16[3,2,16,3]{3,2,1,0} transpose(cp0b0), dimensions={2,0,1,3}
cp0b1 = f16[3,3,2,16]{1,3,2,0} bitcast(cp0t0)
cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0b1)
m = f32[3,3,2,16]{1,3,2,0} multiply(cvt1, cvt0)
cvt2 = f16[3,3,2,16]{1,3,2,0} convert(m)
cp1 = f16[3,3,2,16]{3,2,1,0} copy(cvt2)
b1 = f16[9,32]{1,0} bitcast(cp1)
cp1b0 = f16[3,2,16,3]{3,2,1,0} bitcast(cvt2)
cp1t0 = f16[3,3,2,16]{3,2,1,0} transpose(cp1b0), dimensions={0,3,1,2}
b1 = f16[9,32]{1,0} bitcast(cp1t0)
p2 = f16[32,32]{1,0} parameter(2)
ROOT r = f16[9,32]{1,0} dot(b1, p2),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
Expand Down
21 changes: 19 additions & 2 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,22 @@ absl::Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr) {
return absl::OkStatus();
}

absl::Status IrEmitterUnnested::EmitCopy(const HloInstruction* instr) {
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
instr->operand(0)->shape(), instr->shape(),
Layout::Equal().MinorToMajorOnly()));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice src_buffer,
GetAllocationSliceForHlo(instr->operand(0)));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst_buffer,
GetAllocationSliceForHlo(instr));
AddThunkToThunkSequence(std::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo::WithProfileAnnotation(instr),
/*source_buffer=*/src_buffer,
/*destination_buffer=*/dst_buffer,
/*mem_size=*/src_buffer.size()));
return absl::OkStatus();
}

absl::Status IrEmitterUnnested::EmitAsyncCustomCallStart(
const HloInstruction* instr) {
const HloInstruction* wrapped = instr->async_wrapped_instruction();
Expand Down Expand Up @@ -2773,9 +2789,10 @@ absl::Status IrEmitterUnnested::EmitHloInstruction(
}
return EmitCustomCallThunk(custom_call);
}
case HloOpcode::kFusion: {
case HloOpcode::kFusion:
return EmitFusion(Cast<HloFusionInstruction>(instr));
}
case HloOpcode::kCopy:
return EmitCopy(instr);
case HloOpcode::kInfeed:
return EmitInfeed(Cast<HloInfeedInstruction>(instr));
case HloOpcode::kOutfeed:
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/ir_emitter_unnested.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class IrEmitterUnnested : public IrEmitter {
absl::Status EmitCustomCallThunk(const HloCustomCallInstruction* instr);
absl::Status EmitFftThunk(const HloFftInstruction* instr);
absl::Status EmitFusion(const HloFusionInstruction* instr);
absl::Status EmitCopy(const HloInstruction* instr);
absl::Status EmitAsyncCustomCallStart(const HloInstruction* instr);
absl::Status EmitSelectAndScatter(
const HloSelectAndScatterInstruction* instr);
Expand Down
6 changes: 4 additions & 2 deletions xla/service/gpu/tests/sorting_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ compare {
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY TestComputation {
x = f32[3, 2]{1, 0} parameter(0)
x.copy = f32[3, 2]{0, 1} copy(x)
ROOT sort = f32[3, 2]{0, 1} sort(x.copy), dimensions={1}, to_apply=compare
tr = f32[2, 3]{1, 0} transpose(x), dimensions={1,0}
b = f32[3, 2]{0, 1} bitcast(tr)
ROOT sort = f32[3, 2]{0, 1} sort(b), dimensions={1}, to_apply=compare
}
)";
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/transforms/fusion_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ absl::StatusOr<bool> FusionWrapper::Run(
case HloOpcode::kConcatenate:
case HloOpcode::kConvolution:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kDivide:
case HloOpcode::kDot:
Expand Down
28 changes: 20 additions & 8 deletions xla/service/gpu/transforms/fusion_wrapper_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,22 @@ TEST_F(FusionWrapperTest, ControlDependency) {
// CHECK-SAME: control-predecessors={%fusion})");
}

TEST_F(FusionWrapperTest, Copy) {
// Avoid rewriting copies, so that the rematerialization pass
// can avoid rematerializing copies inserted by copy-insertion
// (the rematerialization could read overwritten data).
RunAndFilecheckHloRewrite(R"(
HloModule Copy
ENTRY %main (parameter.1: f32[5]) -> f32[5] {
%parameter.1 = f32[5]{0} parameter(0)
ROOT %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
})",
FusionWrapper(),
// No change
std::nullopt);
}

TEST_F(FusionWrapperTest, While) {
RunAndFilecheckHloRewrite(R"(
HloModule While
Expand All @@ -148,8 +164,8 @@ TEST_F(FusionWrapperTest, While) {
})",
FusionWrapper(), R"(
// CHECK: %wrapped_broadcast_computation {{.*}} {
// CHECK: %param_0.1 = f32[] parameter(0)
// CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={}
// CHECK: %param_0 = f32[] parameter(0)
// CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0), dimensions={}
// CHECK: }
// CHECK: %body {{.*}} {
// CHECK: %parameter.5 = (f32[5]{0}) parameter(0)
Expand All @@ -161,14 +177,10 @@ TEST_F(FusionWrapperTest, While) {
// CHECK: %parameter.12 = (f32[5]{0}) parameter(0)
// CHECK: ROOT %constant_1 = pred[] constant(false)
// CHECK: }
// CHECK: %wrapped_copy_computation {{.*}} {
// CHECK: %param_0 = f32[5]{0} parameter(0)
// CHECK: ROOT %copy.0 = f32[5]{0} copy(%param_0)
// CHECK: }
// CHECK: ENTRY %main {{.*}} {
// CHECK: %parameter.1 = f32[5]{0} parameter(0)
// CHECK: %wrapped_copy = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%wrapped_copy_computation
// CHECK: %tuple = (f32[5]{0}) tuple(%wrapped_copy)
// CHECK: %copy.3 = f32[5]{0} copy(%parameter.1)
// CHECK: %tuple = (f32[5]{0}) tuple(%copy.3)
// CHECK: ROOT %while.19 = (f32[5]{0}) while(%tuple), condition=%cond, body=%body
// CHECK: })");
}
Expand Down
8 changes: 8 additions & 0 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1818,7 +1818,15 @@ xla_test(
":test_macros_header",
":test_utils",
":xla_internal_test_main",
"//xla:error_spec",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
Expand Down
32 changes: 26 additions & 6 deletions xla/tests/reduce_hlo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,29 @@ limitations under the License.
==============================================================================*/

#include <array>

#include <cstdint>
#include <memory>
#include <ostream>
#include <string>
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "xla/error_spec.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/shape.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/test_macros.h"
#include "xla/tests/test_utils.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

// Tests the Reduce HLO in ways that can't be done using the ComputationBuilder
Expand Down Expand Up @@ -64,8 +81,9 @@ Sum {
ENTRY reduce.1 {
parameter = f32[2,2,2,3]{3,2,1,0} parameter(0)
init_value = f32[] constant(0)
reduce = f32[2,2,3]{2,1,0} reduce(parameter, init_value), dimensions={1}, to_apply=Sum
ROOT copy = f32[2,2,3]{2,1,0} copy(reduce)
reduce = f32[2,2,3]{2,1,0} reduce(parameter, init_value), dimensions={1},
to_apply=Sum transpose = f32[2,2,3]{2,1,0} transpose(reduce),
dimensions={0,1,2} ROOT bitcast = f32[2,2,3]{2,1,0} bitcast(transpose)
}
)";

Expand All @@ -79,8 +97,10 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_TPU(Reduce)) {
}

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, GetParsedModule());
HloInstruction* reduce_instruction =
module->entry_computation()->root_instruction()->mutable_operand(0);
HloInstruction* reduce_instruction = module->entry_computation()
->root_instruction()
->mutable_operand(0)
->mutable_operand(0);
ASSERT_EQ(reduce_instruction->opcode(), HloOpcode::kReduce);

const ReduceLayout& reduce_layout = GetParam();
Expand Down Expand Up @@ -110,7 +130,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_TPU(Reduce)) {
{-0.241772294, -0.245131493, -0.160247207},
{-0.179881215, -0.23383224, -0.121976733}}}});

auto reduce_input_relaid =
Literal reduce_input_relaid =
reduce_input.Relayout(reduce_input_shape->layout());
EXPECT_TRUE(RunAndCompareNoHloPasses(
std::move(module), {&reduce_input_relaid}, ErrorSpec(1e-5)));
Expand Down

0 comments on commit 8b93301

Please sign in to comment.