Skip to content

Commit d401ee1

Browse files
PR tensorflow#22437: Added frontend attribute handling to explicit_stream_annotation_async_wrapper
Imported from GitHub PR openxla/xla#22437 This is a small change that ensures the frontend attributes are correctly passed to both the `async-start` and `async-done` created pairs. This also clears the scheduling attributes that are directly on the call operation and inner ops. The specific goal of this change is to have stable support combining the scheduling group ids with stream annotation in JAX. ```python with set_xla_metadata(_scheduling_group_id=1): result = compute_on("gpu_stream:1")(jitted_func)(...) ``` Currently, the issue stems from the `set_xla_metadata` context manager, which will apply the frontend attribute to all operations, including the ones within our `jitted_func`. When the same scheduling annotations is found in two `HloComputation`s, an error is raised in `LegalizeSchedulingAnnotations`. This is intended to avoid hitting this check, and cleaning up the annotations on the wrapped streamed computation. Copybara import of the project: -- 994c2eee3c946102270587681f5c17b994cbb6a9 by chaser <chaser@nvidia.com>: Added frontend attributed handling -- 9db58b2b988dc2288d42126271223f924aac19f9 by chaser <chaser@nvidia.com>: Added clearing of scheduling annotations -- a83e32a34ba5d64a29c7f01b03536f27decd8125 by chaser <chaser@nvidia.com>: Added HloInstruction.erase_frontend_attribute Merging this change closes tensorflow#22437 PiperOrigin-RevId: 731960979
1 parent 2d5f74d commit d401ee1

5 files changed

+90
-11
lines changed

third_party/xla/xla/hlo/ir/hlo_instruction.h

+4
Original file line numberDiff line numberDiff line change
@@ -1871,6 +1871,10 @@ class HloInstruction {
18711871
return it.second;
18721872
}
18731873

1874+
size_t erase_frontend_attribute(const std::string& key) {
1875+
return mutable_rare()->frontend_attributes.mutable_map()->erase(key);
1876+
}
1877+
18741878
// Adds or overrides a single attribute in the HloInstruction.
18751879
void set_frontend_attribute(const std::string& key,
18761880
const std::string& value) {

third_party/xla/xla/hlo/ir/hlo_instruction_test.cc

+12
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,17 @@ TEST(HloInstruction, AddFrontendAttributes) {
6161
EXPECT_EQ(instr.get_frontend_attribute("key2").value(), "value2");
6262
}
6363

64+
TEST(HloInstruction, EraseFrontendAttribute) {
65+
HloConstantInstruction instr(ShapeUtil::MakeShape(U32, {3, 2}));
66+
instr.add_frontend_attribute("key1", "value1");
67+
instr.add_frontend_attribute("key2", "value2");
68+
EXPECT_EQ(instr.erase_frontend_attribute("key2"), 1);
69+
EXPECT_EQ(instr.erase_frontend_attribute("not_a_key"), 0);
70+
EXPECT_EQ(instr.get_frontend_attribute("key1").value(), "value1")
71+
<< "key1 should not be erased";
72+
EXPECT_EQ(instr.get_frontend_attribute("key2"), std::nullopt)
73+
<< "key2 should have been erased";
74+
}
75+
6476
} // namespace
6577
} // namespace xla

