Skip to content

Commit

Permalink
fixes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cperkinsintel committed Dec 28, 2023
1 parent 1088186 commit 12a3a60
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 9 deletions.
46 changes: 37 additions & 9 deletions sycl/include/sycl/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,43 @@ using select_apply_cl_t = std::conditional_t<
template <typename T> struct vec_helper {
using RetType = T;
static constexpr RetType get(T value) { return value; }
static constexpr RetType set(T value) { return value; }
};
template <> struct vec_helper<bool> {
using RetType = select_apply_cl_t<bool, std::int8_t, std::int16_t,
std::int32_t, std::int64_t>;
static constexpr RetType get(bool value) { return value; }
static constexpr RetType set(bool value) { return value; }
};

template <> struct vec_helper<sycl::ext::oneapi::bfloat16> {
using RetType = sycl::ext::oneapi::bfloat16;
using BFloat16StorageT = sycl::ext::oneapi::detail::Bfloat16StorageT;
static constexpr RetType get(BFloat16StorageT value) {
// given that BFloat16StorageT is the storageT for bfloat16, I'd prefer
// to use a reinterpret_cast (or cast from void*). But inexplicably
// that's not allowed in constexpr (before C++20).
return sycl::bit_cast<RetType>(value);
}

static constexpr RetType get(RetType value) { return value; }

static constexpr BFloat16StorageT set(RetType value) {
return sycl::bit_cast<BFloat16StorageT>(value);
}
};

#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
template <> struct vec_helper<std::byte> {
using RetType = std::uint8_t;
static constexpr RetType get(std::byte value) { return (RetType)value; }
static constexpr RetType set(std::byte value) { return (RetType)value; }
static constexpr std::byte get(std::uint8_t value) {
return (std::byte)value;
}
static constexpr std::byte set(std::uint8_t value) {
return (std::byte)value;
}
};
#endif

