Skip to content

Commit

Permalink
[SYCL] Add operator overloading for aggregate types in annotated_ref (#…
Browse files Browse the repository at this point in the history
…11971)

Added several operator overloading for aggregate types for annotated_ref
class.

This https://godbolt.org/z/h5cTTr17K would be a good example to show why
this is not working without this fix.

This PR includes several changes:
1. Propogate operators including all binaries and unary operators,
including arithmetic, comparator, and logical.
2. Using perfecting forwarding for binaries operators, compound
operators, and unary operators.

This covers cases in which the sequence of conversions will be correct
when implicit conversion is involved. i.e
```
annotated_ref<int> a;
double b;
auto p = a + b; // expected to be (double)a + b, and p should be double
```
without this fix, 
```
T operator+(T a) const;
annotated_ref<int> a;
double b;
auto p = a + b; // this become a + (int)b, and p will be int
```

---------

Co-authored-by: Roland Schulz <roland.schulz@intel.com>
  • Loading branch information
Brox Chen and rolandschulz authored Dec 7, 2023
1 parent 19c17e2 commit 4ab007d
Show file tree
Hide file tree
Showing 4 changed files with 560 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,32 @@ class annotated_ref {
public:
annotated_ref(const annotated_ref&) = delete;
operator T() const;
T operator=(T) const;

template <typename O> //available only if O is not an annotated_ref type
T operator=(O&&) const;
T operator=(const annotated_ref&) const;

// OP is: +=, -=, *=, /=, %=, <<=, >>=, &=, |=, ^=
T operatorOP(T) const;
template <typename O> //available only if O is not an annotated_ref type
T operatorOP(O&& a) const;
T operatorOP(const annotated_ref &b) const;

T operator++() const;
T operator++(int) const;
T operator--() const;
T operator--(int) const;

// OP is: +, -, *, /, %, <<, >>, &, |, ^, <, <=, >, >=, ==, ~=, &&, ||
template <typename O>
auto friend operatorOP(O&& a, const annotated_ref& b) ->
decltype(std::forward<O>(a) OP std::declval<T>());
template <typename O> //available only if O is not an annotated_ref type
auto friend operatorOP(const annotated_ref& a, O&& b) ->
decltype(std::declval<T>() OP std::forward<O>(b));

// OP is: +, -, !, ~
template <typename O=T>
auto operatorOP() -> decltype(OP std::declval<O>());
};
} // namespace sycl::ext::oneapi::experimental
```
Expand All @@ -581,10 +599,13 @@ annotations when the object is loaded from memory.
a|
[source,c++]
----
T operator=(T) const;
template <typename O>
T operator=(O&&) const;
----
|
Writes an object of type `T` to the location referenced by this wrapper,
Writes an object of type `O` to the location referenced by this wrapper.
`O` cannot be a type of `annotated_ref`.

applying the annotations when the object is stored to memory.

// --- ROW BREAK ---
Expand All @@ -608,18 +629,47 @@ Does not rebind the reference!
a|
[source,c++]
----
T operatorOP(T) const;
template <typename O>
T operatorOP(O&& a) const;
----
a|
Where [code]#OP# is: [code]#pass:[+=]#, [code]#-=#,[code]#*=#, [code]#/=#, [code]#%=#, [code]#+<<=+#, [code]#>>=#, [code]#&=#, [code]#\|=#, [code]#^=#.

Compound assignment operators. Return result by value.
Compound assignment operators for type `O`. `O` cannot be a type of `annotated_ref`.

Return result by value.

Available only if the corresponding assignment operator OP is available for `T` taking a type of `O`.
Equivalent to:
```c++
T tmp = *this; // Reads from memory
// with annotations
tmp OP std::forward<O>(a);
*this = tmp; // Writes to memory
// with annotations
return tmp;
```
// --- ROW BREAK ---
a|
[source,c++]
----
T operatorOP(const annotated_ref &b) const;
----
a|
Where [code]#OP# is: [code]#pass:[+=]#, [code]#-=#,[code]#*=#, [code]#/=#, [code]#%=#, [code]#+<<=+#, [code]#>>=#, [code]#&=#, [code]#\|=#, [code]#^=#.

Compound assignment operators for type `annotated_ref`.

Return result by value.

Available only if the corresponding assignment operator OP is available for `T`.
Equivalent to:
```c++
T tmp = *this; // Reads from memory
// with annotations
tmp OP val;
T tmp2 = b; // Reads from memory
// with annotations
tmp OP b;
*this = tmp; // Writes to memory
// with annotations
return tmp;
Expand All @@ -638,6 +688,64 @@ Increment and decrement operator of annotated_ref. Increment/Decrement the objec
referenced by this wrapper via ``T``'s Increment/Decrement operator.

The annotations are applied when the object `T` is loaded and stored to the memory.

