@@ -43,7 +43,8 @@ limitations under the License.
43
43
#include " xla/literal_util.h"
44
44
#include " xla/primitive_util.h"
45
45
#include " xla/tests/client_library_test_runner_mixin.h"
46
- #include " xla/tests/hlo_test_base.h"
46
+ #include " xla/tests/hlo_pjrt_interpreter_reference_mixin.h"
47
+ #include " xla/tests/hlo_pjrt_test_base.h"
47
48
#include " xla/tests/test_macros.h"
48
49
#include " xla/tsl/platform/statusor.h"
49
50
#include " xla/tsl/platform/test.h"
@@ -104,7 +105,8 @@ void AddNegativeValuesMaybeRemoveZero(std::vector<T>& values) {
104
105
}
105
106
106
107
class ArrayElementwiseOpTest
107
- : public ClientLibraryTestRunnerMixin<HloTestBase> {
108
+ : public ClientLibraryTestRunnerMixin<
109
+ HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
108
110
public:
109
111
static constexpr float kEpsF32 = std::numeric_limits<float >::epsilon();
110
112
static constexpr double kEpsF64 = std::numeric_limits<double >::epsilon();
@@ -1339,7 +1341,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
1339
1341
}
1340
1342
1341
1343
template <typename T>
1342
- class TotalOrderTest : public ClientLibraryTestRunnerMixin <HloTestBase> {
1344
+ class TotalOrderTest : public ClientLibraryTestRunnerMixin <
1345
+ HloPjRtInterpreterReferenceMixin<HloPjRtTestBase>> {
1343
1346
public:
1344
1347
void DoIt (ComparisonDirection direction) {
1345
1348
this ->SetFastMathDisabled (true );
0 commit comments