Skip to content

Commit 0b31741

Browse files
abhigunjtensorflower-gardener
authored andcommitted
[XLA:CPU] Remove unneeded MHLO dependencies from XLA CPU compiler
PiperOrigin-RevId: 727123960
1 parent 87c216a commit 0b31741

File tree

3 files changed

+18
-22
lines changed

3 files changed

+18
-22
lines changed

third_party/xla/xla/backends/cpu/nanort/ifrt_client_test.cc

+18-18
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ TEST(NanoIfrtClientTest, BigResult) {
5757
// A program that is likely to need some temporary buffers to be allocated.
5858
absl::string_view kBigResult =
5959
R"(module {
60-
func.func @main(%arg: tensor<f32>) -> tensor<1024x1024xf32> {
61-
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[1024, 1024]> : tensor<2xi64>} : (tensor<f32>) -> tensor<1024x1024xf32>
62-
%1 = "mhlo.add"(%0, %0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
63-
%2 = "mhlo.dot"(%1, %1) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
64-
return %2 : tensor<1024x1024xf32>
65-
}
66-
})";
60+
func.func @main(%arg0: tensor<f32>) -> tensor<1024x1024xf32> {
61+
%0 = stablehlo.broadcast %arg0, sizes = [1024, 1024] : (tensor<f32>) -> tensor<1024x1024xf32>
62+
%1 = stablehlo.add %0, %0 : tensor<1024x1024xf32>
63+
%2 = stablehlo.dot %1, %1 : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
64+
return %2 : tensor<1024x1024xf32>
65+
}
66+
})";
6767
auto client = NanoIfrtClient::Create();
6868
auto compiler = client->GetDefaultCompiler();
6969

@@ -142,7 +142,7 @@ static void BM_IfRtAddScalars(benchmark::State& state) {
142142
constexpr absl::string_view program =
143143
R"(module {
144144
func.func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
145-
%0 = mhlo.add %arg0, %arg1 : tensor<f32>
145+
%0 = stablehlo.add %arg0, %arg1 : tensor<f32>
146146
return %0 : tensor<f32>
147147
}
148148
})";
@@ -185,16 +185,16 @@ static void BM_IfRtAddManyScalars(benchmark::State& state) {
185185
-> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>,
186186
tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
187187
{
188-
%0 = mhlo.add %arg0, %arg1 : tensor<f32>
189-
%1 = mhlo.add %arg0, %0 : tensor<f32>
190-
%2 = mhlo.add %arg0, %1 : tensor<f32>
191-
%3 = mhlo.add %arg0, %2 : tensor<f32>
192-
%4 = mhlo.add %arg0, %3 : tensor<f32>
193-
%5 = mhlo.add %arg0, %4 : tensor<f32>
194-
%6 = mhlo.add %arg0, %5 : tensor<f32>
195-
%7 = mhlo.add %arg0, %6 : tensor<f32>
196-
%8 = mhlo.add %arg0, %7 : tensor<f32>
197-
%9 = mhlo.add %arg0, %8 : tensor<f32>
188+
%0 = stablehlo.add %arg0, %arg1 : tensor<f32>
189+
%1 = stablehlo.add %arg0, %0 : tensor<f32>
190+
%2 = stablehlo.add %arg0, %1 : tensor<f32>
191+
%3 = stablehlo.add %arg0, %2 : tensor<f32>
192+
%4 = stablehlo.add %arg0, %3 : tensor<f32>
193+
%5 = stablehlo.add %arg0, %4 : tensor<f32>
194+
%6 = stablehlo.add %arg0, %5 : tensor<f32>
195+
%7 = stablehlo.add %arg0, %6 : tensor<f32>
196+
%8 = stablehlo.add %arg0, %7 : tensor<f32>
197+
%9 = stablehlo.add %arg0, %8 : tensor<f32>
198198
return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
199199
: tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>,
200200
tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>

third_party/xla/xla/service/cpu/BUILD

-3
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,8 @@ cc_library(
292292
"//xla/hlo/transforms/simplifiers:tree_reduction_rewriter",
293293
"//xla/hlo/transforms/simplifiers:tuple_simplifier",
294294
"//xla/hlo/transforms/simplifiers:zero_sized_hlo_elimination",
295-
"//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
296-
"//xla/hlo/translate/hlo_to_mhlo:hlo_utils",
297295
"//xla/mlir_hlo",
298296
"//xla/mlir_hlo:all_passes",
299-
"//xla/mlir_hlo:mhlo_passes",
300297
"//xla/mlir_hlo:transforms_passes",
301298
"//xla/service:all_reduce_promotion",
302299
"//xla/service:all_to_all_decomposer",

third_party/xla/xla/service/cpu/cpu_compiler.cc

-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ limitations under the License.
139139
#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h"
140140
#include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h"
141141
#include "xla/hlo/transforms/while_loop_trip_count_annotator.h"
142-
#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
143142
#include "xla/literal.h"
144143
#include "xla/literal_pool.h"
145144
#include "xla/map_util.h"

0 commit comments

Comments
 (0)