a|
[source,c++]
----
template <typename O>
auto friend operatorOP(O&& a, const annotated_ref& b) ->
decltype(std::forward<O>(a) OP std::declval<T>());
----
a|
Where [code]#OP# is: [code]#pass:[+]#, [code]#-#,[code]#*#, [code]#/#, [code]#%#, [code]#+<<+#, [code]#>>#, [code]#&#, [code]#\|#, [code]#\^#, [code]#<#, [code]#<=#, [code]#>#, [code]#>=#, [code]#==#, [code]#!=#, [code]#&&#, [code]#\|\|#.

Defines a hidden friend operator `OP` overload for type `O` and `annotated_ref`.

Let `operatorOP` denotes the operator used. The overloaded operator `operatorOP` utilizes
`operatorOP(O&&, T&&)` and is available only if `operatorOP(O&&, T&&)` is well formed. The value and result
is the same as the result of `operatorOP(O&&, T&&)` applied to the objects of
type `O` and `T`.

The annotations from `PropertyListT` are applied when the object `b` is loaded from memory.

a|
[source,c++]
----
template <typename O>
auto friend operatorOP(const annotated_ref& a, O&& b) ->
decltype(std::declval<T>() OP std::forward<O>(b));
----
a|
Where [code]#OP# is: [code]#pass:[+]#, [code]#-#,[code]#*#, [code]#/#, [code]#%#, [code]#+<<+#, [code]#>>#, [code]#&#, [code]#\|#, [code]#\^#, [code]#<#, [code]#<=#, [code]#>#, [code]#>=#, [code]#==#, [code]#!=#, [code]#&&#, [code]#\|\|#.

Defines a hidden friend operator `OP` overload for type `annotated_ref` and `O`. `O` cannot be
a type of `annotated_ref`.

Let `operatorOP` denotes the operator used. The overloaded operator `operatorOP` utilizes
`operatorOP(T&&, O&&)` and is available only if `operatorOP(T&&, O&&)` is well formed. The value and result
is the same as the result of `operatorOP(T&&, O&&)` applied to the objects of
type `T` and `O`.

The annotations from `PropertyListT` are applied when the object `a` is loaded from memory.

a|
[source,c++]
----
template <typename O=T>
auto operatorOP() -> decltype(OP std::declval<O>());
----
a|
Where [code]#OP# is: [code]#pass:[+]#, [code]#-#, [code]#!#, [code]#~#.

Defines a operator `OP` overload for types `O` where the default type is `T`.

Let `operatorOP` denotes the operator used. The overloaded operator
`operatorOP` utilizes `operatorOP(O)` and is available only if `operatorOP(O)`
is well formed. The value and result is the same as the result of `operatorOP(O)`
applied to the objects of type `O`.

The annotations from `PropertyListT` are applied when the object `a` is loaded from memory.

|===

