Skip to content

Commit 30a2dae

Browse files
authored
Revert "[SYCL][ESIMD]Limit bfloat16 operators to scalars to enable operations with simd vectors (#12089)" (#12462)
This reverts commit 8c92df9 to address test failures
1 parent dc37ee4 commit 30a2dae

File tree

5 files changed

+61
-282
lines changed

5 files changed

+61
-282
lines changed

sycl/include/sycl/ext/intel/esimd/detail/bfloat16_type_traits.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ inline std::ostream &operator<<(std::ostream &O, bfloat16 const &rhs) {
9494
return O;
9595
}
9696

97-
template <> struct is_esimd_arithmetic_type<bfloat16, void> : std::true_type {};
98-
9997
} // namespace ext::intel::esimd::detail
10098
} // namespace _V1
10199
} // namespace sycl

sycl/include/sycl/ext/oneapi/bfloat16.hpp

Lines changed: 58 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -132,175 +132,69 @@ class bfloat16 {
132132
#endif
133133
}
134134

135-
bfloat16 &operator+=(const bfloat16 &rhs) {
136-
value = from_float(to_float(value) + to_float(rhs.value));
137-
return *this;
138-
}
139-
140-
bfloat16 &operator-=(const bfloat16 &rhs) {
141-
value = from_float(to_float(value) - to_float(rhs.value));
142-
return *this;
143-
}
144-
145-
bfloat16 &operator*=(const bfloat16 &rhs) {
146-
value = from_float(to_float(value) * to_float(rhs.value));
147-
return *this;
148-
}
149-
150-
bfloat16 &operator/=(const bfloat16 &rhs) {
151-
value = from_float(to_float(value) / to_float(rhs.value));
152-
return *this;
153-
}
154-
155-
// Operator ++, --
156-
bfloat16 &operator++() {
157-
float f = to_float(value);
158-
value = from_float(++f);
159-
return *this;
160-
}
161-
162-
bfloat16 operator++(int) {
163-
bfloat16 ret(*this);
164-
operator++();
165-
return ret;
166-
}
167-
168-
bfloat16 &operator--() {
169-
float f = to_float(value);
170-
value = from_float(--f);
171-
return *this;
172-
}
173-
174-
bfloat16 operator--(int) {
175-
bfloat16 ret(*this);
176-
operator--();
177-
return ret;
178-
}
179-
180-
// Operator +, -, *, /
135+
// Increment and decrement operators overloading
181136
#define OP(op) \
182-
friend bfloat16 operator op(const bfloat16 lhs, const bfloat16 rhs) { \
183-
return to_float(lhs.value) op to_float(rhs.value); \
184-
} \
185-
friend double operator op(const bfloat16 lhs, const double rhs) { \
186-
return to_float(lhs.value) op rhs; \
187-
} \
188-
friend double operator op(const double lhs, const bfloat16 rhs) { \
189-
return lhs op to_float(rhs.value); \
190-
} \
191-
friend float operator op(const bfloat16 lhs, const float rhs) { \
192-
return to_float(lhs.value) op rhs; \
193-
} \
194-
friend float operator op(const float lhs, const bfloat16 rhs) { \
195-
return lhs op to_float(rhs.value); \
196-
} \
197-
friend bfloat16 operator op(const bfloat16 lhs, const int rhs) { \
198-
return to_float(lhs.value) op rhs; \
199-
} \
200-
friend bfloat16 operator op(const int lhs, const bfloat16 rhs) { \
201-
return lhs op to_float(rhs.value); \
202-
} \
203-
friend bfloat16 operator op(const bfloat16 lhs, const long rhs) { \
204-
return to_float(lhs.value) op rhs; \
205-
} \
206-
friend bfloat16 operator op(const long lhs, const bfloat16 rhs) { \
207-
return lhs op to_float(rhs.value); \
208-
} \
209-
friend bfloat16 operator op(const bfloat16 lhs, const long long rhs) { \
210-
return to_float(lhs.value) op rhs; \
211-
} \
212-
friend bfloat16 operator op(const long long lhs, const bfloat16 rhs) { \
213-
return lhs op to_float(rhs.value); \
214-
} \
215-
friend bfloat16 operator op(const bfloat16 &lhs, const unsigned int &rhs) { \
216-
return to_float(lhs.value) op rhs; \
217-
} \
218-
friend bfloat16 operator op(const unsigned int &lhs, const bfloat16 &rhs) { \
219-
return lhs op to_float(rhs.value); \
220-
} \
221-
friend bfloat16 operator op(const bfloat16 &lhs, const unsigned long &rhs) { \
222-
return to_float(lhs.value) op rhs; \
223-
} \
224-
friend bfloat16 operator op(const unsigned long &lhs, const bfloat16 &rhs) { \
225-
return lhs op to_float(rhs.value); \
226-
} \
227-
friend bfloat16 operator op(const bfloat16 &lhs, \
228-
const unsigned long long &rhs) { \
229-
return to_float(lhs.value) op rhs; \
230-
} \
231-
friend bfloat16 operator op(const unsigned long long &lhs, \
232-
const bfloat16 &rhs) { \
233-
return lhs op to_float(rhs.value); \
234-
}
235-
OP(+)
236-
OP(-)
237-
OP(*)
238-
OP(/)
239-
137+
friend bfloat16 &operator op(bfloat16 & lhs) { \
138+
float f = to_float(lhs.value); \
139+
lhs.value = from_float(op f); \
140+
return lhs; \
141+
} \
142+
friend bfloat16 operator op(bfloat16 &lhs, int) { \
143+
bfloat16 old = lhs; \
144+
operator op(lhs); \
145+
return old; \
146+
}
147+
OP(++)
148+
OP(--)
240149
#undef OP
241150

