Skip to content

Commit

Permalink
[SYCL] Refactor builtins implementation (#11956)
Browse files Browse the repository at this point in the history
See `builtins_preview.hpp` for the outline of the new design. This PR
changes the implementation under `-fpreview-breaking-changes` and
removes reliance on a python builtins generator script.

Suggested reading/review order: `builtins_preview.hpp`,
`helper_macros.hpp`,
`host_helper_macros.hpp`, then headers implementing user-visible side
with
the library implementation `sycl/source/builtins/*_functions.cpp` last.
  • Loading branch information
aelovikov-intel authored Jan 17, 2024
1 parent 5bb9a44 commit 7e9819d
Show file tree
Hide file tree
Showing 23 changed files with 2,874 additions and 27 deletions.
5 changes: 1 addition & 4 deletions sycl/include/sycl/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES

// Include the generated builtins.
#include <sycl/builtins_marray_gen.hpp>
#include <sycl/builtins_scalar_gen.hpp>
#include <sycl/builtins_vector_gen.hpp>
#include <sycl/builtins_preview.hpp>

#else // __INTEL_PREVIEW_BREAKING_CHANGES

Expand Down
270 changes: 270 additions & 0 deletions sycl/include/sycl/builtins_preview.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
//==------------------- builtins_preview.hpp -------------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// Implement SYCL builtin functions. This implementation is mainly driven by the
// requirement of not including <cmath> anywhere in the SYCL headers (i.e. from
// within <sycl/sycl.hpp>), because it pollutes global namespace. Note that we
// can avoid that using MSVC's STL as the pollution happens even from
// <vector>/<string> and other headers that have to be included per the SYCL
// specification. As such, an alternative approach might be to use math
// intrinsics with GCC/clang-based compilers and use <cmath> when using MSVC as
// a host compiler. That hasn't been tried/investigated.
//
// Current implementation splits builtins into several files following the SYCL
// 2020 (revision 8) split into common/math/geometric/relational/etc. functions.
// For each set, the implementation is split into a user-visible
// include/sycl/detail/builtins/*_functions.hpp providing full device-side
// implementation as well as defining user-visible APIs and defining ABI
// implemented under source/builtins/*_functions.cpp for the host side. We
// provide both scalar/vector overloads through symbols in the SYCL runtime
// library due to the <cmath> limitation above (for scalars) and due to
// performance reasons for vector overloads (to be able to benefit from
// vectorization).
//
// Providing declaration for the host side symbols contained in the library
// comes with its own challenges. One is compilation time - blindly providing
// all those declarations takes significant time (about 10% slowdown for
// "clang++ -fsycl" when compiling just "#include <sycl/sycl.hpp>"). Another
// issue is that return type for templates is part of the mangling (and as such
// SFINAE requirements too). To overcome that we structure host side
// implementation roughly like this (in most cases):
//
// math_function.cpp exports:
// float sycl::__sin_impl(float);
// float1 sycl::__sin_impl(float1);
// float2 sycl::__sin_impl(float2);
// ...
// /* same for other types */
//
// math_functions.hpp provide an implementation based on the following idea (in
// ::sycl namespace):
// float sin(float x) {
// extern __sin_impl(float);
// return __sin_impl(x);
// }
// template <typename T>
// enable_if_valid_type<T> sin(T x) {
// if constexpr (marray_or_swizzle) {
// ...
// call sycl::sin(vector_or_scalar)
// } else {
// extern T __sin_impl(T);
// return __sin_impl(x);
// }
// }
// That way we avoid having the full set of explicit declaration for the symbols
// in the library and instead only pay with compile time when those template
// instantiations actually happen.

#pragma once

#include <sycl/builtins_utils_vec.hpp>

namespace sycl {
inline namespace _V1 {
namespace detail {
template <typename... Ts>
inline constexpr bool builtin_same_shape_v =
((... && is_scalar_arithmetic_v<Ts>) || (... && is_marray_v<Ts>) ||
(... && is_vec_or_swizzle_v<Ts>)) &&
(... && (num_elements<Ts>::value ==
num_elements<typename first_type<Ts...>::type>::value));

template <typename... Ts>
inline constexpr bool builtin_same_or_swizzle_v =
// Use builtin_same_shape_v to filter out types unrelated to builtins.
builtin_same_shape_v<Ts...> && all_same_v<simplify_if_swizzle_t<Ts>...>;

namespace builtins {
#ifdef __SYCL_DEVICE_ONLY__
template <typename T> auto convert_arg(T &&x) {
using no_cv_ref = std::remove_cv_t<std::remove_reference_t<T>>;
if constexpr (is_vec_v<no_cv_ref>) {
using elem_type = get_elem_type_t<no_cv_ref>;
using converted_elem_type =
decltype(convert_arg(std::declval<elem_type>()));

constexpr auto N = no_cv_ref::size();
using result_type = std::conditional_t<N == 1, converted_elem_type,
converted_elem_type
__attribute__((ext_vector_type(N)))>;
// TODO: We should have this bit_cast impl inside vec::convert.
return bit_cast<result_type>(static_cast<typename no_cv_ref::vector_t>(x));
} else if constexpr (std::is_same_v<no_cv_ref, half>)
return static_cast<half_impl::BIsRepresentationT>(x);
else if constexpr (is_multi_ptr_v<no_cv_ref>) {
return convert_arg(x.get_decorated());
} else if constexpr (is_scalar_arithmetic_v<no_cv_ref>) {
// E.g. on linux: long long -> int64_t (long), or char -> int8_t (signed
// char) and same for unsigned; Windows has long/long long reversed.
// TODO: Inline this scalar impl.
return static_cast<ConvertToOpenCLType_t<no_cv_ref>>(x);
} else if constexpr (std::is_pointer_v<no_cv_ref>) {
using elem_type = remove_decoration_t<std::remove_pointer_t<no_cv_ref>>;
using converted_elem_type =
decltype(convert_arg(std::declval<elem_type>()));
using result_type =
typename DecoratedType<converted_elem_type,
deduce_AS<no_cv_ref>::value>::type *;
return reinterpret_cast<result_type>(x);
} else if constexpr (is_swizzle_v<no_cv_ref>) {
return convert_arg(simplify_if_swizzle_t<no_cv_ref>{x});
} else {
// TODO: should it be unreachable? What can it be?
return std::forward<T>(x);
}
}

template <typename RetTy, typename T> auto convert_result(T &&x) {
if constexpr (is_vec_v<RetTy>) {
return bit_cast<typename RetTy::vector_t>(x);
} else {
return std::forward<T>(x);
}
}
#endif
} // namespace builtins

template <typename FuncTy, typename... Ts>
auto builtin_marray_impl(FuncTy F, const Ts &...x) {
using ret_elem_type = decltype(F(x[0]...));
using T = typename first_type<Ts...>::type;
marray<ret_elem_type, T::size()> Res;
constexpr auto N = T::size();
for (size_t I = 0; I < N / 2; ++I) {
auto PartialRes = F(to_vec2(x, I * 2)...);
std::memcpy(&Res[I * 2], &PartialRes, sizeof(decltype(PartialRes)));
}
if (N % 2)
Res[N - 1] = F(x[N - 1]...);
return Res;
}

template <typename FuncTy, typename... Ts>
auto builtin_default_host_impl(FuncTy F, const Ts &...x) {
// We implement support for marray/swizzle in the headers and export symbols
// for scalars/vector from the library binary. The reason is that scalar
// implementations mostly depend on <cmath> which pollutes global namespace,
// so we can't unconditionally include it from the SYCL headers. Vector
// overloads have to be implemented in the library next to scalar overloads in
// order to be vectorizable.
if constexpr ((... || is_marray_v<Ts>)) {
return builtin_marray_impl(F, x...);
} else {
return F(simplify_if_swizzle_t<Ts>{x}...);
}
}

template <typename FuncTy, typename... Ts>
auto builtin_delegate_to_scalar(FuncTy F, const Ts &...x) {
using T = typename first_type<Ts...>::type;
if constexpr (is_vec_or_swizzle_v<T>) {
using ret_elem_type = decltype(F(x[0]...));
// TODO: using r{} to avoid Werror. Not sure if ok.
vec<ret_elem_type, T::size()> r{};
loop<T::size()>([&](auto idx) { r[idx] = F(x[idx]...); });
return r;
} else {
static_assert(is_marray_v<T>);
return builtin_marray_impl(F, x...);
}
}

template <typename T>
struct any_elem_type
: std::bool_constant<check_type_in_v<
get_elem_type_t<T>, float, double, half, char, signed char, short,
int, long, long long, unsigned char, unsigned short, unsigned int,
unsigned long, unsigned long long>> {};
template <typename T>
struct fp_elem_type
: std::bool_constant<
check_type_in_v<get_elem_type_t<T>, float, double, half>> {};
template <typename T>
struct float_elem_type
: std::bool_constant<check_type_in_v<get_elem_type_t<T>, float>> {};
template <typename T>
struct integer_elem_type
: std::bool_constant<
check_type_in_v<get_elem_type_t<T>, char, signed char, short, int,
long, long long, unsigned char, unsigned short,
unsigned int, unsigned long, unsigned long long>> {};
template <typename T>
struct suint32_elem_type
: std::bool_constant<
check_type_in_v<get_elem_type_t<T>, int32_t, uint32_t>> {};

template <typename... Ts>
struct same_basic_shape : std::bool_constant<builtin_same_shape_v<Ts...>> {};

template <typename... Ts>
struct same_elem_type : std::bool_constant<same_basic_shape<Ts...>::value &&
all_same_v<get_elem_type_t<Ts>...>> {
};

template <typename> struct any_shape : std::true_type {};

template <typename T>
struct scalar_only : std::bool_constant<is_scalar_arithmetic_v<T>> {};

template <typename T>
struct non_scalar_only : std::bool_constant<!is_scalar_arithmetic_v<T>> {};

template <typename T> struct default_ret_type {
using type = T;
};

template <typename T> struct scalar_ret_type {
using type = get_elem_type_t<T>;
};

template <template <typename> typename RetTypeTrait,
template <typename> typename ElemTypeChecker,
template <typename> typename ShapeChecker,
template <typename...> typename ExtraConditions, typename... Ts>
struct builtin_enable
: std::enable_if<
ElemTypeChecker<typename first_type<Ts...>::type>::value &&
ShapeChecker<typename first_type<Ts...>::type>::value &&
ExtraConditions<Ts...>::value,
typename RetTypeTrait<
simplify_if_swizzle_t<typename first_type<Ts...>::type>>::type> {
};
#define BUILTIN_CREATE_ENABLER(NAME, RET_TYPE_TRAIT, ELEM_TYPE_CHECKER, \
SHAPE_CHECKER, EXTRA_CONDITIONS) \
namespace detail { \
template <typename... Ts> \
using NAME##_t = \
typename builtin_enable<RET_TYPE_TRAIT, ELEM_TYPE_CHECKER, \
SHAPE_CHECKER, EXTRA_CONDITIONS, Ts...>::type; \
}
} // namespace detail

BUILTIN_CREATE_ENABLER(builtin_enable_generic, default_ret_type, any_elem_type,
any_shape, same_elem_type)
BUILTIN_CREATE_ENABLER(builtin_enable_generic_scalar, default_ret_type,
any_elem_type, scalar_only, same_elem_type)
BUILTIN_CREATE_ENABLER(builtin_enable_generic_non_scalar, default_ret_type,
any_elem_type, non_scalar_only, same_elem_type)
} // namespace _V1
} // namespace sycl

