Skip to content

Commit bf2d6e7

Browse files
nvgrwtensorflower-gardener
authored andcommitted
Migrate array_elementwise_ops_test to always use PjRt for its test backend.
PiperOrigin-RevId: 720791217
1 parent 0092ab7 commit bf2d6e7

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

third_party/xla/xla/tests/BUILD

+6-2
Original file line numberDiff line numberDiff line change
@@ -885,10 +885,14 @@ xla_test(
885885
name = "array_elementwise_ops_test",
886886
srcs = ["array_elementwise_ops_test.cc"],
887887
shard_count = 25,
888-
tags = ["test_xla_cpu_no_thunks"],
888+
tags = [
889+
"test_migrated_to_hlo_runner_pjrt",
890+
"test_xla_cpu_no_thunks",
891+
],
889892
deps = [
890893
":client_library_test_runner_mixin",
891-
":hlo_test_base",
894+
":hlo_pjrt_interpreter_reference_mixin",
895+
":hlo_pjrt_test_base",
892896
":test_macros_header",
893897
":xla_internal_test_main",
894898
"//xla:array2d",

third_party/xla/xla/tests/array_elementwise_ops_test.cc

+6-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ limitations under the License.
4343
#include "xla/literal_util.h"
4444
#include "xla/primitive_util.h"
4545
#include "xla/tests/client_library_test_runner_mixin.h"
46-
#include "xla/tests/hlo_test_base.h"
46+
#include "xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
47+
#include "xla/tests/hlo_pjrt_test_base.h"
4748
#include "xla/tests/test_macros.h"
4849
#include "xla/tsl/platform/statusor.h"
4950
#include "xla/tsl/platform/test.h"
@@ -104,7 +105,8 @@ void AddNegativeValuesMaybeRemoveZero(std::vector<T>& values) {
104105
}
105106

106107
class ArrayElementwiseOpTest
107-
: public ClientLibraryTestRunnerMixin<HloTestBase> {
108+
: public ClientLibraryTestRunnerMixin<
109+
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
108110
public:
109111
static constexpr float kEpsF32 = std::numeric_limits<float>::epsilon();
110112
static constexpr double kEpsF64 = std::numeric_limits<double>::epsilon();
@@ -1339,7 +1341,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
13391341
}
13401342

13411343
template <typename T>
1342-
class TotalOrderTest : public ClientLibraryTestRunnerMixin<HloTestBase> {
1344+
class TotalOrderTest : public ClientLibraryTestRunnerMixin<
1345+
HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
13431346
public:
13441347
void DoIt(ComparisonDirection direction) {
13451348
this->SetFastMathDisabled(true);

0 commit comments

Comments
 (0)