Skip to content

Commit

Permalink
[SYCL] add preliminary support for bfloat16 to sycl::vec (#12261)
Browse files Browse the repository at this point in the history
This does not yet add support for the math builtins for
`sycl::vec<bfloat16>` . That will come later.
  • Loading branch information
cperkinsintel authored Jan 25, 2024
1 parent ba6ce4d commit bbbe883
Show file tree
Hide file tree
Showing 13 changed files with 500 additions and 63 deletions.
9 changes: 5 additions & 4 deletions sycl/include/sycl/bit_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,22 @@
#if __cpp_lib_bit_cast
// first choice std::bit_cast
#include <bit>
#define __SYCL_BC_CONSTEXPR constexpr
#define __SYCL_BITCAST_IS_CONSTEXPR 1
#elif __SYCL_HAS_BUILTIN_BIT_CAST
// second choice __builtin_bit_cast
#define __SYCL_BC_CONSTEXPR constexpr
#define __SYCL_BITCAST_IS_CONSTEXPR 1
#else
// fallback memcpy
#include <sycl/detail/memcpy.hpp>
#define __SYCL_BC_CONSTEXPR
#endif

namespace sycl {
inline namespace _V1 {

template <typename To, typename From>
__SYCL_BC_CONSTEXPR
#if defined(__SYCL_BITCAST_IS_CONSTEXPR)
constexpr
#endif
std::enable_if_t<sizeof(To) == sizeof(From) &&
std::is_trivially_copyable<From>::value &&
std::is_trivially_copyable<To>::value,
Expand Down
40 changes: 36 additions & 4 deletions sycl/include/sycl/detail/generic_type_lists.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <sycl/detail/type_list.hpp> // for type_list, address_space_list
#include <sycl/half_type.hpp> // for half

#include <sycl/ext/oneapi/bfloat16.hpp> // bfloat16

#include <cstddef> // for byte, size_t
#include <type_traits> // for conditional_t, is_signed_v, is_...

Expand Down Expand Up @@ -41,6 +43,28 @@ using scalar_vector_half_list = tl_append<scalar_half_list, vector_half_list>;
using half_list =
tl_append<scalar_half_list, vector_half_list, marray_half_list>;

using scalar_bfloat16_list = type_list<sycl::ext::oneapi::bfloat16>;

using vector_bfloat16_list = type_list<
vec<sycl::ext::oneapi::bfloat16, 1>, vec<sycl::ext::oneapi::bfloat16, 2>,
vec<sycl::ext::oneapi::bfloat16, 3>, vec<sycl::ext::oneapi::bfloat16, 4>,
vec<sycl::ext::oneapi::bfloat16, 8>, vec<sycl::ext::oneapi::bfloat16, 16>>;

using marray_bfloat16_list = type_list<marray<sycl::ext::oneapi::bfloat16, 1>,
marray<sycl::ext::oneapi::bfloat16, 2>,
marray<sycl::ext::oneapi::bfloat16, 3>,
marray<sycl::ext::oneapi::bfloat16, 4>,
marray<sycl::ext::oneapi::bfloat16, 8>,
marray<sycl::ext::oneapi::bfloat16, 16>>;

using scalar_vector_bfloat16_list =
tl_append<scalar_bfloat16_list, vector_bfloat16_list>;

using bfloat16_list =
tl_append<scalar_bfloat16_list, vector_bfloat16_list, marray_bfloat16_list>;

using half_bfloat16_list = tl_append<scalar_half_list, scalar_bfloat16_list>;

using scalar_float_list = type_list<float>;

using vector_float_list =
Expand Down Expand Up @@ -73,14 +97,22 @@ using scalar_vector_double_list =
using double_list =
tl_append<scalar_double_list, vector_double_list, marray_double_list>;

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
using scalar_floating_list = tl_append<scalar_float_list, scalar_double_list,
scalar_half_list, scalar_bfloat16_list>;
#else
// Presently, this is used only by builtins_legacy_scalar.hpp for defining math
// funcs. bfloat16 provides its own scalar math definitions so we skip its
// inclusion.
using scalar_floating_list =
tl_append<scalar_float_list, scalar_double_list, scalar_half_list>;
#endif

using vector_floating_list =
tl_append<vector_float_list, vector_double_list, vector_half_list>;
using vector_floating_list = tl_append<vector_float_list, vector_double_list,
vector_half_list, vector_bfloat16_list>;

using marray_floating_list =
tl_append<marray_float_list, marray_double_list, marray_half_list>;
using marray_floating_list = tl_append<marray_float_list, marray_double_list,
marray_half_list, marray_bfloat16_list>;

using scalar_vector_floating_list =
tl_append<scalar_floating_list, vector_floating_list>;
Expand Down
22 changes: 15 additions & 7 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ inline constexpr bool is_genfloath_v = is_contained_v<T, gtl::half_list>;
template <typename T>
inline constexpr bool is_half_v = is_contained_v<T, gtl::scalar_half_list>;

template <typename T>
inline constexpr bool is_bfloat16_v =
is_contained_v<T, gtl::scalar_bfloat16_list>;

template <typename T>
inline constexpr bool is_half_or_bf16_v =
is_contained_v<T, gtl::half_bfloat16_list>;

template <typename T>
inline constexpr bool is_svgenfloath_v =
is_contained_v<T, gtl::scalar_vector_half_list>;
Expand Down Expand Up @@ -539,12 +547,12 @@ using select_cl_scalar_t = std::conditional_t<
std::is_integral_v<T>, select_cl_scalar_integral_t<T>,
std::conditional_t<
std::is_floating_point_v<T>, select_cl_scalar_float_t<T>,
// half is a special case: it is implemented differently on
// host and device and therefore, might lower to different
// types
std::conditional_t<is_half_v<T>,
sycl::detail::half_impl::BIsRepresentationT,
select_cl_scalar_complex_or_T_t<T>>>>;
// half and bfloat16 are special cases: they are implemented differently
// on host and device and therefore might lower to different types
std::conditional_t<
is_half_v<T>, sycl::detail::half_impl::BIsRepresentationT,
std::conditional_t<is_bfloat16_v<T>, T,
select_cl_scalar_complex_or_T_t<T>>>>>;

// select_cl_vector_or_scalar_or_ptr does cl_* type selection for element type
// of a vector type T, pointer type substitution, and scalar type substitution.
Expand All @@ -556,7 +564,7 @@ template <typename T>
struct select_cl_vector_or_scalar_or_ptr<
T, typename std::enable_if_t<is_vgentype_v<T>>> {
using type =
// select_cl_scalar_t returns _Float16, so, we try to instantiate vec
// select_cl_scalar_t may return _Float16, so, we try to instantiate vec
// class with _Float16 DataType, which is not expected there
// So, leave vector<half, N> as-is
vec<std::conditional_t<is_half_v<mptr_or_vec_elem_type_t<T>>,
Expand Down
63 changes: 60 additions & 3 deletions sycl/include/sycl/ext/oneapi/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#pragma once

#include <sycl/aliases.hpp> // for half
#include <sycl/builtins.hpp> // for isnan
#include <sycl/detail/defines_elementary.hpp> // for __DPCPP_SYCL_EXTERNAL
#include <sycl/half_type.hpp> // for half

Expand All @@ -22,6 +21,13 @@ __devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept;

namespace sycl {
inline namespace _V1 {

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
// forward declaration of sycl::isnan built-in.
// extern __DPCPP_SYCL_EXTERNAL bool isnan(float a);
bool isnan(float a);
#endif

namespace ext::oneapi {

class bfloat16;
Expand All @@ -30,9 +36,31 @@ namespace detail {
using Bfloat16StorageT = uint16_t;
Bfloat16StorageT bfloat16ToBits(const bfloat16 &Value);
bfloat16 bitsToBfloat16(const Bfloat16StorageT Value);

// sycl::vec support
namespace bf16 {
#ifdef __SYCL_DEVICE_ONLY__
using Vec2StorageT = Bfloat16StorageT __attribute__((ext_vector_type(2)));
using Vec3StorageT = Bfloat16StorageT __attribute__((ext_vector_type(3)));
using Vec4StorageT = Bfloat16StorageT __attribute__((ext_vector_type(4)));
using Vec8StorageT = Bfloat16StorageT __attribute__((ext_vector_type(8)));
using Vec16StorageT = Bfloat16StorageT __attribute__((ext_vector_type(16)));
#else
using Vec2StorageT = std::array<Bfloat16StorageT, 2>;
using Vec3StorageT = std::array<Bfloat16StorageT, 3>;
using Vec4StorageT = std::array<Bfloat16StorageT, 4>;
using Vec8StorageT = std::array<Bfloat16StorageT, 8>;
using Vec16StorageT = std::array<Bfloat16StorageT, 16>;
#endif
} // namespace bf16

#ifndef __INTEL_PREVIEW_BREAKING_CHANGES
inline bool float_is_nan(float x) { return x != x; }
#endif
} // namespace detail

class bfloat16 {
protected:
detail::Bfloat16StorageT value;

friend inline detail::Bfloat16StorageT
Expand All @@ -42,13 +70,21 @@ class bfloat16 {

public:
bfloat16() = default;
bfloat16(const bfloat16 &) = default;
constexpr bfloat16(const bfloat16 &) = default;
constexpr bfloat16(bfloat16 &&) = default;
constexpr bfloat16 &operator=(const bfloat16 &rhs) = default;
~bfloat16() = default;

private:
static detail::Bfloat16StorageT from_float_fallback(const float &a) {
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
if (sycl::isnan(a))
return 0xffc1;
#else
if (detail::float_is_nan(a))
return 0xffc1;
#endif

union {
uint32_t intStorage;
float floatValue;
Expand Down Expand Up @@ -92,6 +128,14 @@ class bfloat16 {
#endif
}

protected:
friend class sycl::vec<bfloat16, 1>;
friend class sycl::vec<bfloat16, 2>;
friend class sycl::vec<bfloat16, 3>;
friend class sycl::vec<bfloat16, 4>;
friend class sycl::vec<bfloat16, 8>;
friend class sycl::vec<bfloat16, 16>;

public:
// Implicit conversion from float to bfloat16
bfloat16(const float &a) { value = from_float(a); }
Expand Down Expand Up @@ -128,7 +172,7 @@ class bfloat16 {
#elif defined(__SYCL_DEVICE_ONLY__) && defined(__SPIR__)
return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)};
#else
return -to_float(lhs.value);
return bfloat16{-to_float(lhs.value)};
#endif
}

Expand Down Expand Up @@ -199,6 +243,19 @@ class bfloat16 {

// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
// for floating-point types.

// Stream Operator << and >>
inline friend std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
O << static_cast<float>(rhs);
return O;
}

inline friend std::istream &operator>>(std::istream &I, bfloat16 &rhs) {
float ValFloat = 0.0f;
I >> ValFloat;
rhs = ValFloat;
return I;
}
};

namespace detail {
Expand Down
4 changes: 4 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
// ===--------------------------------------------------------------------=== //

#pragma once

#include <CL/__spirv/spirv_types.hpp> // __spv namespace
#include <optional> // std::optional

namespace sycl {
inline namespace _V1 {
namespace ext {
Expand Down
41 changes: 35 additions & 6 deletions sycl/include/sycl/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <sycl/detail/export.hpp> // for __SYCL_EXPORT
#include <sycl/detail/item_base.hpp> // for id, range
#include <sycl/detail/owner_less_base.hpp> // for OwnerLessBase
#include <sycl/ext/oneapi/bfloat16.hpp> // for bfloat16
#include <sycl/group.hpp> // for group
#include <sycl/h_item.hpp> // for h_item
#include <sycl/half_type.hpp> // for half, operator-, operator<
Expand Down Expand Up @@ -83,10 +84,8 @@ constexpr size_t MAX_ARRAY_SIZE =
constexpr unsigned FLUSH_BUF_OFFSET_SIZE = 2;

template <class F, class T = void>
using EnableIfFP = typename std::enable_if_t<std::is_same_v<F, float> ||
std::is_same_v<F, double> ||
std::is_same_v<F, half>,
T>;
using EnableIfFP = typename std::enable_if_t<
detail::check_type_in_v<F, float, double, half, ext::oneapi::bfloat16>, T>;

using GlobalBufAccessorT = accessor<char, 1, sycl::access::mode::read_write,
sycl::access::target::device>;
Expand Down Expand Up @@ -334,9 +333,9 @@ checkForInfNan(char *Buf, T Val) {
return append(Buf, "nan");

// Extract the sign from the bits
const uint16_t Sign = reinterpret_cast<uint16_t &>(Val) & 0x8000;
const uint16_t Sign = sycl::bit_cast<uint16_t>(Val) & 0x8000;
// Extract the exponent from the bits
const uint16_t Exp16 = (reinterpret_cast<uint16_t &>(Val) & 0x7c00) >> 10;
const uint16_t Exp16 = (sycl::bit_cast<uint16_t>(Val) & 0x7c00) >> 10;

if (Exp16 == 0x1f) {
if (Sign)
Expand All @@ -346,6 +345,26 @@ checkForInfNan(char *Buf, T Val) {
return 0;
}

template <typename T>
inline typename std::enable_if_t<std::is_same_v<T, ext::oneapi::bfloat16>,
unsigned>
checkForInfNan(char *Buf, T Val) {
if (Val != Val)
return append(Buf, "nan");

// Extract the sign from the bits
const uint16_t Sign = sycl::bit_cast<uint16_t>(Val) & 0x8000;
// Extract the exponent from the bits
const uint16_t Exp16 = (sycl::bit_cast<uint16_t>(Val) & 0x7f80) >> 7;

if (Exp16 == 0x7f) {
if (Sign)
return append(Buf, "-inf");
return append(Buf, "inf");
}
return 0;
}

template <typename T>
EnableIfFP<T, unsigned> floatingPointToDecStr(T AbsVal, char *Digits,
int Precision, bool IsSci) {
Expand Down Expand Up @@ -1053,6 +1072,8 @@ class __SYCL_EXPORT __SYCL_SPECIAL_CLASS __SYCL_TYPE(stream) stream
friend const stream &operator<<(const stream &, const float &);
friend const stream &operator<<(const stream &, const double &);
friend const stream &operator<<(const stream &, const half &);
friend const stream &operator<<(const stream &,
const ext::oneapi::bfloat16 &);

friend const stream &operator<<(const stream &, const stream_manipulator);

Expand Down Expand Up @@ -1159,6 +1180,14 @@ inline const stream &operator<<(const stream &Out, const half &RHS) {
return Out;
}

inline const stream &operator<<(const stream &Out,
const ext::oneapi::bfloat16 &RHS) {
detail::writeFloatingPoint<ext::oneapi::bfloat16>(
Out.GlobalFlushBuf, Out.FlushBufferSize, Out.WIOffset, Out.get_flags(),
Out.get_width(), Out.get_precision(), RHS);
return Out;
}

// Pointer

template <typename ElementType, access::address_space Space,
Expand Down
Loading

0 comments on commit bbbe883

Please sign in to comment.