Skip to content

Commit 0b24d47

Browse files
majnemerThe ml_dtypes Authors
authored and
The ml_dtypes Authors
committed
Remove more initialization by const-ref
While we are here, implement double/long double -> float8 conversions via conversion to float. This lets us avoid doing arithmetic using 64-bit types during the float8 rounding step. This also ensures that we can correctly round exotic types like `long double`. PiperOrigin-RevId: 577251071
1 parent 161db24 commit 0b24d47

File tree

2 files changed

+73
-13
lines changed

2 files changed

+73
-13
lines changed

ml_dtypes/include/float8.h

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,13 @@ class float8_base {
6161

6262
template <typename T>
6363
explicit EIGEN_DEVICE_FUNC float8_base(
64-
T f, std::enable_if_t<std::is_arithmetic_v<T>, int> = 0)
65-
: float8_base(ConvertFrom(static_cast<float>(f)).rep(),
64+
T i, std::enable_if_t<std::is_integral_v<T>, int> = 0)
65+
: float8_base(ConvertFrom(static_cast<float>(i)).rep(),
6666
ConstructFromRepTag{}) {}
67-
explicit EIGEN_DEVICE_FUNC float8_base(double f64)
68-
: float8_base(ConvertFrom(f64).rep(), ConstructFromRepTag{}) {}
69-
explicit EIGEN_DEVICE_FUNC float8_base(float f32)
70-
: float8_base(ConvertFrom(f32).rep(), ConstructFromRepTag{}) {}
67+
template <typename T>
68+
explicit EIGEN_DEVICE_FUNC float8_base(
69+
T f, std::enable_if_t<std::is_floating_point_v<T>, int> = 0)
70+
: float8_base(ConvertFrom(f).rep(), ConstructFromRepTag{}) {}
7171
explicit EIGEN_DEVICE_FUNC float8_base(Eigen::bfloat16 bf16)
7272
: float8_base(ConvertFrom(bf16).rep(), ConstructFromRepTag{}) {}
7373
explicit EIGEN_DEVICE_FUNC float8_base(Eigen::half f16)
@@ -112,10 +112,10 @@ class float8_base {
112112

113113
// Conversions allowing saturation and truncation.
114114
template <bool kSaturate = false, bool kTruncate = false, typename From>
115-
static inline EIGEN_DEVICE_FUNC Derived ConvertFrom(const From& from);
115+
static inline EIGEN_DEVICE_FUNC Derived ConvertFrom(From from);
116116

117117
template <typename To, bool kSaturate = false, bool kTruncate = false>
118-
static inline EIGEN_DEVICE_FUNC To ConvertTo(const Derived& from);
118+
static inline EIGEN_DEVICE_FUNC To ConvertTo(Derived from);
119119

120120
// Operators via float32.
121121
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
@@ -634,7 +634,8 @@ struct numeric_limits_float8_e4m3fnuz : public numeric_limits_float8_base {
634634
}
635635
static constexpr float8_e4m3fnuz infinity() {
636636
return float8_e4m3fnuz::FromRep(0x80);
637-
} // NaN.
637+
}
638+
// NaN.
638639
static constexpr float8_e4m3fnuz quiet_NaN() {
639640
return float8_e4m3fnuz::FromRep(0x80);
640641
}
@@ -1239,13 +1240,38 @@ struct ConvertImpl<float8_e5m2, Eigen::half, kSaturate, kTruncate> {
12391240

12401241
template <typename Derived>
12411242
template <bool kSaturate, bool kTruncate, typename From>
1242-
EIGEN_DEVICE_FUNC Derived float8_base<Derived>::ConvertFrom(const From& from) {
1243-
return ConvertImpl<From, Derived, kSaturate, kTruncate>::run(from);
1243+
EIGEN_DEVICE_FUNC Derived float8_base<Derived>::ConvertFrom(const From from) {
1244+
// We are rounding double/long double -> float -> float8. This can induce
1245+
// double-rounding which may alter the results. We can correct for this using
1246+
// a trick explained in: Boldo, Sylvie, and Guillaume Melquiond. "When double
1247+
// rounding is odd." 17th IMACS World Congress. 2005.
1248+
if constexpr (std::is_floating_point_v<From> &&
1249+
sizeof(From) > sizeof(float)) {
1250+
// binary64, float80, binary128, etc. end up here.
1251+
static_assert(std::numeric_limits<From>::digits >=
1252+
std::numeric_limits<float>::digits + 2);
1253+
static_assert(std::numeric_limits<float>::min_exponent >=
1254+
std::numeric_limits<From>::min_exponent + 2);
1255+
static_assert(std::numeric_limits<float>::radix == 2);
1256+
float from_rnd_float = static_cast<float>(from);
1257+
1258+
// Round-to-odd involves us setting the LSB if we dropped any bits while
1259+
// rounding.
1260+
if (std::isfinite(from_rnd_float) &&
1261+
static_cast<From>(from_rnd_float) != from) {
1262+
from_rnd_float = Eigen::numext::bit_cast<float>(
1263+
Eigen::numext::bit_cast<uint32_t>(from_rnd_float) | 1);
1264+
}
1265+
return ConvertImpl<float, Derived, kSaturate, kTruncate>::run(
1266+
from_rnd_float);
1267+
} else {
1268+
return ConvertImpl<From, Derived, kSaturate, kTruncate>::run(from);
1269+
}
12441270
}
12451271

12461272
template <typename Derived>
12471273
template <typename To, bool kSaturate, bool kTruncate>
1248-
EIGEN_DEVICE_FUNC To float8_base<Derived>::ConvertTo(const Derived& from) {
1274+
EIGEN_DEVICE_FUNC To float8_base<Derived>::ConvertTo(const Derived from) {
12491275
return ConvertImpl<Derived, To, kSaturate, kTruncate>::run(from);
12501276
}
12511277

ml_dtypes/tests/float8_test.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,40 @@ TYPED_TEST(Float8Test, ConvertTo) {
407407
}
408408
}
409409

410+
template <typename SrcType, typename IntermediateType, typename Float8>
411+
static SrcType DoubleRoundHelper() {
412+
// If we have a number of the form 1.0..010..010.., two rounds of RTNE can
413+
// cause the last-set bit to get rounded down due to RTNE which in turn will
414+
// cause the other bit to get rounded down due to RTNE. RTNE's tie breaking
415+
// semantics *should* not apply here as there is no tie but double-rounding
416+
// may confuse us.
417+
SrcType x{1.0};
418+
x += std::ldexp(SrcType{1.0}, -std::numeric_limits<Float8>::digits);
419+
x += std::ldexp(SrcType{1.0}, -std::numeric_limits<IntermediateType>::digits);
420+
auto rounded_x = static_cast<Float8>(x);
421+
return static_cast<SrcType>(rounded_x);
422+
}
423+
424+
// This test tries to capture mistakes in `float8_base::ConverFrom` where it is
425+
// implemented by a series of conversions. e.g. converting a double to a float
426+
// to a float8 introduces double-rounding which makes the final rounding step
427+
// unfaithful. Craft a variety of numbers which try to detect if this happens.
428+
TYPED_TEST(Float8Test, DoubleRound) {
429+
using Float8 = TypeParam;
430+
431+
// We expect that our number results in rounding up to the number after 1.
432+
// Incorrect rounding will result in 1.
433+
const double expected =
434+
1.0 + static_cast<double>(std::numeric_limits<Float8>::epsilon());
435+
436+
// Don't use long double on targets which don't support it.
437+
#if !defined(EIGEN_USE_GPU) && !defined(EIGEN_GPU_COMPILE_PHASE)
438+
EXPECT_EQ((DoubleRoundHelper<long double, double, Float8>()), expected);
439+
EXPECT_EQ((DoubleRoundHelper<long double, float, Float8>()), expected);
440+
#endif
441+
EXPECT_EQ((DoubleRoundHelper<double, float, Float8>()), expected);
442+
}
443+
410444
TEST(Float8Test, Float8E5m2_To_Float8E4m3) {
411445
// Saturation.
412446
float8_e5m2 max = std::numeric_limits<float8_e5m2>::max();
@@ -677,7 +711,7 @@ TYPED_TEST(Float8Test, CallTheConstOperator) {
677711
}
678712
}
679713

680-
TEST(Float855m2Test, SmallCastToDenormal) {
714+
TEST(Float8E5m2Test, SmallCastToDenormal) {
681715
// Special edge-case where rounding to a normalized value would
682716
// normally round down, but rounding to a subnormal rounds up.
683717
float x = std::ldexp(1.3125, -15);

0 commit comments

Comments
 (0)