@@ -132,175 +132,69 @@ class bfloat16 {
132
132
#endif
133
133
}
134
134
135
- bfloat16 &operator +=(const bfloat16 &rhs) {
136
- value = from_float (to_float (value) + to_float (rhs.value ));
137
- return *this ;
138
- }
139
-
140
- bfloat16 &operator -=(const bfloat16 &rhs) {
141
- value = from_float (to_float (value) - to_float (rhs.value ));
142
- return *this ;
143
- }
144
-
145
- bfloat16 &operator *=(const bfloat16 &rhs) {
146
- value = from_float (to_float (value) * to_float (rhs.value ));
147
- return *this ;
148
- }
149
-
150
- bfloat16 &operator /=(const bfloat16 &rhs) {
151
- value = from_float (to_float (value) / to_float (rhs.value ));
152
- return *this ;
153
- }
154
-
155
- // Operator ++, --
156
- bfloat16 &operator ++() {
157
- float f = to_float (value);
158
- value = from_float (++f);
159
- return *this ;
160
- }
161
-
162
- bfloat16 operator ++(int ) {
163
- bfloat16 ret (*this );
164
- operator ++();
165
- return ret;
166
- }
167
-
168
- bfloat16 &operator --() {
169
- float f = to_float (value);
170
- value = from_float (--f);
171
- return *this ;
172
- }
173
-
174
- bfloat16 operator --(int ) {
175
- bfloat16 ret (*this );
176
- operator --();
177
- return ret;
178
- }
179
-
180
- // Operator +, -, *, /
135
+ // Increment and decrement operators overloading
181
136
#define OP (op ) \
182
- friend bfloat16 operator op (const bfloat16 lhs, const bfloat16 rhs) { \
183
- return to_float (lhs.value ) op to_float (rhs.value ); \
184
- } \
185
- friend double operator op (const bfloat16 lhs, const double rhs) { \
186
- return to_float (lhs.value ) op rhs; \
187
- } \
188
- friend double operator op (const double lhs, const bfloat16 rhs) { \
189
- return lhs op to_float (rhs.value ); \
190
- } \
191
- friend float operator op (const bfloat16 lhs, const float rhs) { \
192
- return to_float (lhs.value ) op rhs; \
193
- } \
194
- friend float operator op (const float lhs, const bfloat16 rhs) { \
195
- return lhs op to_float (rhs.value ); \
196
- } \
197
- friend bfloat16 operator op (const bfloat16 lhs, const int rhs) { \
198
- return to_float (lhs.value ) op rhs; \
199
- } \
200
- friend bfloat16 operator op (const int lhs, const bfloat16 rhs) { \
201
- return lhs op to_float (rhs.value ); \
202
- } \
203
- friend bfloat16 operator op (const bfloat16 lhs, const long rhs) { \
204
- return to_float (lhs.value ) op rhs; \
205
- } \
206
- friend bfloat16 operator op (const long lhs, const bfloat16 rhs) { \
207
- return lhs op to_float (rhs.value ); \
208
- } \
209
- friend bfloat16 operator op (const bfloat16 lhs, const long long rhs) { \
210
- return to_float (lhs.value ) op rhs; \
211
- } \
212
- friend bfloat16 operator op (const long long lhs, const bfloat16 rhs) { \
213
- return lhs op to_float (rhs.value ); \
214
- } \
215
- friend bfloat16 operator op (const bfloat16 &lhs, const unsigned int &rhs) { \
216
- return to_float (lhs.value ) op rhs; \
217
- } \
218
- friend bfloat16 operator op (const unsigned int &lhs, const bfloat16 &rhs) { \
219
- return lhs op to_float (rhs.value ); \
220
- } \
221
- friend bfloat16 operator op (const bfloat16 &lhs, const unsigned long &rhs) { \
222
- return to_float (lhs.value ) op rhs; \
223
- } \
224
- friend bfloat16 operator op (const unsigned long &lhs, const bfloat16 &rhs) { \
225
- return lhs op to_float (rhs.value ); \
226
- } \
227
- friend bfloat16 operator op (const bfloat16 &lhs, \
228
- const unsigned long long &rhs) { \
229
- return to_float (lhs.value ) op rhs; \
230
- } \
231
- friend bfloat16 operator op (const unsigned long long &lhs, \
232
- const bfloat16 &rhs) { \
233
- return lhs op to_float (rhs.value ); \
234
- }
235
- OP (+)
236
- OP (-)
237
- OP (*)
238
- OP (/)
239
-
137
+ friend bfloat16 &operator op (bfloat16 & lhs) { \
138
+ float f = to_float (lhs.value ); \
139
+ lhs.value = from_float (op f); \
140
+ return lhs; \
141
+ } \
142
+ friend bfloat16 operator op (bfloat16 &lhs, int ) { \
143
+ bfloat16 old = lhs; \
144
+ operator op (lhs); \
145
+ return old; \
146
+ }
147
+ OP (++)
148
+ OP (--)
240
149
#undef OP
241
150
242
- // Operator ==, !=, <, >, <=, >=
151
+ // Assignment operators overloading
243
152
#define OP (op ) \
244
- friend bool operator op (const bfloat16 &lhs, const bfloat16 &rhs) { \
245
- return to_float (lhs.value ) op to_float (rhs.value ); \
246
- } \
247
- friend bool operator op (const bfloat16 &lhs, const double &rhs) { \
248
- return to_float (lhs.value ) op rhs; \
249
- } \
250
- friend bool operator op (const double &lhs, const bfloat16 &rhs) { \
251
- return lhs op to_float (rhs.value ); \
252
- } \
253
- friend bool operator op (const bfloat16 &lhs, const float &rhs) { \
254
- return to_float (lhs.value ) op rhs; \
255
- } \
256
- friend bool operator op (const float &lhs, const bfloat16 &rhs) { \
257
- return lhs op to_float (rhs.value ); \
258
- } \
259
- friend bool operator op (const bfloat16 &lhs, const int &rhs) { \
260
- return to_float (lhs.value ) op rhs; \
261
- } \
262
- friend bool operator op (const int &lhs, const bfloat16 &rhs) { \
263
- return lhs op to_float (rhs.value ); \
264
- } \
265
- friend bool operator op (const bfloat16 &lhs, const long &rhs) { \
266
- return to_float (lhs.value ) op rhs; \
267
- } \
268
- friend bool operator op (const long &lhs, const bfloat16 &rhs) { \
269
- return lhs op to_float (rhs.value ); \
270
- } \
271
- friend bool operator op (const bfloat16 &lhs, const long long &rhs) { \
272
- return to_float (lhs.value ) op rhs; \
273
- } \
274
- friend bool operator op (const long long &lhs, const bfloat16 &rhs) { \
275
- return lhs op to_float (rhs.value ); \
276
- } \
277
- friend bool operator op (const bfloat16 &lhs, const unsigned int &rhs) { \
278
- return to_float (lhs.value ) op rhs; \
279
- } \
280
- friend bool operator op (const unsigned int &lhs, const bfloat16 &rhs) { \
281
- return lhs op to_float (rhs.value ); \
282
- } \
283
- friend bool operator op (const bfloat16 &lhs, const unsigned long &rhs) { \
284
- return to_float (lhs.value ) op rhs; \
285
- } \
286
- friend bool operator op (const unsigned long &lhs, const bfloat16 &rhs) { \
287
- return lhs op to_float (rhs.value ); \
288
- } \
289
- friend bool operator op (const bfloat16 &lhs, \
290
- const unsigned long long &rhs) { \
291
- return to_float (lhs.value ) op rhs; \
292
- } \
293
- friend bool operator op (const unsigned long long &lhs, \
294
- const bfloat16 &rhs) { \
295
- return lhs op to_float (rhs.value ); \
296
- }
297
- OP (==)
298
- OP (!=)
299
- OP (<)
300
- OP (>)
301
- OP (<=)
302
- OP (>=)
153
+ friend bfloat16 &operator op (bfloat16 & lhs, const bfloat16 & rhs) { \
154
+ float f = static_cast <float >(lhs); \
155
+ f op static_cast <float >(rhs); \
156
+ return lhs = f; \
157
+ } \
158
+ template <typename T> \
159
+ friend bfloat16 &operator op (bfloat16 & lhs, const T & rhs) { \
160
+ float f = static_cast <float >(lhs); \
161
+ f op static_cast <float >(rhs); \
162
+ return lhs = f; \
163
+ } \
164
+ template <typename T> friend T &operator op (T & lhs, const bfloat16 & rhs) { \
165
+ float f = static_cast <float >(lhs); \
166
+ f op static_cast <float >(rhs); \
167
+ return lhs = f; \
168
+ }
169
+ OP (+=)
170
+ OP (-=)
171
+ OP (*=)
172
+ OP (/=)
173
+ #undef OP
303
174
175
+ // Binary operators overloading
176
+ #define OP (type, op ) \
177
+ friend type operator op (const bfloat16 &lhs, const bfloat16 &rhs) { \
178
+ return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
179
+ } \
180
+ template <typename T> \
181
+ friend type operator op (const bfloat16 &lhs, const T &rhs) { \
182
+ return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
183
+ } \
184
+ template <typename T> \
185
+ friend type operator op (const T &lhs, const bfloat16 &rhs) { \
186
+ return type{static_cast <float >(lhs) op static_cast <float >(rhs)}; \
187
+ }
188
+ OP (bfloat16, +)
189
+ OP (bfloat16, -)
190
+ OP (bfloat16, *)
191
+ OP (bfloat16, /)
192
+ OP (bool , ==)
193
+ OP (bool , !=)
194
+ OP (bool , <)
195
+ OP (bool , >)
196
+ OP (bool , <=)
197
+ OP (bool , >=)
304
198
#undef OP
305
199
306
200
// Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported
0 commit comments