Skip to content

Commit

Permalink
Bfloat16 support for sycl::vec
Browse files Browse the repository at this point in the history
  • Loading branch information
cperkinsintel committed Dec 12, 2023
1 parent 8074617 commit 1e61fe3
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 50 deletions.
62 changes: 50 additions & 12 deletions sycl/include/sycl/detail/generic_type_lists.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <sycl/detail/type_list.hpp> // for type_list, address_space_list
#include <sycl/half_type.hpp> // for half

#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#include <cstddef> // for byte, size_t
#include <type_traits> // for conditional_t, is_signed_v, is_...

Expand Down Expand Up @@ -41,6 +43,28 @@ using scalar_vector_half_list = tl_append<scalar_half_list, vector_half_list>;
using half_list =
tl_append<scalar_half_list, vector_half_list, marray_half_list>;

using scalar_bfloat16_list = type_list<sycl::ext::oneapi::bfloat16>;

using vector_bfloat16_list = type_list<
vec<sycl::ext::oneapi::bfloat16, 1>, vec<sycl::ext::oneapi::bfloat16, 2>,
vec<sycl::ext::oneapi::bfloat16, 3>, vec<sycl::ext::oneapi::bfloat16, 4>,
vec<sycl::ext::oneapi::bfloat16, 8>, vec<sycl::ext::oneapi::bfloat16, 16>>;

using marray_bfloat16_list = type_list<marray<sycl::ext::oneapi::bfloat16, 1>,
marray<sycl::ext::oneapi::bfloat16, 2>,
marray<sycl::ext::oneapi::bfloat16, 3>,
marray<sycl::ext::oneapi::bfloat16, 4>,
marray<sycl::ext::oneapi::bfloat16, 8>,
marray<sycl::ext::oneapi::bfloat16, 16>>;

using scalar_vector_bfloat16_list =
tl_append<scalar_bfloat16_list, vector_bfloat16_list>;

using bfloat16_list =
tl_append<scalar_bfloat16_list, vector_bfloat16_list, marray_bfloat16_list>;

using half_bfloat16_list = tl_append<scalar_half_list, scalar_bfloat16_list>;

using scalar_float_list = type_list<float>;

using vector_float_list =
Expand Down Expand Up @@ -73,14 +97,14 @@ using scalar_vector_double_list =
using double_list =
tl_append<scalar_double_list, vector_double_list, marray_double_list>;

using scalar_floating_list =
tl_append<scalar_float_list, scalar_double_list, scalar_half_list>;
using scalar_floating_list = tl_append<scalar_float_list, scalar_double_list,
scalar_half_list, scalar_bfloat16_list>;

using vector_floating_list =
tl_append<vector_float_list, vector_double_list, vector_half_list>;
using vector_floating_list = tl_append<vector_float_list, vector_double_list,
vector_half_list, vector_bfloat16_list>;

using marray_floating_list =
tl_append<marray_float_list, marray_double_list, marray_half_list>;
using marray_floating_list = tl_append<marray_float_list, marray_double_list,
marray_half_list, marray_bfloat16_list>;

using scalar_vector_floating_list =
tl_append<scalar_floating_list, vector_floating_list>;
Expand All @@ -91,13 +115,19 @@ using floating_list =
// geometric floating point types
using scalar_geo_half_list = type_list<half>;

using scalar_geo_bfloat16_list = type_list<sycl::ext::oneapi::bfloat16>;

using scalar_geo_float_list = type_list<float>;

using scalar_geo_double_list = type_list<double>;

using vector_geo_half_list =
type_list<vec<half, 1>, vec<half, 2>, vec<half, 3>, vec<half, 4>>;

using vector_geo_bfloat16_list = type_list<
vec<sycl::ext::oneapi::bfloat16, 1>, vec<sycl::ext::oneapi::bfloat16, 2>,
vec<sycl::ext::oneapi::bfloat16, 3>, vec<sycl::ext::oneapi::bfloat16, 4>>;

