Skip to content

Commit

Permalink
Add float8_base as friend class to FP4 and FP6 types.
Browse files Browse the repository at this point in the history
This fixes a compile error on Clang 12 and below. Clang had a bug where if a superclass instantiated an instance of a subclass with a constructor defined in the superclass and brought into the subclass with a using-declaration, and the constructor had at least two arguments, compilation would fail. See https://godbolt.org/z/aeP9sP5x5 for an example. The error would complain that the superclass does not have access to the subclass's protected members, despite the constructor being declared in the superclass itself.

The fix is to make the superclass a friend class of the subclass. This is already done in float8.h, but wasn't done in mxfloat.h.

openxla/xla#19096 was rolled back since it added an include of mxfloat.h in TensorFlow, causing an Android TensorFlow build using Android NDK r21e to fail since this NDK uses Clang 9.

PiperOrigin-RevId: 713843078
  • Loading branch information
reedwm authored and The ml_dtypes Authors committed Jan 10, 2025
1 parent f656f18 commit bd71c2a
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ml_dtypes/include/mxfloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace mxfloat_internal {
template <typename Derived>
class mxfloat6_base : public float8_internal::float8_base<Derived> {
using Base = float8_internal::float8_base<Derived>;
friend class float8_internal::float8_base<Derived>;
using Base::Base;

public:
Expand All @@ -54,6 +55,7 @@ class mxfloat6_base : public float8_internal::float8_base<Derived> {
template <typename Derived>
class mxfloat4_base : public float8_internal::float8_base<Derived> {
using Base = float8_internal::float8_base<Derived>;
friend class float8_internal::float8_base<Derived>;
using Base::Base;

public:
Expand All @@ -74,6 +76,7 @@ class float6_e2m3fn : public mxfloat6_base<float6_e2m3fn> {
// Exponent: 2, Mantissa: 3, bias: 1.
// Extended range: no inf, no NaN.
using Base = mxfloat6_base<float6_e2m3fn>;
friend class float8_internal::float8_base<float6_e2m3fn>;
using Base::Base;

public:
Expand All @@ -86,6 +89,7 @@ class float6_e3m2fn : public mxfloat6_base<float6_e3m2fn> {
// Exponent: 3, Mantissa: 2, bias: 3.
// Extended range: no inf, no NaN.
using Base = mxfloat6_base<float6_e3m2fn>;
friend class float8_internal::float8_base<float6_e3m2fn>;
using Base::Base;

public:
Expand All @@ -98,6 +102,7 @@ class float4_e2m1fn : public mxfloat4_base<float4_e2m1fn> {
// Exponent: 2, Mantissa: 1, bias: 1.
// Extended range: no inf, no NaN.
using Base = mxfloat4_base<float4_e2m1fn>;
friend class float8_internal::float8_base<float4_e2m1fn>;;
using Base::Base;

public:
Expand Down

0 comments on commit bd71c2a

Please sign in to comment.