== Issues related to `annotated_ptr`
Expand Down Expand Up @@ -685,6 +793,7 @@ the alignment is set up.
[options="header"]
|========================================
|Rev|Date|Author|Changes
|5|2023-11-30|Brox Chen|API fixes: operators fowarding for annnotated_ref
|4|2023-06-28|Roland Schulz|API fixes: constructors and annotated_ref assignment
|3|2022-04-05|Abhishek Tiwari|*Addressed review comments*
|2|2022-03-07|Abhishek Tiwari|*Corrected API and updated description*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//
//==----------- annotated_ptr.hpp - SYCL annotated_ptr extension -----------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
Expand Down Expand Up @@ -31,14 +32,6 @@ namespace oneapi {
namespace experimental {

namespace {
#define PROPAGATE_OP(op) \
T operator op##=(T rhs) const { \
T t = *this; \
t op## = rhs; \
*this = t; \
return t; \
}

// compare strings on compile time
constexpr bool compareStrs(const char *Str1, const char *Str2) {
return std::string_view(Str1) == Str2;
Expand Down Expand Up @@ -66,6 +59,7 @@ struct PropertiesFilter {
std::tuple<>>::type...>;
};
} // namespace

template <typename T, typename PropertyListT = empty_properties_t>
class annotated_ref {
// This should always fail when instantiating the unspecialized version.
Expand All @@ -74,6 +68,17 @@ class annotated_ref {
static_assert(is_valid_property_list, "Property list is invalid.");
};

namespace detail {
template <class T> struct is_ann_ref_impl : std::false_type {};
template <class T, class P>
struct is_ann_ref_impl<annotated_ref<T, P>> : std::true_type {};
template <class T, class P>
struct is_ann_ref_impl<const annotated_ref<T, P>> : std::true_type {};
template <class T>
constexpr bool is_ann_ref_v =
is_ann_ref_impl<std::remove_reference_t<T>>::value;
} // namespace detail

template <typename T, typename... Props>
class annotated_ref<T, detail::properties_t<Props...>> {
using property_list_t = detail::properties_t<Props...>;
Expand All @@ -84,11 +89,12 @@ class annotated_ref<T, detail::properties_t<Props...>> {

private:
T *m_Ptr;
annotated_ref(T *Ptr) : m_Ptr(Ptr) {}
explicit annotated_ref(T *Ptr) : m_Ptr(Ptr) {}

public:
annotated_ref(const annotated_ref &) = delete;

// implicit conversion with annotaion
operator T() const {
#ifdef __SYCL_DEVICE_ONLY__
return *__builtin_intel_sycl_ptr_annotation(
Expand All @@ -99,30 +105,100 @@ class annotated_ref<T, detail::properties_t<Props...>> {
#endif
}

T operator=(T Obj) const {
// assignment operator with annotaion
template <class O, typename = std::enable_if_t<!detail::is_ann_ref_v<O>>>
T operator=(O &&Obj) const {
#ifdef __SYCL_DEVICE_ONLY__
*__builtin_intel_sycl_ptr_annotation(
m_Ptr, detail::PropertyMetaInfo<Props>::name...,
detail::PropertyMetaInfo<Props>::value...) = Obj;
return *__builtin_intel_sycl_ptr_annotation(
m_Ptr, detail::PropertyMetaInfo<Props>::name...,
detail::PropertyMetaInfo<Props>::value...) =
std::forward<O>(Obj);
#else
*m_Ptr = Obj;
return *m_Ptr = std::forward<O>(Obj);
#endif
return Obj;
}

T operator=(const annotated_ref &Ref) const { return *this = T(Ref); }
template <class O, class P>
T operator=(const annotated_ref<O, P> &Ref) const {
O t2 = Ref;
return *this = t2;
}

// propagate compound operators
#define PROPAGATE_OP(op) \
template <class O, typename = std::enable_if_t<!detail::is_ann_ref_v<O>>> \
T operator op(O &&rhs) const { \
T t = *this; \
t op std::forward<O>(rhs); \
*this = t; \
return t; \
} \
template <class O, class P> \
T operator op(const annotated_ref<O, P> &rhs) const { \
T t = *this; \
O t2 = rhs; \
t op t2; \
*this = t; \
return t; \
}
PROPAGATE_OP(+=)
PROPAGATE_OP(-=)
PROPAGATE_OP(*=)
PROPAGATE_OP(/=)
PROPAGATE_OP(%=)
PROPAGATE_OP(^=)
PROPAGATE_OP(&=)
PROPAGATE_OP(|=)
PROPAGATE_OP(<<=)
PROPAGATE_OP(>>=)
#undef PROPAGATE_OP

// propagate binary operators
#define PROPAGATE_OP(op) \
template <class O> \
friend auto operator op(O &&a, const annotated_ref &b) \
->decltype(std::forward<O>(a) op std::declval<T>()) { \
return std::forward<O>(a) op T(b); \
} \
template <class O, typename = std::enable_if_t<!detail::is_ann_ref_v<O>>> \
friend auto operator op(const annotated_ref &a, O &&b) \
->decltype(std::declval<T>() op std::forward<O>(b)) { \
return T(a) op std::forward<O>(b); \
}
PROPAGATE_OP(+)
PROPAGATE_OP(-)
PROPAGATE_OP(*)
PROPAGATE_OP(/)
PROPAGATE_OP(%)
PROPAGATE_OP(^)
PROPAGATE_OP(&)
PROPAGATE_OP(|)
PROPAGATE_OP(&)
PROPAGATE_OP(^)
PROPAGATE_OP(<<)
PROPAGATE_OP(>>)
PROPAGATE_OP(<)
PROPAGATE_OP(<=)
PROPAGATE_OP(>)
PROPAGATE_OP(>=)
PROPAGATE_OP(==)
PROPAGATE_OP(!=)
PROPAGATE_OP(&&)
PROPAGATE_OP(||)
#undef PROPAGATE_OP

// Propagate unary operators
// by setting a default template we get SFINAE to kick in
#define PROPAGATE_OP(op) \
template <typename O = T> \
auto operator op() const->decltype(op std::declval<O>()) { \
return op O(*this); \
}
PROPAGATE_OP(+)
PROPAGATE_OP(-)
PROPAGATE_OP(!)
PROPAGATE_OP(~)
#undef PROPAGATE_OP

// Propagate inc/dec operators
T operator++() const {
T t = *this;
++t;
Expand Down Expand Up @@ -156,8 +232,6 @@ class annotated_ref<T, detail::properties_t<Props...>> {
template <class T2, class P2> friend class annotated_ptr;
};

#undef PROPAGATE_OP

#ifdef __cpp_deduction_guides
template <typename T, typename... Args>
annotated_ptr(T *, Args...)
Expand Down
Loading

0 comments on commit 4ab007d

Please sign in to comment.