@@ -61,13 +61,13 @@ class float8_base {
61
61
62
62
template <typename T>
63
63
explicit EIGEN_DEVICE_FUNC float8_base (
64
- T f , std::enable_if_t <std::is_arithmetic_v <T>, int > = 0 )
65
- : float8_base(ConvertFrom(static_cast <float >(f )).rep(),
64
+ T i , std::enable_if_t <std::is_integral_v <T>, int > = 0 )
65
+ : float8_base(ConvertFrom(static_cast <float >(i )).rep(),
66
66
ConstructFromRepTag{}) {}
67
- explicit EIGEN_DEVICE_FUNC float8_base ( double f64 )
68
- : float8_base(ConvertFrom( f64 ).rep(), ConstructFromRepTag{}) {}
69
- explicit EIGEN_DEVICE_FUNC float8_base ( float f32 )
70
- : float8_base(ConvertFrom(f32 ).rep(), ConstructFromRepTag{}) {}
67
+ template < typename T>
68
+ explicit EIGEN_DEVICE_FUNC float8_base (
69
+ T f, std:: enable_if_t <std::is_floating_point_v<T>, int > = 0 )
70
+ : float8_base(ConvertFrom(f ).rep(), ConstructFromRepTag{}) {}
71
71
explicit EIGEN_DEVICE_FUNC float8_base (Eigen::bfloat16 bf16 )
72
72
: float8_base(ConvertFrom(bf16 ).rep(), ConstructFromRepTag{}) {}
73
73
explicit EIGEN_DEVICE_FUNC float8_base (Eigen::half f16 )
@@ -112,10 +112,10 @@ class float8_base {
112
112
113
113
// Conversions allowing saturation and truncation.
114
114
template <bool kSaturate = false , bool kTruncate = false , typename From>
115
- static inline EIGEN_DEVICE_FUNC Derived ConvertFrom (const From& from);
115
+ static inline EIGEN_DEVICE_FUNC Derived ConvertFrom (From from);
116
116
117
117
template <typename To, bool kSaturate = false , bool kTruncate = false >
118
- static inline EIGEN_DEVICE_FUNC To ConvertTo (const Derived& from);
118
+ static inline EIGEN_DEVICE_FUNC To ConvertTo (Derived from);
119
119
120
120
// Operators via float32.
121
121
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Derived
@@ -634,7 +634,8 @@ struct numeric_limits_float8_e4m3fnuz : public numeric_limits_float8_base {
634
634
}
635
635
static constexpr float8_e4m3fnuz infinity () {
636
636
return float8_e4m3fnuz::FromRep (0x80 );
637
- } // NaN.
637
+ }
638
+ // NaN.
638
639
static constexpr float8_e4m3fnuz quiet_NaN () {
639
640
return float8_e4m3fnuz::FromRep (0x80 );
640
641
}
@@ -1239,13 +1240,38 @@ struct ConvertImpl<float8_e5m2, Eigen::half, kSaturate, kTruncate> {
1239
1240
1240
1241
template <typename Derived>
1241
1242
template <bool kSaturate , bool kTruncate , typename From>
1242
- EIGEN_DEVICE_FUNC Derived float8_base<Derived>::ConvertFrom(const From& from) {
1243
- return ConvertImpl<From, Derived, kSaturate , kTruncate >::run (from);
1243
+ EIGEN_DEVICE_FUNC Derived float8_base<Derived>::ConvertFrom(const From from) {
1244
+ // We are rounding double/long double -> float -> float8. This can induce
1245
+ // double-rounding which may alter the results. We can correct for this using
1246
+ // a trick explained in: Boldo, Sylvie, and Guillaume Melquiond. "When double
1247
+ // rounding is odd." 17th IMACS World Congress. 2005.
1248
+ if constexpr (std::is_floating_point_v<From> &&
1249
+ sizeof (From) > sizeof (float )) {
1250
+ // binary64, float80, binary128, etc. end up here.
1251
+ static_assert (std::numeric_limits<From>::digits >=
1252
+ std::numeric_limits<float >::digits + 2 );
1253
+ static_assert (std::numeric_limits<float >::min_exponent >=
1254
+ std::numeric_limits<From>::min_exponent + 2 );
1255
+ static_assert (std::numeric_limits<float >::radix == 2 );
1256
+ float from_rnd_float = static_cast <float >(from);
1257
+
1258
+ // Round-to-odd involves us setting the LSB if we dropped any bits while
1259
+ // rounding.
1260
+ if (std::isfinite (from_rnd_float) &&
1261
+ static_cast <From>(from_rnd_float) != from) {
1262
+ from_rnd_float = Eigen::numext::bit_cast<float >(
1263
+ Eigen::numext::bit_cast<uint32_t >(from_rnd_float) | 1 );
1264
+ }
1265
+ return ConvertImpl<float , Derived, kSaturate , kTruncate >::run (
1266
+ from_rnd_float);
1267
+ } else {
1268
+ return ConvertImpl<From, Derived, kSaturate , kTruncate >::run (from);
1269
+ }
1244
1270
}
1245
1271
1246
1272
template <typename Derived>
1247
1273
template <typename To, bool kSaturate , bool kTruncate >
1248
- EIGEN_DEVICE_FUNC To float8_base<Derived>::ConvertTo(const Derived& from) {
1274
+ EIGEN_DEVICE_FUNC To float8_base<Derived>::ConvertTo(const Derived from) {
1249
1275
return ConvertImpl<Derived, To, kSaturate , kTruncate >::run (from);
1250
1276
}
1251
1277
0 commit comments