Expand Down Expand Up @@ -330,7 +353,7 @@ template <typename Type, int NumElements> class vec {
// of 64 for direct params. If we drop MSVC, we can have alignment the same as
// size and use vector extensions for all sizes.
static constexpr bool IsUsingArrayOnDevice =
(IsHostHalf || IsSizeGreaterThanMaxAlign);
(IsHostHalf || IsBfloat16 || IsSizeGreaterThanMaxAlign);

#if defined(__SYCL_DEVICE_ONLY__)
static constexpr bool NativeVec = NumElements > 1 && !IsUsingArrayOnDevice;
Expand All @@ -343,7 +366,7 @@ template <typename Type, int NumElements> class vec {
#endif // defined(__INTEL_PREVIEW_BREAKING_CHANGES)

#if !defined(__INTEL_PREVIEW_BREAKING_CHANGES)
static constexpr bool IsUsingArrayOnDevice = IsHostHalf;
static constexpr bool IsUsingArrayOnDevice = IsHostHalf || IsBfloat16;
#endif // !defined(__INTEL_PREVIEW_BREAKING_CHANGES)

static constexpr int getNumElements() { return NumElements; }
Expand Down Expand Up @@ -1035,7 +1058,7 @@ template <typename Type, int NumElements> class vec {
typename std::enable_if_t< \
std::is_convertible_v<DataT, T> && \
(std::is_fundamental_v<vec_data_t<T>> || \
std::is_same_v<typename std::remove_const_t<T>, half>), \
detail::is_half_or_bf16_v<typename std::remove_const_t<T>>), \
vec> \
operator BINOP(const T & Rhs) const { \
return *this BINOP vec(static_cast<const DataT &>(Rhs)); \
Expand Down Expand Up @@ -1464,13 +1487,13 @@ template <typename Type, int NumElements> class vec {

// setValue and getValue should be able to operate on different underlying
// types: enum cl_float#N , builtin vector float#N, builtin type float.

// These versions are for N > 1.
#ifdef __SYCL_USE_EXT_VECTOR_TYPE__
template <int Num = NumElements, typename Ty = int,
typename = typename std::enable_if_t<1 != Num>>
constexpr void setValue(EnableIfNotHostHalf<Ty> Index, const DataT &Value,
int) {
m_Data[Index] = vec_data<DataT>::get(Value);
m_Data[Index] = vec_data<DataT>::set(Value);
}

template <int Num = NumElements, typename Ty = int,
Expand All @@ -1482,7 +1505,7 @@ template <typename Type, int NumElements> class vec {
template <int Num = NumElements, typename Ty = int,
typename = typename std::enable_if_t<1 != Num>>
constexpr void setValue(EnableIfHostHalf<Ty> Index, const DataT &Value, int) {
m_Data.s[Index] = vec_data<DataT>::get(Value);
m_Data.s[Index] = vec_data<DataT>::set(Value);
}

template <int Num = NumElements, typename Ty = int,
Expand All @@ -1495,9 +1518,9 @@ template <typename Type, int NumElements> class vec {
typename = typename std::enable_if_t<1 != Num>>
constexpr void setValue(int Index, const DataT &Value, int) {
#if defined(__INTEL_PREVIEW_BREAKING_CHANGES)
m_Data[Index] = vec_data<DataT>::get(Value);
m_Data[Index] = vec_data<DataT>::set(Value);
#else
m_Data.s[Index] = vec_data<DataT>::get(Value);
m_Data.s[Index] = vec_data<DataT>::set(Value);
#endif
}

Expand All @@ -1512,10 +1535,11 @@ template <typename Type, int NumElements> class vec {
}
#endif // __SYCL_USE_EXT_VECTOR_TYPE__

// N==1 versions, used by host and device. Shouldn't trailing type be int?
template <int Num = NumElements,
typename = typename std::enable_if_t<1 == Num>>
constexpr void setValue(int, const DataT &Value, float) {
m_Data = vec_data<DataT>::get(Value);
m_Data = vec_data<DataT>::set(Value);
}

template <int Num = NumElements,
Expand All @@ -1524,6 +1548,9 @@ template <typename Type, int NumElements> class vec {
return vec_data<DataT>::get(m_Data);
}

// setValue and getValue.
// The "api" functions used by BINOP etc. These versions just dispatch
// using additional int or float arg to disambiguate vec<1> vs. vec<N>
// Special proxies as specialization is not allowed in class scope.
constexpr void setValue(int Index, const DataT &Value) {
if (NumElements == 1)
Expand Down Expand Up @@ -2544,6 +2571,7 @@ __SYCL_DEFINE_HALF_VECSTORAGE(16)
// Single element bfloat16
template <> struct VecStorage<sycl::ext::oneapi::bfloat16, 1, void> {
using DataType = sycl::ext::oneapi::detail::Bfloat16StorageT;
// using VectorDataType = sycl::ext::oneapi::bfloat16;
using VectorDataType = sycl::ext::oneapi::detail::Bfloat16StorageT;
};
// Multiple elements bfloat16
Expand Down
141 changes: 141 additions & 0 deletions sycl/test-e2e/BFloat16/bfloat16_vec.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
//==------------ bfloat16_vec.cpp - test sycl::vec<bfloat16>----------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// RUN: %clangxx -fsycl %s -o %t.out
// RUN: %t.out
// RUN: %if preview-breaking-changes-supported %{ %clangxx -fsycl %s -o %t.out && %t.out %}

#include <sycl/sycl.hpp>

constexpr unsigned N = 8; // + - * / for vec<1> and vec<2>

int main() {

// clang-format off
using T = sycl::ext::oneapi::bfloat16;

sycl::queue q;

T a0{ 2.0f };
T b0 { 6.0f };

T addition_ref0 = a0 + b0;
T subtraction_ref0 = a0 - b0;
T multiplication_ref0 = a0 * b0;
T division_ref0 = a0 / b0;

std::cout << " === vec<bfloat16, 1> === " << std::endl;
std::cout << " --- ON HOST --- " << std::endl;
sycl::vec<T, 1> oneA{ a0 }, oneB{ b0 };
sycl::vec<T, 1> simple_addition = oneA + oneB;
sycl::vec<T, 1> simple_subtraction = oneA - oneB;
sycl::vec<T, 1> simple_multiplication = oneA * oneB;
sycl::vec<T, 1> simple_division = oneA / oneB;

std::cout << "addition. ref: " << addition_ref0 << " vec: " << simple_addition[0] << std::endl;
std::cout << "subtraction. ref: " << subtraction_ref0 << " vec: " << simple_subtraction[0] << std::endl;
std::cout << "multiplication. ref: " << multiplication_ref0 << " vec: " << simple_multiplication[0] << std::endl;
std::cout << "division. ref: " << division_ref0 << " vec: " << simple_division[0] << std::endl;

assert(addition_ref0 == simple_addition[0]);
assert(subtraction_ref0 == simple_subtraction[0]);
assert(multiplication_ref0 == simple_multiplication[0]);
assert(division_ref0 == simple_division[0]);

std::cout << " --- ON DEVICE --- " << std::endl;
sycl::range<1> r(N);
sycl::buffer<bool, 1> buf(r);

q.submit([&](sycl::handler &cgh) {
sycl::stream out(2024, 400, cgh);
sycl::accessor acc(buf, cgh, sycl::write_only );
cgh.single_task([=](){
sycl::vec<T, 1> device_addition = oneA + oneB;
sycl::vec<T, 1> device_subtraction = oneA - oneB;
sycl::vec<T, 1> device_multiplication = oneA * oneB;
sycl::vec<T, 1> device_division = oneA / oneB;

out << "addition. ref: " << addition_ref0 << " vec: " << device_addition[0] << sycl::endl;
out << "subtraction. ref: " << subtraction_ref0 << " vec: " << device_subtraction[0] << sycl::endl;
out << "multiplication. ref: " << multiplication_ref0 << " vec: " << device_multiplication[0] << sycl::endl;
out << "division. ref: " << division_ref0 << " vec: " << device_division[0] << sycl::endl;

acc[0] = (addition_ref0 == device_addition[0]);
acc[1] = (subtraction_ref0 == device_subtraction[0]);
acc[2] = (multiplication_ref0 == device_multiplication[0]);
acc[3] = (division_ref0 == device_division[0]);

});
}).wait();


// second value
T a1 { 3.33333f };
T b1 { 6.66666f };
T addition_ref1 = a1 + b1;
T subtraction_ref1 = a1 - b1;
T multiplication_ref1 = a1 * b1;
T division_ref1 = a1 / b1;

std::cout << "\n === vec<bfloat16, 2> === " << std::endl;
std::cout << " --- ON HOST --- " << std::endl;
sycl::vec<T, 2> twoA{ a0, a1 }, twoB{ b0, b1 };
sycl::vec<T, 2> double_addition = twoA + twoB;
sycl::vec<T, 2> double_subtraction = twoA - twoB;
sycl::vec<T, 2> double_multiplication = twoA * twoB;
sycl::vec<T, 2> double_division = twoA / twoB;

std::cout << "+ ref0: " << addition_ref0 << " ref1: " << addition_ref1 << std::endl;
std::cout << "add[0]: " << double_addition[0] << " add[1]: " << double_addition[1] << std::endl;
std::cout << "- ref0: " << subtraction_ref0 << " ref1: " << subtraction_ref1 << std::endl;
std::cout << "sub[0]: " << double_subtraction[0] << " sub[1]: " << double_subtraction[1] << std::endl;
std::cout << "* ref0: " << multiplication_ref0 << " ref1: " << multiplication_ref1 << std::endl;
std::cout << "mul[0]: " << double_multiplication[0] << " mul[1]: " << double_multiplication[1] << std::endl;
std::cout << "/ ref0: " << division_ref0 << " ref1: " << division_ref1 << std::endl;
std::cout << "div[0]: " << double_division[0] << " div[1]: " << double_division[1] << std::endl;

assert(addition_ref0 == double_addition[0]); assert(addition_ref1 == double_addition[1]);
assert(subtraction_ref0 == double_subtraction[0]); assert(subtraction_ref1 == double_subtraction[1]);
assert(multiplication_ref0 == double_multiplication[0]); assert(multiplication_ref1 == double_multiplication[1]);
assert(division_ref0 == double_division[0]); assert(division_ref1 == double_division[1]);

std::cout << " --- ON DEVICE --- " << std::endl;
q.submit([&](sycl::handler &cgh) {
sycl::stream out(2024, 400, cgh);
sycl::accessor acc(buf, cgh, sycl::write_only );
cgh.single_task([=](){
sycl::vec<T, 2> device_addition = twoA + twoB;
sycl::vec<T, 2> device_subtraction = twoA - twoB;
sycl::vec<T, 2> device_multiplication = twoA * twoB;
sycl::vec<T, 2> device_division = twoA / twoB;

out << "+ ref0: " << addition_ref0 << " ref1: " << addition_ref1 << sycl::endl;
out << "add[0]: " << device_addition[0] << " add[1]: " << device_addition[1] << sycl::endl;
out << "- ref0: " << subtraction_ref0 << " ref1: " << subtraction_ref1 << sycl::endl;
out << "sub[0]: " << device_subtraction[0] << " sub[1]: " << device_subtraction[1] << sycl::endl;
out << "* ref0: " << multiplication_ref0 << " ref1: " << multiplication_ref1 << sycl::endl;
out << "mul[0]: " << device_multiplication[0] << " mul[1]: " << device_multiplication[1] << sycl::endl;
out << "/ ref0: " << division_ref0 << " ref1: " << division_ref1 << sycl::endl;
out << "div[0]: " << device_division[0] << " div[1]: " << device_division[1] << sycl::endl;

acc[4] = (addition_ref0 == device_addition[0]) && (addition_ref1 == device_addition[1]);
acc[5] = (subtraction_ref0 == device_subtraction[0]) && (subtraction_ref1 == device_subtraction[1]);
acc[6] = (multiplication_ref0 == device_multiplication[0]) && (multiplication_ref1 == device_multiplication[1]);
acc[7] = (division_ref0 == device_division[0]) && (division_ref1 == device_division[1]);

});
}).wait();

sycl::host_accessor h_acc(buf, sycl::read_only);
for(unsigned i = 0; i < N; i++){
assert(h_acc[i]);
}

// clang-format on
return 0;
}

0 comments on commit 12a3a60

Please sign in to comment.