// The headers below are specifically implemented without including all the
// necessary headers to allow preprocessing them on their own and providing
// human-friendly result. One can use a command like this to achieve that:
// clang++ -[DU]__SYCL_DEVICE_ONLY__ -x c++ math_functions.inc \
// -I <..>/llvm/sycl/include -E -o - \
// | grep -v '^#' | clang-format > math_functions.{host|device}.ii

#include <sycl/detail/builtins/common_functions.inc>
#include <sycl/detail/builtins/geometric_functions.inc>
#include <sycl/detail/builtins/half_precision_math_functions.inc>
#include <sycl/detail/builtins/integer_functions.inc>
#include <sycl/detail/builtins/math_functions.inc>
#include <sycl/detail/builtins/native_math_functions.inc>
#include <sycl/detail/builtins/relational_functions.inc>
103 changes: 103 additions & 0 deletions sycl/include/sycl/detail/builtins/common_functions.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//==------------------- common_functions.hpp -------------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// Intentionally insufficient set of includes and no "#pragma once".

#include <sycl/detail/builtins/helper_macros.hpp>

namespace sycl {
inline namespace _V1 {
BUILTIN_CREATE_ENABLER(builtin_enable_common, default_ret_type, fp_elem_type,
any_shape, same_elem_type)
BUILTIN_CREATE_ENABLER(builtin_enable_common_non_scalar, default_ret_type,
fp_elem_type, non_scalar_only, same_elem_type)

#ifdef __SYCL_DEVICE_ONLY__
#define BUILTIN_COMMON(NUM_ARGS, NAME, SPIRV_IMPL) \
DEVICE_IMPL_TEMPLATE(NUM_ARGS, NAME, builtin_enable_common_t, SPIRV_IMPL)
#else
#define BUILTIN_COMMON(NUM_ARGS, NAME, SPIRV_IMPL) \
HOST_IMPL_TEMPLATE(NUM_ARGS, NAME, builtin_enable_common_t, common, \
default_ret_type)
#endif

BUILTIN_COMMON(ONE_ARG, degrees, __spirv_ocl_degrees)
BUILTIN_COMMON(ONE_ARG, radians, __spirv_ocl_radians)
BUILTIN_COMMON(ONE_ARG, sign, __spirv_ocl_sign)

BUILTIN_COMMON(THREE_ARGS, mix, __spirv_ocl_mix)
template <typename T0, typename T1>
detail::builtin_enable_common_non_scalar_t<T0, T1>
mix(T0 x, T1 y, detail::get_elem_type_t<T0> z) {
return mix(detail::simplify_if_swizzle_t<T0>{x},
detail::simplify_if_swizzle_t<T0>{y},
detail::simplify_if_swizzle_t<T0>{z});
}

BUILTIN_COMMON(TWO_ARGS, step, __spirv_ocl_step)
template <typename T>
detail::builtin_enable_common_non_scalar_t<T> step(detail::get_elem_type_t<T> x,
T y) {
return step(detail::simplify_if_swizzle_t<T>{x},
detail::simplify_if_swizzle_t<T>{y});
}

BUILTIN_COMMON(THREE_ARGS, smoothstep, __spirv_ocl_smoothstep)
template <typename T>
detail::builtin_enable_common_non_scalar_t<T>
smoothstep(detail::get_elem_type_t<T> x, detail::get_elem_type_t<T> y, T z) {
return smoothstep(detail::simplify_if_swizzle_t<T>{x},
detail::simplify_if_swizzle_t<T>{y},
detail::simplify_if_swizzle_t<T>{z});
}

BUILTIN_COMMON(TWO_ARGS, max, __spirv_ocl_fmax_common)
template <typename T>
detail::builtin_enable_common_non_scalar_t<T>
max(T x, detail::get_elem_type_t<T> y) {
return max(detail::simplify_if_swizzle_t<T>{x},
detail::simplify_if_swizzle_t<T>{y});
}

BUILTIN_COMMON(TWO_ARGS, min, __spirv_ocl_fmin_common)
template <typename T>
detail::builtin_enable_common_non_scalar_t<T>
min(T x, detail::get_elem_type_t<T> y) {
return min(detail::simplify_if_swizzle_t<T>{x},
detail::simplify_if_swizzle_t<T>{y});
}

#undef BUILTIN_COMMON

#ifdef __SYCL_DEVICE_ONLY__
DEVICE_IMPL_TEMPLATE(THREE_ARGS, clamp, builtin_enable_generic_t,
[](auto... xs) {
using ElemTy = detail::get_elem_type_t<T0>;
if constexpr (std::is_integral_v<ElemTy>) {
if constexpr (std::is_signed_v<ElemTy>) {
return __spirv_ocl_s_clamp(xs...);
} else {
return __spirv_ocl_u_clamp(xs...);
}
} else {
return __spirv_ocl_fclamp(xs...);
}
})
#else
HOST_IMPL_TEMPLATE(THREE_ARGS, clamp, builtin_enable_generic_t, common,
default_ret_type)
#endif
template <typename T>
detail::builtin_enable_generic_non_scalar_t<T>
clamp(T x, detail::get_elem_type_t<T> y, detail::get_elem_type_t<T> z) {
return clamp(detail::simplify_if_swizzle_t<T>{x},
detail::simplify_if_swizzle_t<T>{y},
detail::simplify_if_swizzle_t<T>{z});
}
} // namespace _V1
} // namespace sycl
Loading

0 comments on commit 7e9819d

Please sign in to comment.