-
Notifications
You must be signed in to change notification settings - Fork 752
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYCL] Refactor builtins implementation (#11956)
See `builtins_preview.hpp` for the outline of the new design. This PR changes the implementation under `-fpreview-breaking-changes` and removes reliance on a python builtins generator script. Suggested reading/review order: `builtins_preview.hpp`, `helper_macros.hpp`, `host_helper_macros.hpp`, then headers implementing user-visible side with the library implementation `sycl/source/builtins/*_functions.cpp` last.
- Loading branch information
1 parent
5bb9a44
commit 7e9819d
Showing
23 changed files
with
2,874 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
//==------------------- builtins_preview.hpp -------------------------------==// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
// Implement SYCL builtin functions. This implementation is mainly driven by the | ||
// requirement of not including <cmath> anywhere in the SYCL headers (i.e. from | ||
// within <sycl/sycl.hpp>), because it pollutes global namespace. Note that we | ||
// can avoid that using MSVC's STL as the pollution happens even from | ||
// <vector>/<string> and other headers that have to be included per the SYCL | ||
// specification. As such, an alternative approach might be to use math | ||
// intrinsics with GCC/clang-based compilers and use <cmath> when using MSVC as | ||
// a host compiler. That hasn't been tried/investigated. | ||
// | ||
// Current implementation splits builtins into several files following the SYCL | ||
// 2020 (revision 8) split into common/math/geometric/relational/etc. functions. | ||
// For each set, the implementation is split into a user-visible | ||
// include/sycl/detail/builtins/*_functions.hpp providing full device-side | ||
// implementation as well as defining user-visible APIs and defining ABI | ||
// implemented under source/builtins/*_functions.cpp for the host side. We | ||
// provide both scalar/vector overloads through symbols in the SYCL runtime | ||
// library due to the <cmath> limitation above (for scalars) and due to | ||
// performance reasons for vector overloads (to be able to benefit from | ||
// vectorization). | ||
// | ||
// Providing declaration for the host side symbols contained in the library | ||
// comes with its own challenges. One is compilation time - blindly providing | ||
// all those declarations takes significant time (about 10% slowdown for | ||
// "clang++ -fsycl" when compiling just "#include <sycl/sycl.hpp>"). Another | ||
// issue is that return type for templates is part of the mangling (and as such | ||
// SFINAE requirements too). To overcome that we structure host side | ||
// implementation roughly like this (in most cases): | ||
// | ||
// math_function.cpp exports: | ||
// float sycl::__sin_impl(float); | ||
// float1 sycl::__sin_impl(float1); | ||
// float2 sycl::__sin_impl(float2); | ||
// ... | ||
// /* same for other types */ | ||
// | ||
// math_functions.hpp provide an implementation based on the following idea (in | ||
// ::sycl namespace): | ||
// float sin(float x) { | ||
// extern __sin_impl(float); | ||
// return __sin_impl(x); | ||
// } | ||
// template <typename T> | ||
// enable_if_valid_type<T> sin(T x) { | ||
// if constexpr (marray_or_swizzle) { | ||
// ... | ||
// call sycl::sin(vector_or_scalar) | ||
// } else { | ||
// extern T __sin_impl(T); | ||
// return __sin_impl(x); | ||
// } | ||
// } | ||
// That way we avoid having the full set of explicit declaration for the symbols | ||
// in the library and instead only pay with compile time when those template | ||
// instantiations actually happen. | ||
|
||
#pragma once | ||
|
||
#include <sycl/builtins_utils_vec.hpp> | ||
|
||
namespace sycl { | ||
inline namespace _V1 { | ||
namespace detail { | ||
template <typename... Ts> | ||
inline constexpr bool builtin_same_shape_v = | ||
((... && is_scalar_arithmetic_v<Ts>) || (... && is_marray_v<Ts>) || | ||
(... && is_vec_or_swizzle_v<Ts>)) && | ||
(... && (num_elements<Ts>::value == | ||
num_elements<typename first_type<Ts...>::type>::value)); | ||
|
||
template <typename... Ts> | ||
inline constexpr bool builtin_same_or_swizzle_v = | ||
// Use builtin_same_shape_v to filter out types unrelated to builtins. | ||
builtin_same_shape_v<Ts...> && all_same_v<simplify_if_swizzle_t<Ts>...>; | ||
|
||
namespace builtins { | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
template <typename T> auto convert_arg(T &&x) { | ||
using no_cv_ref = std::remove_cv_t<std::remove_reference_t<T>>; | ||
if constexpr (is_vec_v<no_cv_ref>) { | ||
using elem_type = get_elem_type_t<no_cv_ref>; | ||
using converted_elem_type = | ||
decltype(convert_arg(std::declval<elem_type>())); | ||
|
||
constexpr auto N = no_cv_ref::size(); | ||
using result_type = std::conditional_t<N == 1, converted_elem_type, | ||
converted_elem_type | ||
__attribute__((ext_vector_type(N)))>; | ||
// TODO: We should have this bit_cast impl inside vec::convert. | ||
return bit_cast<result_type>(static_cast<typename no_cv_ref::vector_t>(x)); | ||
} else if constexpr (std::is_same_v<no_cv_ref, half>) | ||
return static_cast<half_impl::BIsRepresentationT>(x); | ||
else if constexpr (is_multi_ptr_v<no_cv_ref>) { | ||
return convert_arg(x.get_decorated()); | ||
} else if constexpr (is_scalar_arithmetic_v<no_cv_ref>) { | ||
// E.g. on linux: long long -> int64_t (long), or char -> int8_t (signed | ||
// char) and same for unsigned; Windows has long/long long reversed. | ||
// TODO: Inline this scalar impl. | ||
return static_cast<ConvertToOpenCLType_t<no_cv_ref>>(x); | ||
} else if constexpr (std::is_pointer_v<no_cv_ref>) { | ||
using elem_type = remove_decoration_t<std::remove_pointer_t<no_cv_ref>>; | ||
using converted_elem_type = | ||
decltype(convert_arg(std::declval<elem_type>())); | ||
using result_type = | ||
typename DecoratedType<converted_elem_type, | ||
deduce_AS<no_cv_ref>::value>::type *; | ||
return reinterpret_cast<result_type>(x); | ||
} else if constexpr (is_swizzle_v<no_cv_ref>) { | ||
return convert_arg(simplify_if_swizzle_t<no_cv_ref>{x}); | ||
} else { | ||
// TODO: should it be unreachable? What can it be? | ||
return std::forward<T>(x); | ||
} | ||
} | ||
|
||
template <typename RetTy, typename T> auto convert_result(T &&x) { | ||
if constexpr (is_vec_v<RetTy>) { | ||
return bit_cast<typename RetTy::vector_t>(x); | ||
} else { | ||
return std::forward<T>(x); | ||
} | ||
} | ||
#endif | ||
} // namespace builtins | ||
|
||
template <typename FuncTy, typename... Ts> | ||
auto builtin_marray_impl(FuncTy F, const Ts &...x) { | ||
using ret_elem_type = decltype(F(x[0]...)); | ||
using T = typename first_type<Ts...>::type; | ||
marray<ret_elem_type, T::size()> Res; | ||
constexpr auto N = T::size(); | ||
for (size_t I = 0; I < N / 2; ++I) { | ||
auto PartialRes = F(to_vec2(x, I * 2)...); | ||
std::memcpy(&Res[I * 2], &PartialRes, sizeof(decltype(PartialRes))); | ||
} | ||
if (N % 2) | ||
Res[N - 1] = F(x[N - 1]...); | ||
return Res; | ||
} | ||
|
||
template <typename FuncTy, typename... Ts> | ||
auto builtin_default_host_impl(FuncTy F, const Ts &...x) { | ||
// We implement support for marray/swizzle in the headers and export symbols | ||
// for scalars/vector from the library binary. The reason is that scalar | ||
// implementations mostly depend on <cmath> which pollutes global namespace, | ||
// so we can't unconditionally include it from the SYCL headers. Vector | ||
// overloads have to be implemented in the library next to scalar overloads in | ||
// order to be vectorizable. | ||
if constexpr ((... || is_marray_v<Ts>)) { | ||
return builtin_marray_impl(F, x...); | ||
} else { | ||
return F(simplify_if_swizzle_t<Ts>{x}...); | ||
} | ||
} | ||
|
||
template <typename FuncTy, typename... Ts> | ||
auto builtin_delegate_to_scalar(FuncTy F, const Ts &...x) { | ||
using T = typename first_type<Ts...>::type; | ||
if constexpr (is_vec_or_swizzle_v<T>) { | ||
using ret_elem_type = decltype(F(x[0]...)); | ||
// TODO: using r{} to avoid Werror. Not sure if ok. | ||
vec<ret_elem_type, T::size()> r{}; | ||
loop<T::size()>([&](auto idx) { r[idx] = F(x[idx]...); }); | ||
return r; | ||
} else { | ||
static_assert(is_marray_v<T>); | ||
return builtin_marray_impl(F, x...); | ||
} | ||
} | ||
|
||
template <typename T> | ||
struct any_elem_type | ||
: std::bool_constant<check_type_in_v< | ||
get_elem_type_t<T>, float, double, half, char, signed char, short, | ||
int, long, long long, unsigned char, unsigned short, unsigned int, | ||
unsigned long, unsigned long long>> {}; | ||
template <typename T> | ||
struct fp_elem_type | ||
: std::bool_constant< | ||
check_type_in_v<get_elem_type_t<T>, float, double, half>> {}; | ||
template <typename T> | ||
struct float_elem_type | ||
: std::bool_constant<check_type_in_v<get_elem_type_t<T>, float>> {}; | ||
template <typename T> | ||
struct integer_elem_type | ||
: std::bool_constant< | ||
check_type_in_v<get_elem_type_t<T>, char, signed char, short, int, | ||
long, long long, unsigned char, unsigned short, | ||
unsigned int, unsigned long, unsigned long long>> {}; | ||
template <typename T> | ||
struct suint32_elem_type | ||
: std::bool_constant< | ||
check_type_in_v<get_elem_type_t<T>, int32_t, uint32_t>> {}; | ||
|
||
template <typename... Ts> | ||
struct same_basic_shape : std::bool_constant<builtin_same_shape_v<Ts...>> {}; | ||
|
||
template <typename... Ts> | ||
struct same_elem_type : std::bool_constant<same_basic_shape<Ts...>::value && | ||
all_same_v<get_elem_type_t<Ts>...>> { | ||
}; | ||
|
||
template <typename> struct any_shape : std::true_type {}; | ||
|
||
template <typename T> | ||
struct scalar_only : std::bool_constant<is_scalar_arithmetic_v<T>> {}; | ||
|
||
template <typename T> | ||
struct non_scalar_only : std::bool_constant<!is_scalar_arithmetic_v<T>> {}; | ||
|
||
template <typename T> struct default_ret_type { | ||
using type = T; | ||
}; | ||
|
||
template <typename T> struct scalar_ret_type { | ||
using type = get_elem_type_t<T>; | ||
}; | ||
|
||
template <template <typename> typename RetTypeTrait, | ||
template <typename> typename ElemTypeChecker, | ||
template <typename> typename ShapeChecker, | ||
template <typename...> typename ExtraConditions, typename... Ts> | ||
struct builtin_enable | ||
: std::enable_if< | ||
ElemTypeChecker<typename first_type<Ts...>::type>::value && | ||
ShapeChecker<typename first_type<Ts...>::type>::value && | ||
ExtraConditions<Ts...>::value, | ||
typename RetTypeTrait< | ||
simplify_if_swizzle_t<typename first_type<Ts...>::type>>::type> { | ||
}; | ||
#define BUILTIN_CREATE_ENABLER(NAME, RET_TYPE_TRAIT, ELEM_TYPE_CHECKER, \ | ||
SHAPE_CHECKER, EXTRA_CONDITIONS) \ | ||
namespace detail { \ | ||
template <typename... Ts> \ | ||
using NAME##_t = \ | ||
typename builtin_enable<RET_TYPE_TRAIT, ELEM_TYPE_CHECKER, \ | ||
SHAPE_CHECKER, EXTRA_CONDITIONS, Ts...>::type; \ | ||
} | ||
} // namespace detail | ||
|
||
BUILTIN_CREATE_ENABLER(builtin_enable_generic, default_ret_type, any_elem_type, | ||
any_shape, same_elem_type) | ||
BUILTIN_CREATE_ENABLER(builtin_enable_generic_scalar, default_ret_type, | ||
any_elem_type, scalar_only, same_elem_type) | ||
BUILTIN_CREATE_ENABLER(builtin_enable_generic_non_scalar, default_ret_type, | ||
any_elem_type, non_scalar_only, same_elem_type) | ||
} // namespace _V1 | ||
} // namespace sycl | ||
|
||
// The headers below are specifically implemented without including all the | ||
// necessary headers to allow preprocessing them on their own and providing | ||
// human-friendly result. One can use a command like this to achieve that: | ||
// clang++ -[DU]__SYCL_DEVICE_ONLY__ -x c++ math_functions.inc \ | ||
// -I <..>/llvm/sycl/include -E -o - \ | ||
// | grep -v '^#' | clang-format > math_functions.{host|device}.ii | ||
|
||
#include <sycl/detail/builtins/common_functions.inc> | ||
#include <sycl/detail/builtins/geometric_functions.inc> | ||
#include <sycl/detail/builtins/half_precision_math_functions.inc> | ||
#include <sycl/detail/builtins/integer_functions.inc> | ||
#include <sycl/detail/builtins/math_functions.inc> | ||
#include <sycl/detail/builtins/native_math_functions.inc> | ||
#include <sycl/detail/builtins/relational_functions.inc> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
//==------------------- common_functions.hpp -------------------------------==// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
// Intentionally insufficient set of includes and no "#pragma once". | ||
|
||
#include <sycl/detail/builtins/helper_macros.hpp> | ||
|
||
namespace sycl { | ||
inline namespace _V1 { | ||
BUILTIN_CREATE_ENABLER(builtin_enable_common, default_ret_type, fp_elem_type, | ||
any_shape, same_elem_type) | ||
BUILTIN_CREATE_ENABLER(builtin_enable_common_non_scalar, default_ret_type, | ||
fp_elem_type, non_scalar_only, same_elem_type) | ||
|
||
#ifdef __SYCL_DEVICE_ONLY__ | ||
#define BUILTIN_COMMON(NUM_ARGS, NAME, SPIRV_IMPL) \ | ||
DEVICE_IMPL_TEMPLATE(NUM_ARGS, NAME, builtin_enable_common_t, SPIRV_IMPL) | ||
#else | ||
#define BUILTIN_COMMON(NUM_ARGS, NAME, SPIRV_IMPL) \ | ||
HOST_IMPL_TEMPLATE(NUM_ARGS, NAME, builtin_enable_common_t, common, \ | ||
default_ret_type) | ||
#endif | ||
|
||
BUILTIN_COMMON(ONE_ARG, degrees, __spirv_ocl_degrees) | ||
BUILTIN_COMMON(ONE_ARG, radians, __spirv_ocl_radians) | ||
BUILTIN_COMMON(ONE_ARG, sign, __spirv_ocl_sign) | ||
|
||
BUILTIN_COMMON(THREE_ARGS, mix, __spirv_ocl_mix) | ||
template <typename T0, typename T1> | ||
detail::builtin_enable_common_non_scalar_t<T0, T1> | ||
mix(T0 x, T1 y, detail::get_elem_type_t<T0> z) { | ||
return mix(detail::simplify_if_swizzle_t<T0>{x}, | ||
detail::simplify_if_swizzle_t<T0>{y}, | ||
detail::simplify_if_swizzle_t<T0>{z}); | ||
} | ||
|
||
BUILTIN_COMMON(TWO_ARGS, step, __spirv_ocl_step) | ||
template <typename T> | ||
detail::builtin_enable_common_non_scalar_t<T> step(detail::get_elem_type_t<T> x, | ||
T y) { | ||
return step(detail::simplify_if_swizzle_t<T>{x}, | ||
detail::simplify_if_swizzle_t<T>{y}); | ||
} | ||
|
||
BUILTIN_COMMON(THREE_ARGS, smoothstep, __spirv_ocl_smoothstep) | ||
template <typename T> | ||
detail::builtin_enable_common_non_scalar_t<T> | ||
smoothstep(detail::get_elem_type_t<T> x, detail::get_elem_type_t<T> y, T z) { | ||
return smoothstep(detail::simplify_if_swizzle_t<T>{x}, | ||
detail::simplify_if_swizzle_t<T>{y}, | ||
detail::simplify_if_swizzle_t<T>{z}); | ||
} | ||
|
||
BUILTIN_COMMON(TWO_ARGS, max, __spirv_ocl_fmax_common) | ||
template <typename T> | ||
detail::builtin_enable_common_non_scalar_t<T> | ||
max(T x, detail::get_elem_type_t<T> y) { | ||
return max(detail::simplify_if_swizzle_t<T>{x}, | ||
detail::simplify_if_swizzle_t<T>{y}); | ||
} | ||
|
||
BUILTIN_COMMON(TWO_ARGS, min, __spirv_ocl_fmin_common) | ||
template <typename T> | ||
detail::builtin_enable_common_non_scalar_t<T> | ||
min(T x, detail::get_elem_type_t<T> y) { | ||
return min(detail::simplify_if_swizzle_t<T>{x}, | ||
detail::simplify_if_swizzle_t<T>{y}); | ||
} | ||
|
||
#undef BUILTIN_COMMON | ||
|
||
#ifdef __SYCL_DEVICE_ONLY__ | ||
DEVICE_IMPL_TEMPLATE(THREE_ARGS, clamp, builtin_enable_generic_t, | ||
[](auto... xs) { | ||
using ElemTy = detail::get_elem_type_t<T0>; | ||
if constexpr (std::is_integral_v<ElemTy>) { | ||
if constexpr (std::is_signed_v<ElemTy>) { | ||
return __spirv_ocl_s_clamp(xs...); | ||
} else { | ||
return __spirv_ocl_u_clamp(xs...); | ||
} | ||
} else { | ||
return __spirv_ocl_fclamp(xs...); | ||
} | ||
}) | ||
#else | ||
HOST_IMPL_TEMPLATE(THREE_ARGS, clamp, builtin_enable_generic_t, common, | ||
default_ret_type) | ||
#endif | ||
template <typename T> | ||
detail::builtin_enable_generic_non_scalar_t<T> | ||
clamp(T x, detail::get_elem_type_t<T> y, detail::get_elem_type_t<T> z) { | ||
return clamp(detail::simplify_if_swizzle_t<T>{x}, | ||
detail::simplify_if_swizzle_t<T>{y}, | ||
detail::simplify_if_swizzle_t<T>{z}); | ||
} | ||
} // namespace _V1 | ||
} // namespace sycl |
Oops, something went wrong.