using vector_geo_float_list =
type_list<vec<float, 1>, vec<float, 2>, vec<float, 3>, vec<float, 4>>;

Expand All @@ -112,16 +142,21 @@ using marray_geo_double_list =

using geo_half_list = tl_append<scalar_geo_half_list, vector_geo_half_list>;

using geo_bfloat16_list =
tl_append<scalar_geo_bfloat16_list, vector_geo_bfloat16_list>;

using geo_float_list = tl_append<scalar_geo_float_list, vector_geo_float_list>;

using geo_double_list =
tl_append<scalar_geo_double_list, vector_geo_double_list>;

using scalar_geo_list = tl_append<scalar_geo_half_list, scalar_geo_float_list,
scalar_geo_double_list>;
using scalar_geo_list =
tl_append<scalar_geo_half_list, scalar_geo_bfloat16_list,
scalar_geo_float_list, scalar_geo_double_list>;

using vector_geo_list = tl_append<vector_geo_half_list, vector_geo_float_list,
vector_geo_double_list>;
using vector_geo_list =
tl_append<vector_geo_half_list, vector_geo_bfloat16_list,
vector_geo_float_list, vector_geo_double_list>;

using marray_geo_list =
tl_append<marray_geo_float_list, marray_geo_double_list>;
Expand All @@ -131,12 +166,15 @@ using geo_list = tl_append<scalar_geo_list, vector_geo_list>;
// cross floating point types
using cross_half_list = type_list<vec<half, 3>, vec<half, 4>>;

using cross_bfloat16_list = type_list<vec<sycl::ext::oneapi::bfloat16, 3>,
vec<sycl::ext::oneapi::bfloat16, 4>>;

using cross_float_list = type_list<vec<float, 3>, vec<float, 4>>;

using cross_double_list = type_list<vec<double, 3>, vec<double, 4>>;

using cross_floating_list =
tl_append<cross_float_list, cross_double_list, cross_half_list>;
using cross_floating_list = tl_append<cross_float_list, cross_double_list,
cross_half_list, cross_bfloat16_list>;

using cross_marray_list = type_list<marray<float, 3>, marray<float, 4>,
marray<double, 3>, marray<double, 4>>;
Expand Down
17 changes: 12 additions & 5 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ inline constexpr bool is_genfloath_v = is_contained_v<T, gtl::half_list>;
template <typename T>
inline constexpr bool is_half_v = is_contained_v<T, gtl::scalar_half_list>;

template <typename T>
inline constexpr bool is_bfloat16_v =
is_contained_v<T, gtl::scalar_bfloat16_list>;

template <typename T>
inline constexpr bool is_half_or_bf16_v =
is_contained_v<T, gtl::half_bfloat16_list>;

template <typename T>
inline constexpr bool is_svgenfloath_v =
is_contained_v<T, gtl::scalar_vector_half_list>;
Expand Down Expand Up @@ -539,10 +547,9 @@ using select_cl_scalar_t = std::conditional_t<
std::is_integral_v<T>, select_cl_scalar_integral_t<T>,
std::conditional_t<
std::is_floating_point_v<T>, select_cl_scalar_float_t<T>,
// half is a special case: it is implemented differently on
// host and device and therefore, might lower to different
// types
std::conditional_t<is_half_v<T>,
// half and bfloat16 are special cases: they are implemented differently
// on host and device and therefore might lower to different types
std::conditional_t<is_half_or_bf16_v<T>,
sycl::detail::half_impl::BIsRepresentationT,
select_cl_scalar_complex_or_T_t<T>>>>;