third_party/xla/xla/service/gpu/transforms/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -1608,6 +1608,7 @@ xla_cc_test(
16081608
srcs = ["explicit_stream_annotation_async_wrapper_test.cc"],
16091609
deps = [
16101610
":explicit_stream_annotation_async_wrapper",
1611+
"//xla:side_effect_util",
16111612
"//xla/hlo/ir:hlo",
16121613
"//xla/hlo/testlib:filecheck",
16131614
"//xla/service/gpu:backend_configs_cc",

third_party/xla/xla/service/gpu/transforms/explicit_stream_annotation_async_wrapper.cc

+21-1
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,41 @@ limitations under the License.
3333
namespace xla::gpu {
3434

3535
namespace {
36+
37+
void ClearSchedulingAnnotations(HloInstruction* instr) {
38+
// These attributes are only valid on the async pairs.
39+
instr->erase_frontend_attribute(kXlaSchedulingGroupIdAttr);
40+
instr->erase_frontend_attribute(kXlaStreamAnnotationAttr);
41+
}
42+
3643
static absl::StatusOr<bool> AsynchronizeInstruction(HloInstruction* instr) {
3744
if (instr->opcode() != HloOpcode::kCall ||
3845
!instr->frontend_attributes().map().contains(kXlaStreamAnnotationAttr)) {
3946
return false;
4047
}
4148
HloComputation* computation = instr->parent();
49+
auto original_attributes = instr->frontend_attributes();
50+
51+
// These annotations are only legal on the async instructions and
52+
// can cause issues if the annotations remain on the inner operations,
53+
// so we clear them before creating the async pair.
54+
for (auto* inner_instr : instr->called_computations()[0]->instructions()) {
55+
ClearSchedulingAnnotations(inner_instr);
56+
}
57+
ClearSchedulingAnnotations(instr);
58+
4259
TF_ASSIGN_OR_RETURN(
4360
HloInstruction * done,
4461
computation->CreateAsyncInstructions(
4562
instr, {},
4663
ExplicitStreamAnnotationAsyncWrapper::kExplicitExecutionThread,
4764
/*replace=*/true));
65+
// Replace the original attributes after creating the async pair.
66+
done->set_frontend_attributes(original_attributes);
67+
done->mutable_operand(0)->set_frontend_attributes(original_attributes);
4868
TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
4969
done->backend_config<GpuBackendConfig>());
50-
// Set the false delay of done op to be false so it can be scheduled
70+
// Set earliest schedule of done op to be false so it can be scheduled
5171
// far apart from start.
5272
gpu_config.set_force_earliest_schedule(false);
5373
TF_RETURN_IF_ERROR(done->set_backend_config(gpu_config));

third_party/xla/xla/service/gpu/transforms/explicit_stream_annotation_async_wrapper_test.cc

+52-10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include "xla/hlo/ir/hlo_instruction.h"
2424
#include "xla/hlo/testlib/filecheck.h"
2525
#include "xla/service/gpu/backend_configs.pb.h"
26+
#include "xla/side_effect_util.h"
2627
#include "xla/tests/hlo_test_base.h"
2728
#include "xla/tsl/lib/core/status_test_util.h"
2829
#include "xla/tsl/platform/statusor.h"
@@ -72,21 +73,23 @@ TEST_F(ExplicitStreamAnnotationAsyncWrapperTest, OverlappingGemms) {
7273
%gemm1 (z: f32[2048,2048], w: f32[2048,2048]) -> f32[2048,2048] {
7374
%w = f32[2048,2048]{1,0} parameter(1)
7475
%z = f32[2048,2048]{1,0} parameter(0)
75-
%custom-call.1 = (f32[2048,2048]{1,0}, s8[33554432]{0}) custom-call(f32[2048,2048]{1,0} %w, f32[2048,2048]{1,0} %z), custom_call_target="__cublas$gemm"
76+
%custom-call.1 = (f32[2048,2048]{1,0}, s8[33554432]{0}) custom-call(f32[2048,2048]{1,0} %w, f32[2048,2048]{1,0} %z), custom_call_target="__cublas$gemm",
77+
frontend_attributes={_scheduling_group_id="0", _xla_stream_annotation="1"}
7678
ROOT %get-tuple-element = f32[2048,2048]{1,0} get-tuple-element((f32[2048,2048]{1,0}, s8[33554432]{0}) %custom-call.1), index=0
7779
}
7880
%gemm2 (a: f32[2048,2048], b: f32[2048,2048]) -> f32[2048,2048] {
7981
%a = f32[2048,2048]{1,0} parameter(1)
8082
%b = f32[2048,2048]{1,0} parameter(0)
81-
%custom-call.2 = (f32[2048,2048]{1,0}, s8[33554432]{0}) custom-call(f32[2048,2048]{1,0} %a, f32[2048,2048]{1,0} %b), custom_call_target="__cublas$gemm"
83+
%custom-call.2 = (f32[2048,2048]{1,0}, s8[33554432]{0}) custom-call(f32[2048,2048]{1,0} %a, f32[2048,2048]{1,0} %b), custom_call_target="__cublas$gemm",
84+
frontend_attributes={_scheduling_group_id="1", _xla_stream_annotation="2"}
8285
ROOT %get-tuple-element = f32[2048,2048]{1,0} get-tuple-element((f32[2048,2048]{1,0}, s8[33554432]{0}) %custom-call.2), index=0
8386
}
8487
8588
ENTRY %main () -> f32[2048,2048]{1,0} {
8689
%x = f32[2048,2048]{1,0} parameter(1), metadata={op_name="b" scheduling_name="x"}
8790
%y = f32[2048,2048]{1,0} parameter(0), metadata={op_name="a" scheduling_name="y"}
88-
%call1 = f32[2048,2048]{1,0} call(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y ), to_apply=%gemm1, frontend_attributes={_xla_stream_annotation="1"}
89-
ROOT %call2 = f32[2048,2048]{1,0} call(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y), to_apply=%gemm2, frontend_attributes={_xla_stream_annotation="2"}
91+
%call1 = f32[2048,2048]{1,0} call(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y ), to_apply=%gemm1, frontend_attributes={_scheduling_group_id="0", _xla_stream_annotation="2"}
92+
ROOT %call2 = f32[2048,2048]{1,0} call(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y), to_apply=%gemm2, frontend_attributes={_scheduling_group_id="1", _xla_stream_annotation="1"}
9093
})";
9194

9295
auto debug_options = HloTestBase::GetDebugOptionsForTest();
@@ -96,16 +99,55 @@ TEST_F(ExplicitStreamAnnotationAsyncWrapperTest, OverlappingGemms) {
9699
ExplicitStreamAnnotationAsyncWrapper wrapper_pass;
97100

98101
TF_ASSERT_OK_AND_ASSIGN(bool mutated, wrapper_pass.Run(module.get()));
102+
ASSERT_TRUE(mutated);
103+
99104
absl::StatusOr<bool> filecheck_result = RunFileCheck(module->ToString({}), R"(
100-
// CHECK: %call-start = ((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) call-start(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y), async_execution_thread="explicit", to_apply=%gemm1, frontend_attributes={_xla_stream_annotation="1"}
101-
// CHECK: %call-done = f32[2048,2048]{1,0} call-done(((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) %call-start), frontend_attributes={_xla_stream_annotation="1"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false}
102-
// CHECK: %call-start.1 = ((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) call-start(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y), async_execution_thread="explicit", to_apply=%gemm2, frontend_attributes={_xla_stream_annotation="2"}
103-
// CHECK: ROOT %call-done.1 = f32[2048,2048]{1,0} call-done(((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) %call-start.1), frontend_attributes={_xla_stream_annotation="2"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false}
105+
// CHECK: %call-start = ((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) call-start(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y), async_execution_thread="explicit", to_apply=%gemm1, frontend_attributes={_scheduling_group_id="0",_xla_stream_annotation="2"}
106+
// CHECK: %call-done = f32[2048,2048]{1,0} call-done(((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) %call-start), frontend_attributes={_scheduling_group_id="0",_xla_stream_annotation="2"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false}
107+
// CHECK: %call-start.1 = ((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) call-start(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y), async_execution_thread="explicit", to_apply=%gemm2, frontend_attributes={_scheduling_group_id="1",_xla_stream_annotation="1"}
108+
// CHECK: ROOT %call-done.1 = f32[2048,2048]{1,0} call-done(((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) %call-start.1), frontend_attributes={_scheduling_group_id="1",_xla_stream_annotation="1"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false}
104109
)");
105110
TF_ASSERT_OK(filecheck_result.status());
106111
EXPECT_TRUE(*filecheck_result);
107-
108-
ASSERT_TRUE(mutated);
112+
for (auto name : {"call-start", "call-done"}) {
113+
EXPECT_EQ(FindInstruction(module.get(), name)
114+
->frontend_attributes()
115+
.map()
116+
.find(kXlaStreamAnnotationAttr)
117+
->second,
118+
"2");
119+
EXPECT_EQ(FindInstruction(module.get(), name)
120+
->frontend_attributes()
121+
.map()
122+
.find(kXlaSchedulingGroupIdAttr)
123+
->second,
124+
"0");
125+
}
126+
for (auto name : {"call-start.1", "call-done.1"}) {
127+
EXPECT_EQ(FindInstruction(module.get(), name)
128+
->frontend_attributes()
129+
.map()
130+
.find(kXlaStreamAnnotationAttr)
131+
->second,
132+
"1");
133+
EXPECT_EQ(FindInstruction(module.get(), name)
134+
->frontend_attributes()
135+
.map()
136+
.find(kXlaSchedulingGroupIdAttr)
137+
->second,
138+
"1");
139+
}
140+
// Ensure the operations within the async computation are not annotated
141+
// anymore.
142+
for (auto annotation :
143+
{kXlaSchedulingGroupIdAttr, kXlaStreamAnnotationAttr}) {
144+
for (auto name : {"custom-call.1", "custom-call.2"}) {
145+
EXPECT_FALSE(FindInstruction(module.get(), name)
146+
->frontend_attributes()
147+
.map()
148+
.contains(annotation));
149+
}
150+
}
109151
}
110152
} // namespace
111153
} // namespace xla::gpu

0 commit comments

Comments
 (0)