@@ -20,10 +20,11 @@ limitations under the License.
20
20
#include < gmock/gmock.h>
21
21
#include < gtest/gtest.h>
22
22
#include " absl/strings/string_view.h"
23
+ #include " xla/hlo/analysis/hlo_ordering.h"
23
24
#include " xla/hlo/ir/hlo_instruction.h"
24
25
#include " xla/hlo/testlib/hlo_hardware_independent_test_base.h"
25
26
#include " xla/hlo/utils/hlo_matchers.h"
26
- #include " tsl/platform/statusor.h"
27
+ #include " xla/ tsl/platform/statusor.h"
27
28
28
29
namespace op = xla::testing::opcode_matchers;
29
30
@@ -824,5 +825,159 @@ ENTRY main {
824
825
op::Outfeed (op::GetTupleElement (),
825
826
op::GetTupleElement (op::Conditional (), 2 ))));
826
827
}
828
+
829
+ TEST_F (InfeedTokenPropagationTest, ConditionalMixedInfeed) {
830
+ constexpr absl::string_view hlo = R"(
831
+ HloModule main
832
+
833
+ true_comp {
834
+ arg.0 = () parameter(0)
835
+ token.0 = after-all()
836
+ host_infeed.0 = (s32[], token[]) infeed(token.0)
837
+ ROOT tuple.0 = tuple()
838
+ }
839
+
840
+ false_comp {
841
+ arg.0 = () parameter(0)
842
+ ROOT tuple.0 = tuple()
843
+ }
844
+
845
+ ENTRY main {
846
+ token.0 = after-all()
847
+ core_infeed.0 = ((), token[]) infeed(token.0), infeed_config="core"
848
+ pred.0 = pred[] constant(true)
849
+ true_tuple.0 = tuple()
850
+ false_tuple.0 = tuple()
851
+ ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp
852
+ }
853
+ )" ;
854
+ TF_ASSERT_OK_AND_ASSIGN (auto module, ParseAndReturnVerifiedModule (hlo));
855
+ InfeedTokenPropagation itp;
856
+ TF_ASSERT_OK_AND_ASSIGN (bool changed, itp.Run (module.get ()));
857
+ EXPECT_TRUE (changed);
858
+
859
+ // The core infeed and host infeed should not be connected.
860
+ DependencyHloOrdering ordering = DependencyHloOrdering (module.get ());
861
+ HloInstruction* core_infeed = FindInstruction (module.get (), " core_infeed.0" );
862
+ HloInstruction* host_infeed = FindInstruction (module.get (), " host_infeed.0" );
863
+ EXPECT_EQ (ordering.GetExecutionConstraint (core_infeed, host_infeed),
864
+ HloOrdering::ExecutionConstraint::kUnordered );
865
+ }
866
+
867
+ TEST_F (InfeedTokenPropagationTest, ConditionalMixedOutfeed) {
868
+ constexpr absl::string_view hlo = R"(
869
+ HloModule main
870
+
871
+ true_comp {
872
+ arg.0 = s32[] parameter(0)
873
+ token.0 = after-all()
874
+ host_outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=s32[]
875
+ ROOT tuple.0 = tuple()
876
+ }
877
+
878
+ false_comp {
879
+ arg.0 = s32[] parameter(0)
880
+ ROOT tuple.0 = tuple()
881
+ }
882
+
883
+ ENTRY main {
884
+ arg.0 = s32[] parameter(0)
885
+ token.0 = after-all()
886
+ core_outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=s32[], outfeed_config="core"
887
+ pred.0 = pred[] constant(true)
888
+ ROOT cond.0 = () conditional(pred.0, arg.0, arg.0), true_computation=true_comp, false_computation=false_comp
889
+ }
890
+ )" ;
891
+ TF_ASSERT_OK_AND_ASSIGN (auto module, ParseAndReturnVerifiedModule (hlo));
892
+ InfeedTokenPropagation itp;
893
+ TF_ASSERT_OK_AND_ASSIGN (bool changed, itp.Run (module.get ()));
894
+ EXPECT_TRUE (changed);
895
+
896
+ // The core outfeed and host outfeed should not be connected.
897
+ DependencyHloOrdering ordering = DependencyHloOrdering (module.get ());
898
+ HloInstruction* core_outfeed =
899
+ FindInstruction (module.get (), " core_outfeed.0" );
900
+ HloInstruction* host_outfeed =
901
+ FindInstruction (module.get (), " host_outfeed.0" );
902
+ EXPECT_EQ (ordering.GetExecutionConstraint (core_outfeed, host_outfeed),
903
+ HloOrdering::ExecutionConstraint::kUnordered );
904
+ }
905
+
906
+ TEST_F (InfeedTokenPropagationTest, WhileMixedInfeed) {
907
+ constexpr absl::string_view hlo = R"(
908
+ HloModule main
909
+
910
+ comp {
911
+ arg.0 = () parameter(0)
912
+ token.0 = after-all()
913
+ host_infeed.0 = (s32[], token[]) infeed(token.0)
914
+ ROOT tuple.0 = tuple()
915
+ }
916
+
917
+ cond {
918
+ arg.0 = () parameter(0)
919
+ ROOT true.0 = pred[] constant(true)
920
+ }
921
+
922
+ ENTRY main {
923
+ token.0 = after-all()
924
+ core_infeed.0 = ((), token[]) infeed(token.0), infeed_config="core"
925
+ while_tuple.0 = tuple()
926
+ ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp
927
+ }
928
+ )" ;
929
+ TF_ASSERT_OK_AND_ASSIGN (auto module, ParseAndReturnVerifiedModule (hlo));
930
+ InfeedTokenPropagation itp;
931
+ TF_ASSERT_OK_AND_ASSIGN (bool changed, itp.Run (module.get ()));
932
+ EXPECT_TRUE (changed);
933
+
934
+ // The core infeed and host infeed should not be connected.
935
+ DependencyHloOrdering ordering = DependencyHloOrdering (module.get ());
936
+ HloInstruction* core_infeed = FindInstruction (module.get (), " core_infeed.0" );
937
+ HloInstruction* host_infeed = FindInstruction (module.get (), " host_infeed.0" );
938
+ EXPECT_EQ (ordering.GetExecutionConstraint (core_infeed, host_infeed),
939
+ HloOrdering::ExecutionConstraint::kUnordered );
940
+ }
941
+
942
+ TEST_F (InfeedTokenPropagationTest, WhileMixedOutfeed) {
943
+ constexpr absl::string_view hlo = R"(
944
+ HloModule main
945
+
946
+ comp {
947
+ arg.0 = (s32[]) parameter(0)
948
+ token.0 = after-all()
949
+ host_outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[])
950
+ gte.0 = get-tuple-element(arg.0), index=0
951
+ ROOT tuple.0 = tuple(gte.0)
952
+ }
953
+
954
+ cond {
955
+ arg.0 = (s32[]) parameter(0)
956
+ ROOT true.0 = pred[] constant(true)
957
+ }
958
+
959
+ ENTRY main {
960
+ arg.0 = s32[] parameter(0)
961
+ token.0 = after-all()
962
+ core_outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=s32[], outfeed_config="core"
963
+ while_tuple.0 = tuple(arg.0)
964
+ ROOT while.0 = (s32[]) while(while_tuple.0), condition=cond, body=comp
965
+ }
966
+ )" ;
967
+
968
+ TF_ASSERT_OK_AND_ASSIGN (auto module, ParseAndReturnVerifiedModule (hlo));
969
+ InfeedTokenPropagation itp;
970
+ TF_ASSERT_OK_AND_ASSIGN (bool changed, itp.Run (module.get ()));
971
+ EXPECT_TRUE (changed);
972
+
973
+ // The core outfeed and host outfeed should not be connected.
974
+ DependencyHloOrdering ordering = DependencyHloOrdering (module.get ());
975
+ HloInstruction* core_outfeed =
976
+ FindInstruction (module.get (), " core_outfeed.0" );
977
+ HloInstruction* host_outfeed =
978
+ FindInstruction (module.get (), " host_outfeed.0" );
979
+ EXPECT_EQ (ordering.GetExecutionConstraint (core_outfeed, host_outfeed),
980
+ HloOrdering::ExecutionConstraint::kUnordered );
981
+ }
827
982
} // namespace
828
983
} // namespace xla
0 commit comments