Expand All @@ -559,7 +566,7 @@ struct select_cl_vector_or_scalar_or_ptr<
// select_cl_scalar_t returns _Float16, so, we try to instantiate vec
// class with _Float16 DataType, which is not expected there
// So, leave vector<half, N> as-is
vec<std::conditional_t<is_half_v<mptr_or_vec_elem_type_t<T>>,
vec<std::conditional_t<is_half_or_bf16_v<mptr_or_vec_elem_type_t<T>>,
mptr_or_vec_elem_type_t<T>,
select_cl_scalar_t<mptr_or_vec_elem_type_t<T>>>,
T::size()>;
Expand Down
63 changes: 60 additions & 3 deletions sycl/include/sycl/ext/oneapi/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#pragma once

#include <sycl/aliases.hpp> // for half
#include <sycl/builtins.hpp> // for isnan
#include <sycl/detail/defines_elementary.hpp> // for __DPCPP_SYCL_EXTERNAL
#include <sycl/half_type.hpp> // for half

Expand All @@ -22,6 +21,13 @@ __devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept;

namespace sycl {
inline namespace _V1 {

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
// forward declaration of sycl::isnan built-in.
// extern __DPCPP_SYCL_EXTERNAL bool isnan(float a);
bool isnan(float a);
#endif

namespace ext::oneapi {

class bfloat16;
Expand All @@ -30,9 +36,31 @@ namespace detail {
using Bfloat16StorageT = uint16_t;
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value);
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value);

// sycl::vec support
namespace bf16 {
#ifdef __SYCL_DEVICE_ONLY__
using Vec2StorageT = Bfloat16StorageT __attribute__((ext_vector_type(2)));
using Vec3StorageT = Bfloat16StorageT __attribute__((ext_vector_type(3)));
using Vec4StorageT = Bfloat16StorageT __attribute__((ext_vector_type(4)));
using Vec8StorageT = Bfloat16StorageT __attribute__((ext_vector_type(8)));
using Vec16StorageT = Bfloat16StorageT __attribute__((ext_vector_type(16)));
#else
using Vec2StorageT = std::array<Bfloat16StorageT, 2>;
using Vec3StorageT = std::array<Bfloat16StorageT, 3>;
using Vec4StorageT = std::array<Bfloat16StorageT, 4>;
using Vec8StorageT = std::array<Bfloat16StorageT, 8>;
using Vec16StorageT = std::array<Bfloat16StorageT, 16>;
#endif
} // namespace bf16

#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
static bool float_is_nan(float x) { return x != x; }
#endif
} // namespace detail

class bfloat16 {
protected:
detail::Bfloat16StorageT value;

friend inline detail::Bfloat16StorageT
Expand All @@ -42,13 +70,21 @@ class bfloat16 {

public:
bfloat16() = default;
bfloat16(const bfloat16 &) = default;
constexpr bfloat16(const bfloat16 &) = default;
constexpr bfloat16(bfloat16 &&) = default;
constexpr bfloat16 &operator=(const bfloat16 &rhs) = default;
~bfloat16() = default;

private:
static detail::Bfloat16StorageT from_float_fallback(const float &a) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
if (sycl::isnan(a))
return 0xffc1;
#else
if (detail::float_is_nan(a))
return 0xffc1;
#endif

union {
uint32_t intStorage;
float floatValue;
Expand Down Expand Up @@ -92,6 +128,14 @@ class bfloat16 {
#endif
}

protected:
friend class sycl::vec<bfloat16, 1>;
friend class sycl::vec<bfloat16, 2>;
friend class sycl::vec<bfloat16, 3>;
friend class sycl::vec<bfloat16, 4>;
friend class sycl::vec<bfloat16, 8>;
friend class sycl::vec<bfloat16, 16>;

public:
// Implicit conversion from float to bfloat16
bfloat16(const float &a) { value = from_float(a); }
Expand Down Expand Up @@ -128,7 +172,7 @@ class bfloat16 {
#elif defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)};
#else
return -to_float(lhs.value);
return bfloat16{-to_float(lhs.value)};
#endif
}

Expand Down Expand Up @@ -199,6 +243,19 @@ class bfloat16 {

// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
// for floating-point types.

// Stream Operator << and >>
inline friend std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
O << static_cast<float>(rhs);
return O;
}

inline friend std::istream &operator>>(std::istream &I, bfloat16 &rhs) {
float ValFloat = 0.0f;
I >> ValFloat;
rhs = ValFloat;
return I;
}
};

namespace detail {
Expand Down
3 changes: 3 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
// ===--------------------------------------------------------------------=== //

#pragma once

#include <optional>

namespace sycl {
inline namespace _V1 {
namespace ext {
Expand Down
39 changes: 35 additions & 4 deletions sycl/include/sycl/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <sycl/detail/export.hpp> // for __SYCL_EXPORT
#include <sycl/detail/item_base.hpp> // for id, range
#include <sycl/detail/owner_less_base.hpp> // for OwnerLessBase
#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
#include <sycl/group.hpp> // for group
#include <sycl/h_item.hpp> // for h_item
#include <sycl/half_type.hpp> // for half, operator-, operator<
Expand Down Expand Up @@ -83,10 +84,10 @@ constexpr size_t MAX_ARRAY_SIZE =
constexpr unsigned FLUSH_BUF_OFFSET_SIZE = 2;

template <class F, class T = void>
using EnableIfFP = typename std::enable_if_t<std::is_same_v<F, float> ||
std::is_same_v<F, double> ||
std::is_same_v<F, half>,
T>;
using EnableIfFP = typename std::enable_if_t<
std::is_same_v<F, float> || std::is_same_v<F, double> ||
std::is_same_v<F, half> || std::is_same_v<F, ext::oneapi::bfloat16>,
T>;

using GlobalBufAccessorT = accessor<char, 1, sycl::access::mode::read_write,
sycl::access::target::device>;
Expand Down Expand Up @@ -346,6 +347,26 @@ checkForInfNan(char *Buf, T Val) {
return 0;
}

template <typename T>
inline typename std::enable_if_t<std::is_same_v<T, ext::oneapi::bfloat16>,
unsigned>
checkForInfNan(char *Buf, T Val) {
if (Val != Val)
return append(Buf, "nan");

// Extract the sign from the bits
const uint16_t Sign = reinterpret_cast<uint16_t &>(Val) & 0x8000;
// Extract the exponent from the bits
const uint16_t Exp16 = (reinterpret_cast<uint16_t &>(Val) & 0x7f80) >> 7;

if (Exp16 == 0x7f) {
if (Sign)
return append(Buf, "-inf");
return append(Buf, "inf");
}
return 0;
}

template <typename T>
EnableIfFP<T, unsigned> floatingPointToDecStr(T AbsVal, char *Digits,
int Precision, bool IsSci) {
Expand Down Expand Up @@ -1053,6 +1074,8 @@ class __SYCL_EXPORT __SYCL_SPECIAL_CLASS __SYCL_TYPE(stream) stream
friend const stream &operator<<(const stream &, const float &);
friend const stream &operator<<(const stream &, const double &);
friend const stream &operator<<(const stream &, const half &);
friend const stream &operator<<(const stream &,
const ext::oneapi::bfloat16 &);

friend const stream &operator<<(const stream &, const stream_manipulator);

Expand Down Expand Up @@ -1159,6 +1182,14 @@ inline const stream &operator<<(const stream &Out, const half &RHS) {
return Out;
}

inline const stream &operator<<(const stream &Out,
const ext::oneapi::bfloat16 &RHS) {
detail::writeFloatingPoint<ext::oneapi::bfloat16>(
Out.GlobalFlushBuf, Out.FlushBufferSize, Out.WIOffset, Out.get_flags(),
Out.get_width(), Out.get_precision(), RHS);
return Out;
}

// Pointer

template <typename ElementType, access::address_space Space,
Expand Down
Loading

0 comments on commit 1e61fe3

Please sign in to comment.