Skip to content

Commit

Permalink
[SYCL] Limit bfloat16 operators to scalar datatypes convertible to fl…
Browse files Browse the repository at this point in the history
…oat (#12477)

Aligns implementation with the spec:
https://github.com/intel/llvm/blob/48ec4dd6ff45b25b405b99d63f3c5b8537b1475d/sycl/doc/extensions/supported/sycl_ext_oneapi_bfloat16.asciidoc

It also enables the ESIMD tests for bfloat16 type arithmetics.
  • Loading branch information
fineg74 authored Feb 27, 2024
1 parent ce70cb5 commit f81b5a2
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
return O;
}

template <> struct is_esimd_arithmetic_type<bfloat16, void> : std::true_type {};

} // namespace ext::intel::esimd::detail
} // namespace _V1
} // namespace sycl
Expand Down
17 changes: 4 additions & 13 deletions sycl/include/sycl/ext/oneapi/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,6 @@ class bfloat16 {
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> \
friend bfloat16 &operator op(bfloat16 & lhs, const T & rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
} \
template <typename T> friend T &operator op(T & lhs, const bfloat16 & rhs) { \
float f = static_cast<float>(lhs); \
f op static_cast<float>(rhs); \
return lhs = f; \
}
OP(+=)
OP(-=)
Expand All @@ -222,11 +211,13 @@ class bfloat16 {
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
friend std::enable_if_t<std::is_convertible_v<T, float>, type> operator op( \
const bfloat16 & lhs, const T & rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
} \
template <typename T> \
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
friend std::enable_if_t<std::is_convertible_v<T, float>, type> operator op( \
const T & lhs, const bfloat16 & rhs) { \
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
}
OP(bfloat16, +)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ int main() {
}

#ifdef USE_BF16
// TODO: Reenable once the issue with bfloat16 is resolved
// Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
#endif
#ifdef USE_TF32
Passed &= test<sycl::ext::intel::experimental::esimd::tfloat32>(Q);
Expand Down
100 changes: 100 additions & 0 deletions sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out
//==- bfloat16_vector_plus_scalar.cpp - Test for bfloat16 operators ------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "../esimd_test_utils.hpp"
#include <iostream>
#include <sycl/ext/intel/esimd.hpp>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::intel::esimd;
using namespace sycl::ext::intel::experimental::esimd;

template <typename T> ESIMD_NOINLINE bool test(queue Q) {
std::cout << "Testing T=" << esimd_test::type_name<T>() << "...\n";

constexpr int N = 8;

constexpr int NumOps = 4;
constexpr int CSize = NumOps * N;

T *Mem = malloc_shared<T>(CSize, Q);
T TOne = static_cast<T>(1);
T TTen = static_cast<T>(10);

Q.single_task([=]() SYCL_ESIMD_KERNEL {
{
simd<T, N> Vec(TOne);
Vec = Vec + TTen;
Vec.copy_to(Mem);
}
{
simd<T, N> Vec(TOne);
Vec = Vec - TTen;
Vec.copy_to(Mem + N);
}
{
simd<T, N> Vec(TOne);
Vec = Vec * TTen;
Vec.copy_to(Mem + 2 * N);
}
{
simd<T, N> Vec(TOne);
Vec = Vec / TTen;
Vec.copy_to(Mem + 3 * N);
}
}).wait();

bool ReturnValue = true;
for (int i = 0; i < N; ++i) {
if (Mem[i] != TOne + TTen) {
ReturnValue = false;
break;
}
if (Mem[i + N] != TOne - TTen) {
ReturnValue = false;
break;
}
if (Mem[i + 2 * N] != TOne * TTen) {
ReturnValue = false;
break;
}
if (!((Mem[i + 3 * N] == (TOne / TTen)) ||
(std::abs((double)(Mem[i + 3 * N] - (TOne / TTen)) /
(double)(TOne / TTen)) <= 0.001))) {
ReturnValue = false;
break;
}
}

free(Mem, Q);
return ReturnValue;
}

int main() {
queue Q;
esimd_test::printTestLabel(Q);

bool SupportsHalf = Q.get_device().has(aspect::fp16);

bool Passed = true;
Passed &= test<int>(Q);
Passed &= test<float>(Q);
if (SupportsHalf) {
Passed &= test<sycl::half>(Q);
}
#ifdef USE_BF16
Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
#endif
#ifdef USE_TF32
Passed &= test<sycl::ext::intel::experimental::esimd::tfloat32>(Q);
#endif
std::cout << (Passed ? "Passed\n" : "FAILED\n");
return Passed ? 0 : 1;
}
17 changes: 17 additions & 0 deletions sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar_pvc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//==- bfloat16_vector_plus_scalar_pvc.cpp - Test for bfloat16 operators -==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This test validates operations between simd vector and scalars for
// bfloat16 and tfloat32 types that are available only on PVC.
//===----------------------------------------------------------------------===//
// REQUIRES: gpu-intel-pvc
// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#define USE_BF16
#define USE_TF32
#include "bfloat16_vector_plus_scalar.cpp"
3 changes: 3 additions & 0 deletions sycl/test/type_traits/half_operator_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,14 @@ int main() {
check_half_math_operator_types<int, sycl::half>(Queue);
check_half_math_operator_types<long, sycl::half>(Queue);
check_half_math_operator_types<long long, sycl::half>(Queue);
check_half_math_operator_types<sycl::ext::oneapi::bfloat16,
sycl::ext::oneapi::bfloat16>(Queue);

check_half_logical_operator_types<sycl::half>(Queue);
check_half_logical_operator_types<double>(Queue);
check_half_logical_operator_types<float>(Queue);
check_half_logical_operator_types<int>(Queue);
check_half_logical_operator_types<long>(Queue);
check_half_logical_operator_types<long long>(Queue);
check_half_logical_operator_types<sycl::ext::oneapi::bfloat16>(Queue);
}

0 comments on commit f81b5a2

Please sign in to comment.