diff --git a/sycl/include/sycl/types.hpp b/sycl/include/sycl/types.hpp index bf7b2d9cef777..79eba7a597085 100644 --- a/sycl/include/sycl/types.hpp +++ b/sycl/include/sycl/types.hpp @@ -144,20 +144,43 @@ using select_apply_cl_t = std::conditional_t< template struct vec_helper { using RetType = T; static constexpr RetType get(T value) { return value; } + static constexpr RetType set(T value) { return value; } }; template <> struct vec_helper { using RetType = select_apply_cl_t; static constexpr RetType get(bool value) { return value; } + static constexpr RetType set(bool value) { return value; } +}; + +template <> struct vec_helper { + using RetType = sycl::ext::oneapi::bfloat16; + using BFloat16StorageT = sycl::ext::oneapi::detail::Bfloat16StorageT; + static constexpr RetType get(BFloat16StorageT value) { + // given that BFloat16StorageT is the storageT for bfloat16, I'd prefer + // to use a reinterpret_cast (or cast from void*). But inexplicably + // that's not allowed in constexpr (before C++20). + return sycl::bit_cast(value); + } + + static constexpr RetType get(RetType value) { return value; } + + static constexpr BFloat16StorageT set(RetType value) { + return sycl::bit_cast(value); + } }; #if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) template <> struct vec_helper { using RetType = std::uint8_t; static constexpr RetType get(std::byte value) { return (RetType)value; } + static constexpr RetType set(std::byte value) { return (RetType)value; } static constexpr std::byte get(std::uint8_t value) { return (std::byte)value; } + static constexpr std::byte set(std::uint8_t value) { + return (std::byte)value; + } }; #endif @@ -330,7 +353,7 @@ template class vec { // of 64 for direct params. If we drop MSVC, we can have alignment the same as // size and use vector extensions for all sizes. static constexpr bool IsUsingArrayOnDevice = - (IsHostHalf || IsSizeGreaterThanMaxAlign); + (IsHostHalf || IsBfloat16 || IsSizeGreaterThanMaxAlign); #if defined(__SYCL_DEVICE_ONLY__) static constexpr bool NativeVec = NumElements > 1 && !IsUsingArrayOnDevice; @@ -343,7 +366,7 @@ template class vec { #endif // defined(__INTEL_PREVIEW_BREAKING_CHANGES) #if !defined(__INTEL_PREVIEW_BREAKING_CHANGES) - static constexpr bool IsUsingArrayOnDevice = IsHostHalf; + static constexpr bool IsUsingArrayOnDevice = IsHostHalf || IsBfloat16; #endif // !defined(__INTEL_PREVIEW_BREAKING_CHANGES) static constexpr int getNumElements() { return NumElements; } @@ -1035,7 +1058,7 @@ template class vec { typename std::enable_if_t< \ std::is_convertible_v && \ (std::is_fundamental_v> || \ - std::is_same_v, half>), \ + detail::is_half_or_bf16_v>), \ vec> \ operator BINOP(const T & Rhs) const { \ return *this BINOP vec(static_cast(Rhs)); \ @@ -1464,13 +1487,13 @@ template class vec { // setValue and getValue should be able to operate on different underlying // types: enum cl_float#N , builtin vector float#N, builtin type float. - + // These versions are for N > 1. #ifdef __SYCL_USE_EXT_VECTOR_TYPE__ template > constexpr void setValue(EnableIfNotHostHalf Index, const DataT &Value, int) { - m_Data[Index] = vec_data::get(Value); + m_Data[Index] = vec_data::set(Value); } template class vec { template > constexpr void setValue(EnableIfHostHalf Index, const DataT &Value, int) { - m_Data.s[Index] = vec_data::get(Value); + m_Data.s[Index] = vec_data::set(Value); } template class vec { typename = typename std::enable_if_t<1 != Num>> constexpr void setValue(int Index, const DataT &Value, int) { #if defined(__INTEL_PREVIEW_BREAKING_CHANGES) - m_Data[Index] = vec_data::get(Value); + m_Data[Index] = vec_data::set(Value); #else - m_Data.s[Index] = vec_data::get(Value); + m_Data.s[Index] = vec_data::set(Value); #endif } @@ -1512,10 +1535,11 @@ template class vec { } #endif // __SYCL_USE_EXT_VECTOR_TYPE__ + // N==1 versions, used by host and device. Shouldn't trailing type be int? template > constexpr void setValue(int, const DataT &Value, float) { - m_Data = vec_data::get(Value); + m_Data = vec_data::set(Value); } template class vec { return vec_data::get(m_Data); } + // setValue and getValue. + // The "api" functions used by BINOP etc. These versions just dispatch + // using additional int or float arg to disambiguate vec<1> vs. vec // Special proxies as specialization is not allowed in class scope. constexpr void setValue(int Index, const DataT &Value) { if (NumElements == 1) @@ -2544,6 +2571,7 @@ __SYCL_DEFINE_HALF_VECSTORAGE(16) // Single element bfloat16 template <> struct VecStorage { using DataType = sycl::ext::oneapi::detail::Bfloat16StorageT; + // using VectorDataType = sycl::ext::oneapi::bfloat16; using VectorDataType = sycl::ext::oneapi::detail::Bfloat16StorageT; }; // Multiple elements bfloat16 diff --git a/sycl/test-e2e/BFloat16/bfloat16_vec.cpp b/sycl/test-e2e/BFloat16/bfloat16_vec.cpp new file mode 100644 index 0000000000000..7af1a97a2d01c --- /dev/null +++ b/sycl/test-e2e/BFloat16/bfloat16_vec.cpp @@ -0,0 +1,141 @@ +//==------------ bfloat16_vec.cpp - test sycl::vec----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// RUN: %clangxx -fsycl %s -o %t.out +// RUN: %t.out +// RUN: %if preview-breaking-changes-supported %{ %clangxx -fsycl %s -o %t.out && %t.out %} + +#include + +constexpr unsigned N = 8; // + - * / for vec<1> and vec<2> + +int main() { + + // clang-format off + using T = sycl::ext::oneapi::bfloat16; + + sycl::queue q; + + T a0{ 2.0f }; + T b0 { 6.0f }; + + T addition_ref0 = a0 + b0; + T subtraction_ref0 = a0 - b0; + T multiplication_ref0 = a0 * b0; + T division_ref0 = a0 / b0; + + std::cout << " === vec === " << std::endl; + std::cout << " --- ON HOST --- " << std::endl; + sycl::vec oneA{ a0 }, oneB{ b0 }; + sycl::vec simple_addition = oneA + oneB; + sycl::vec simple_subtraction = oneA - oneB; + sycl::vec simple_multiplication = oneA * oneB; + sycl::vec simple_division = oneA / oneB; + + std::cout << "addition. ref: " << addition_ref0 << " vec: " << simple_addition[0] << std::endl; + std::cout << "subtraction. ref: " << subtraction_ref0 << " vec: " << simple_subtraction[0] << std::endl; + std::cout << "multiplication. ref: " << multiplication_ref0 << " vec: " << simple_multiplication[0] << std::endl; + std::cout << "division. ref: " << division_ref0 << " vec: " << simple_division[0] << std::endl; + + assert(addition_ref0 == simple_addition[0]); + assert(subtraction_ref0 == simple_subtraction[0]); + assert(multiplication_ref0 == simple_multiplication[0]); + assert(division_ref0 == simple_division[0]); + + std::cout << " --- ON DEVICE --- " << std::endl; + sycl::range<1> r(N); + sycl::buffer buf(r); + + q.submit([&](sycl::handler &cgh) { + sycl::stream out(2024, 400, cgh); + sycl::accessor acc(buf, cgh, sycl::write_only ); + cgh.single_task([=](){ + sycl::vec device_addition = oneA + oneB; + sycl::vec device_subtraction = oneA - oneB; + sycl::vec device_multiplication = oneA * oneB; + sycl::vec device_division = oneA / oneB; + + out << "addition. ref: " << addition_ref0 << " vec: " << device_addition[0] << sycl::endl; + out << "subtraction. ref: " << subtraction_ref0 << " vec: " << device_subtraction[0] << sycl::endl; + out << "multiplication. ref: " << multiplication_ref0 << " vec: " << device_multiplication[0] << sycl::endl; + out << "division. ref: " << division_ref0 << " vec: " << device_division[0] << sycl::endl; + + acc[0] = (addition_ref0 == device_addition[0]); + acc[1] = (subtraction_ref0 == device_subtraction[0]); + acc[2] = (multiplication_ref0 == device_multiplication[0]); + acc[3] = (division_ref0 == device_division[0]); + + }); + }).wait(); + + + // second value + T a1 { 3.33333f }; + T b1 { 6.66666f }; + T addition_ref1 = a1 + b1; + T subtraction_ref1 = a1 - b1; + T multiplication_ref1 = a1 * b1; + T division_ref1 = a1 / b1; + + std::cout << "\n === vec === " << std::endl; + std::cout << " --- ON HOST --- " << std::endl; + sycl::vec twoA{ a0, a1 }, twoB{ b0, b1 }; + sycl::vec double_addition = twoA + twoB; + sycl::vec double_subtraction = twoA - twoB; + sycl::vec double_multiplication = twoA * twoB; + sycl::vec double_division = twoA / twoB; + + std::cout << "+ ref0: " << addition_ref0 << " ref1: " << addition_ref1 << std::endl; + std::cout << "add[0]: " << double_addition[0] << " add[1]: " << double_addition[1] << std::endl; + std::cout << "- ref0: " << subtraction_ref0 << " ref1: " << subtraction_ref1 << std::endl; + std::cout << "sub[0]: " << double_subtraction[0] << " sub[1]: " << double_subtraction[1] << std::endl; + std::cout << "* ref0: " << multiplication_ref0 << " ref1: " << multiplication_ref1 << std::endl; + std::cout << "mul[0]: " << double_multiplication[0] << " mul[1]: " << double_multiplication[1] << std::endl; + std::cout << "/ ref0: " << division_ref0 << " ref1: " << division_ref1 << std::endl; + std::cout << "div[0]: " << double_division[0] << " div[1]: " << double_division[1] << std::endl; + + assert(addition_ref0 == double_addition[0]); assert(addition_ref1 == double_addition[1]); + assert(subtraction_ref0 == double_subtraction[0]); assert(subtraction_ref1 == double_subtraction[1]); + assert(multiplication_ref0 == double_multiplication[0]); assert(multiplication_ref1 == double_multiplication[1]); + assert(division_ref0 == double_division[0]); assert(division_ref1 == double_division[1]); + + std::cout << " --- ON DEVICE --- " << std::endl; + q.submit([&](sycl::handler &cgh) { + sycl::stream out(2024, 400, cgh); + sycl::accessor acc(buf, cgh, sycl::write_only ); + cgh.single_task([=](){ + sycl::vec device_addition = twoA + twoB; + sycl::vec device_subtraction = twoA - twoB; + sycl::vec device_multiplication = twoA * twoB; + sycl::vec device_division = twoA / twoB; + + out << "+ ref0: " << addition_ref0 << " ref1: " << addition_ref1 << sycl::endl; + out << "add[0]: " << device_addition[0] << " add[1]: " << device_addition[1] << sycl::endl; + out << "- ref0: " << subtraction_ref0 << " ref1: " << subtraction_ref1 << sycl::endl; + out << "sub[0]: " << device_subtraction[0] << " sub[1]: " << device_subtraction[1] << sycl::endl; + out << "* ref0: " << multiplication_ref0 << " ref1: " << multiplication_ref1 << sycl::endl; + out << "mul[0]: " << device_multiplication[0] << " mul[1]: " << device_multiplication[1] << sycl::endl; + out << "/ ref0: " << division_ref0 << " ref1: " << division_ref1 << sycl::endl; + out << "div[0]: " << device_division[0] << " div[1]: " << device_division[1] << sycl::endl; + + acc[4] = (addition_ref0 == device_addition[0]) && (addition_ref1 == device_addition[1]); + acc[5] = (subtraction_ref0 == device_subtraction[0]) && (subtraction_ref1 == device_subtraction[1]); + acc[6] = (multiplication_ref0 == device_multiplication[0]) && (multiplication_ref1 == device_multiplication[1]); + acc[7] = (division_ref0 == device_division[0]) && (division_ref1 == device_division[1]); + + }); + }).wait(); + + sycl::host_accessor h_acc(buf, sycl::read_only); + for(unsigned i = 0; i < N; i++){ + assert(h_acc[i]); + } + + // clang-format on + return 0; +} \ No newline at end of file