242-
// Operator ==, !=, <, >, <=, >=
151+
// Assignment operators overloading
243152
#define OP(op) \
244-
friend bool operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
245-
return to_float(lhs.value) op to_float(rhs.value); \
246-
} \
247-
friend bool operator op(const bfloat16 &lhs, const double &rhs) { \
248-
return to_float(lhs.value) op rhs; \
249-
} \
250-
friend bool operator op(const double &lhs, const bfloat16 &rhs) { \
251-
return lhs op to_float(rhs.value); \
252-
} \
253-
friend bool operator op(const bfloat16 &lhs, const float &rhs) { \
254-
return to_float(lhs.value) op rhs; \
255-
} \
256-
friend bool operator op(const float &lhs, const bfloat16 &rhs) { \
257-
return lhs op to_float(rhs.value); \
258-
} \
259-
friend bool operator op(const bfloat16 &lhs, const int &rhs) { \
260-
return to_float(lhs.value) op rhs; \
261-
} \
262-
friend bool operator op(const int &lhs, const bfloat16 &rhs) { \
263-
return lhs op to_float(rhs.value); \
264-
} \
265-
friend bool operator op(const bfloat16 &lhs, const long &rhs) { \
266-
return to_float(lhs.value) op rhs; \
267-
} \
268-
friend bool operator op(const long &lhs, const bfloat16 &rhs) { \
269-
return lhs op to_float(rhs.value); \
270-
} \
271-
friend bool operator op(const bfloat16 &lhs, const long long &rhs) { \
272-
return to_float(lhs.value) op rhs; \
273-
} \
274-
friend bool operator op(const long long &lhs, const bfloat16 &rhs) { \
275-
return lhs op to_float(rhs.value); \
276-
} \
277-
friend bool operator op(const bfloat16 &lhs, const unsigned int &rhs) { \
278-
return to_float(lhs.value) op rhs; \
279-
} \
280-
friend bool operator op(const unsigned int &lhs, const bfloat16 &rhs) { \
281-
return lhs op to_float(rhs.value); \
282-
} \
283-
friend bool operator op(const bfloat16 &lhs, const unsigned long &rhs) { \
284-
return to_float(lhs.value) op rhs; \
285-
} \
286-
friend bool operator op(const unsigned long &lhs, const bfloat16 &rhs) { \
287-
return lhs op to_float(rhs.value); \
288-
} \
289-
friend bool operator op(const bfloat16 &lhs, \
290-
const unsigned long long &rhs) { \
291-
return to_float(lhs.value) op rhs; \
292-
} \
293-
friend bool operator op(const unsigned long long &lhs, \
294-
const bfloat16 &rhs) { \
295-
return lhs op to_float(rhs.value); \
296-
}
297-
OP(==)
298-
OP(!=)
299-
OP(<)
300-
OP(>)
301-
OP(<=)
302-
OP(>=)
153+
friend bfloat16 &operator op(bfloat16 & lhs, const bfloat16 & rhs) { \
154+
float f = static_cast<float>(lhs); \
155+
f op static_cast<float>(rhs); \
156+
return lhs = f; \
157+
} \
158+
template <typename T> \
159+
friend bfloat16 &operator op(bfloat16 & lhs, const T & rhs) { \
160+
float f = static_cast<float>(lhs); \
161+
f op static_cast<float>(rhs); \
162+
return lhs = f; \
163+
} \
164+
template <typename T> friend T &operator op(T & lhs, const bfloat16 & rhs) { \
165+
float f = static_cast<float>(lhs); \
166+
f op static_cast<float>(rhs); \
167+
return lhs = f; \
168+
}
169+
OP(+=)
170+
OP(-=)
171+
OP(*=)
172+
OP(/=)
173+
#undef OP
303174

175+
// Binary operators overloading
176+
#define OP(type, op) \
177+
friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \
178+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
179+
} \
180+
template <typename T> \
181+
friend type operator op(const bfloat16 &lhs, const T &rhs) { \
182+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
183+
} \
184+
template <typename T> \
185+
friend type operator op(const T &lhs, const bfloat16 &rhs) { \
186+
return type{static_cast<float>(lhs) op static_cast<float>(rhs)}; \
187+
}
188+
OP(bfloat16, +)
189+
OP(bfloat16, -)
190+
OP(bfloat16, *)
191+
OP(bfloat16, /)
192+
OP(bool, ==)
193+
OP(bool, !=)
194+
OP(bool, <)
195+
OP(bool, >)
196+
OP(bool, <=)
197+
OP(bool, >=)
304198
#undef OP
305199

306200
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported

sycl/test-e2e/ESIMD/regression/bfloat16_half_vector_plus_eq_scalar.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,12 @@ int main() {
9191
}
9292

9393
#ifdef USE_BF16
94-
Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
94+
// TODO: Reenable once the issue with bfloat16 is resolved
95+
// Passed &= test<sycl::ext::oneapi::bfloat16>(Q);
9596
#endif
9697
#ifdef USE_TF32
9798
Passed &= test<sycl::ext::intel::experimental::esimd::tfloat32>(Q);
9899
#endif
99100
std::cout << (Passed ? "Passed\n" : "FAILED\n");
100101
return Passed ? 0 : 1;
101-
}
102+
}

sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar.cpp

Lines changed: 0 additions & 100 deletions
This file was deleted.

sycl/test-e2e/ESIMD/regression/bfloat16_vector_plus_scalar_pvc.cpp

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)