Skip to content

Commit

Permalink
Add support for FP8 types to reshape_test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686086143
  • Loading branch information
Google-ML-Automation committed Oct 15, 2024
1 parent 0d637e4 commit b46d80b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 19 deletions.
9 changes: 7 additions & 2 deletions xla/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2142,20 +2142,25 @@ xla_test(
":test_macros_header",
":xla_internal_test_main",
"//xla:array2d",
"//xla:array3d",
"//xla:array4d",
"//xla:error_spec",
"//xla:literal",
"//xla:literal_util",
"//xla:reference_util",
"//xla:shape_util",
"//xla:status_macros",
"//xla:test",
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla/client:global_data",
"//xla/client:local_client",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:ml_dtypes",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
57 changes: 40 additions & 17 deletions xla/tests/reshape_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,46 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <memory>
#include <numeric>
#include <random>
#include <string>
#include <vector>

#include <gtest/gtest.h>
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/array2d.h"
#include "xla/array3d.h"
#include "xla/array4d.h"
#include "xla/client/global_data.h"
#include "xla/client/local_client.h"
#include "xla/error_spec.h"
#include "xla/hlo/builder/xla_builder.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/layout_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/reference_util.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/test.h"
#include "xla/tests/client_library_test_base.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tests/test_macros.h"
#include "xla/types.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/ml_dtypes.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {

// Use a bool parameter to indicate whether to use bfloat16.
class ReshapeTest : public ::testing::WithParamInterface<bool>,
class ReshapeTest : public ::testing::WithParamInterface<PrimitiveType>,
public ClientLibraryTestBase {
public:
ReshapeTest() { set_use_bfloat16(GetParam()); }
ReshapeTest() { set_float_type(GetParam()); }

ErrorSpec zero_error_spec_{0.0};
};
Expand Down Expand Up @@ -652,16 +657,15 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
XlaComputation computation = builder.Build().value();
ExecutionOptions execution_options = execution_options_;
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithDenseLayout(use_bfloat16() ? BF16 : F32, {2, 8},
{1, 0})
ShapeUtil::MakeShapeWithDenseLayout(FloatType(), {2, 8}, {1, 0})
.ToProto();
Literal actual =
client_
->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
.value();
Literal expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) {
expected = LiteralUtil::ConvertF32ToBF16(expected);
if (FloatType() != F32) {
expected = MaybeConvertLiteralToTestType(expected);
}
EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
}
Expand Down Expand Up @@ -808,8 +812,8 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {

ExecutionOptions execution_options = execution_options_;
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithDenseLayout(use_bfloat16() ? BF16 : F32,
{7, 2, 3, 5}, {2, 3, 0, 1})
ShapeUtil::MakeShapeWithDenseLayout(FloatType(), {7, 2, 3, 5},
{2, 3, 0, 1})
.ToProto();
Literal output_literal =
client_
Expand All @@ -819,11 +823,29 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {

// Since the reshape is a no-op, verify that it does not change the underlying
// data.
if (use_bfloat16()) {
auto expected = LiteralUtil::ConvertF32ToBF16(input_literal);
EXPECT_EQ(expected.data<bfloat16>(), output_literal.data<bfloat16>());
} else {
EXPECT_EQ(input_literal.data<float>(), output_literal.data<float>());
switch (FloatType()) {
case F32:
EXPECT_EQ(input_literal.data<float>(), output_literal.data<float>());
break;
case BF16: {
auto expected = MaybeConvertLiteralToTestType(input_literal);
EXPECT_EQ(expected.data<bfloat16>(), output_literal.data<bfloat16>());
break;
}
case F8E4M3FN: {
auto expected = MaybeConvertLiteralToTestType(input_literal);
EXPECT_EQ(expected.data<tsl::float8_e4m3fn>(),
output_literal.data<tsl::float8_e4m3fn>());
break;
}
case F8E5M2: {
auto expected = MaybeConvertLiteralToTestType(input_literal);
EXPECT_EQ(expected.data<tsl::float8_e5m2>(),
output_literal.data<tsl::float8_e5m2>());
break;
}
default:
LOG(FATAL) << "Unsupported float type: " << FloatType();
}
}

Expand Down Expand Up @@ -1017,7 +1039,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
zero_error_spec_, &expected.shape());
}

INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, ::testing::Bool());
INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest,
::testing::ValuesIn({F32, BF16, F8E5M2, F8E4M3FN}));

using ReshapeHloTest = HloTestBase;

Expand Down

0 comments on commit b46d80b

Please sign in to comment.