Skip to content

Commit 96d6678

Browse files
vsytchGoogle-ML-Automation
authored andcommitted
[XLA] Googly changes
Don't mixup host {in,out}feed with other types of {in,out}feed. PiperOrigin-RevId: 727370147
1 parent 26f1a69 commit 96d6678

File tree

3 files changed

+166
-2
lines changed

3 files changed

+166
-2
lines changed

xla/hlo/transforms/collectives/BUILD

+2-1
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,13 @@ xla_cc_test(
422422
srcs = ["infeed_token_propagation_test.cc"],
423423
deps = [
424424
":infeed_token_propagation",
425+
"//xla/hlo/analysis:hlo_ordering",
425426
"//xla/hlo/ir:hlo",
426427
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
427428
"//xla/hlo/utils:hlo_matchers",
429+
"//xla/tsl/platform:statusor",
428430
"@com_google_absl//absl/strings:string_view",
429431
"@com_google_googletest//:gtest_main",
430-
"@tsl//tsl/platform:statusor",
431432
],
432433
)
433434

xla/hlo/transforms/collectives/infeed_token_propagation.cc

+8
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,12 @@ absl::Status InfeedTokenPropagation::PropagateToken(
424424
if (dangling_instruction_->opcode() != HloOpcode::kInfeed &&
425425
dangling_instruction_->opcode() != HloOpcode::kOutfeed) {
426426
for (HloInstruction* instruction : comp->instructions()) {
427+
if ((instruction->opcode() == HloOpcode::kInfeed &&
428+
!instruction->infeed_config().empty()) ||
429+
(instruction->opcode() == HloOpcode::kOutfeed &&
430+
!instruction->outfeed_config().empty())) {
431+
continue;
432+
}
427433
if (instruction->opcode() == original_opcode_) {
428434
HloInstruction* begin = ChainBegin(instruction);
429435
HloInstruction* end = ChainEnd(instruction);
@@ -497,6 +503,7 @@ absl::StatusOr<bool> InfeedTokenPropagation::Run(
497503
if (!computation->IsEntryComputation()) {
498504
for (HloInstruction* instruction : computation->instructions()) {
499505
if (instruction->opcode() == HloOpcode::kInfeed &&
506+
instruction->infeed_config().empty() &&
500507
IsDanglingInfeed(instruction)) {
501508
VLOG(1) << "Found dangling infeed: " << instruction->ToString();
502509
dangling_infeeds.push_back(instruction);
@@ -505,6 +512,7 @@ absl::StatusOr<bool> InfeedTokenPropagation::Run(
505512
}
506513
for (HloInstruction* instruction : computation->instructions()) {
507514
if (instruction->opcode() == HloOpcode::kOutfeed &&
515+
instruction->outfeed_config().empty() &&
508516
IsDanglingOutfeed(instruction)) {
509517
VLOG(1) << "Found dangling outfeed: " << instruction->ToString();
510518
dangling_outfeeds.push_back(instruction);

xla/hlo/transforms/collectives/infeed_token_propagation_test.cc

+156-1
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ limitations under the License.
2020
#include <gmock/gmock.h>
2121
#include <gtest/gtest.h>
2222
#include "absl/strings/string_view.h"
23+
#include "xla/hlo/analysis/hlo_ordering.h"
2324
#include "xla/hlo/ir/hlo_instruction.h"
2425
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
2526
#include "xla/hlo/utils/hlo_matchers.h"
26-
#include "tsl/platform/statusor.h"
27+
#include "xla/tsl/platform/statusor.h"
2728

2829
namespace op = xla::testing::opcode_matchers;
2930

@@ -824,5 +825,159 @@ ENTRY main {
824825
op::Outfeed(op::GetTupleElement(),
825826
op::GetTupleElement(op::Conditional(), 2))));
826827
}
828+
829+
TEST_F(InfeedTokenPropagationTest, ConditionalMixedInfeed) {
830+
constexpr absl::string_view hlo = R"(
831+
HloModule main
832+
833+
true_comp {
834+
arg.0 = () parameter(0)
835+
token.0 = after-all()
836+
host_infeed.0 = (s32[], token[]) infeed(token.0)
837+
ROOT tuple.0 = tuple()
838+
}
839+
840+
false_comp {
841+
arg.0 = () parameter(0)
842+
ROOT tuple.0 = tuple()
843+
}
844+
845+
ENTRY main {
846+
token.0 = after-all()
847+
core_infeed.0 = ((), token[]) infeed(token.0), infeed_config="core"
848+
pred.0 = pred[] constant(true)
849+
true_tuple.0 = tuple()
850+
false_tuple.0 = tuple()
851+
ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp
852+
}
853+
)";
854+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
855+
InfeedTokenPropagation itp;
856+
TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get()));
857+
EXPECT_TRUE(changed);
858+
859+
// The core infeed and host infeed should not be connected.
860+
DependencyHloOrdering ordering = DependencyHloOrdering(module.get());
861+
HloInstruction* core_infeed = FindInstruction(module.get(), "core_infeed.0");
862+
HloInstruction* host_infeed = FindInstruction(module.get(), "host_infeed.0");
863+
EXPECT_EQ(ordering.GetExecutionConstraint(core_infeed, host_infeed),
864+
HloOrdering::ExecutionConstraint::kUnordered);
865+
}
866+
867+
TEST_F(InfeedTokenPropagationTest, ConditionalMixedOutfeed) {
868+
constexpr absl::string_view hlo = R"(
869+
HloModule main
870+
871+
true_comp {
872+
arg.0 = s32[] parameter(0)
873+
token.0 = after-all()
874+
host_outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=s32[]
875+
ROOT tuple.0 = tuple()
876+
}
877+
878+
false_comp {
879+
arg.0 = s32[] parameter(0)
880+
ROOT tuple.0 = tuple()
881+
}
882+
883+
ENTRY main {
884+
arg.0 = s32[] parameter(0)
885+
token.0 = after-all()
886+
core_outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=s32[], outfeed_config="core"
887+
pred.0 = pred[] constant(true)
888+
ROOT cond.0 = () conditional(pred.0, arg.0, arg.0), true_computation=true_comp, false_computation=false_comp
889+
}
890+
)";
891+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
892+
InfeedTokenPropagation itp;
893+
TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get()));
894+
EXPECT_TRUE(changed);
895+
896+
// The core outfeed and host outfeed should not be connected.
897+
DependencyHloOrdering ordering = DependencyHloOrdering(module.get());
898+
HloInstruction* core_outfeed =
899+
FindInstruction(module.get(), "core_outfeed.0");
900+
HloInstruction* host_outfeed =
901+
FindInstruction(module.get(), "host_outfeed.0");
902+
EXPECT_EQ(ordering.GetExecutionConstraint(core_outfeed, host_outfeed),
903+
HloOrdering::ExecutionConstraint::kUnordered);
904+
}
905+
906+
TEST_F(InfeedTokenPropagationTest, WhileMixedInfeed) {
907+
constexpr absl::string_view hlo = R"(
908+
HloModule main
909+
910+
comp {
911+
arg.0 = () parameter(0)
912+
token.0 = after-all()
913+
host_infeed.0 = (s32[], token[]) infeed(token.0)
914+
ROOT tuple.0 = tuple()
915+
}
916+
917+
cond {
918+
arg.0 = () parameter(0)
919+
ROOT true.0 = pred[] constant(true)
920+
}
921+
922+
ENTRY main {
923+
token.0 = after-all()
924+
core_infeed.0 = ((), token[]) infeed(token.0), infeed_config="core"
925+
while_tuple.0 = tuple()
926+
ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp
927+
}
928+
)";
929+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
930+
InfeedTokenPropagation itp;
931+
TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get()));
932+
EXPECT_TRUE(changed);
933+
934+
// The core infeed and host infeed should not be connected.
935+
DependencyHloOrdering ordering = DependencyHloOrdering(module.get());
936+
HloInstruction* core_infeed = FindInstruction(module.get(), "core_infeed.0");
937+
HloInstruction* host_infeed = FindInstruction(module.get(), "host_infeed.0");
938+
EXPECT_EQ(ordering.GetExecutionConstraint(core_infeed, host_infeed),
939+
HloOrdering::ExecutionConstraint::kUnordered);
940+
}
941+
942+
TEST_F(InfeedTokenPropagationTest, WhileMixedOutfeed) {
943+
constexpr absl::string_view hlo = R"(
944+
HloModule main
945+
946+
comp {
947+
arg.0 = (s32[]) parameter(0)
948+
token.0 = after-all()
949+
host_outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[])
950+
gte.0 = get-tuple-element(arg.0), index=0
951+
ROOT tuple.0 = tuple(gte.0)
952+
}
953+
954+
cond {
955+
arg.0 = (s32[]) parameter(0)
956+
ROOT true.0 = pred[] constant(true)
957+
}
958+
959+
ENTRY main {
960+
arg.0 = s32[] parameter(0)
961+
token.0 = after-all()
962+
core_outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=s32[], outfeed_config="core"
963+
while_tuple.0 = tuple(arg.0)
964+
ROOT while.0 = (s32[]) while(while_tuple.0), condition=cond, body=comp
965+
}
966+
)";
967+
968+
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
969+
InfeedTokenPropagation itp;
970+
TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get()));
971+
EXPECT_TRUE(changed);
972+
973+
// The core outfeed and host outfeed should not be connected.
974+
DependencyHloOrdering ordering = DependencyHloOrdering(module.get());
975+
HloInstruction* core_outfeed =
976+
FindInstruction(module.get(), "core_outfeed.0");
977+
HloInstruction* host_outfeed =
978+
FindInstruction(module.get(), "host_outfeed.0");
979+
EXPECT_EQ(ordering.GetExecutionConstraint(core_outfeed, host_outfeed),
980+
HloOrdering::ExecutionConstraint::kUnordered);
981+
}
827982
} // namespace
828983
} // namespace xla

0 commit comments

Comments
 (0)