@@ -60,6 +60,7 @@ static vector<vector<int>> GetInputs_() {
60
60
61
61
vector<double > benchmarkTimes;
62
62
63
+ template <typename OutType>
63
64
int run_benchmark (
64
65
int batch_size,
65
66
int num_rows,
@@ -68,7 +69,8 @@ int run_benchmark(
68
69
bool normalize_by_lengths,
69
70
bool use_32_bit_indices = false ,
70
71
bool prefetch = false ,
71
- bool stress_multi_threading = false ) {
72
+ bool stress_multi_threading = false ,
73
+ bool is_bf16_out = false ) {
72
74
// Create embedding table
73
75
default_random_engine generator;
74
76
normal_distribution<float > embedding_distribution;
@@ -127,8 +129,8 @@ int run_benchmark(
127
129
weights[i] = embedding_distribution (generator);
128
130
}
129
131
130
- vector<float > output_sls_ref (batch_size * embedding_dim);
131
- vector<float > output_slws_ref (output_sls_ref.size ()),
132
+ vector<OutType > output_sls_ref (batch_size * embedding_dim);
133
+ vector<OutType > output_slws_ref (output_sls_ref.size ()),
132
134
output_sls (output_sls_ref.size ()), output_slws (output_sls_ref.size ());
133
135
134
136
constexpr int NUM_WARMUP = 10 ;
@@ -149,7 +151,7 @@ int run_benchmark(
149
151
has_weight_options.push_back (true );
150
152
}
151
153
for (bool has_weight : has_weight_options) {
152
- vector<float >& output_ref = has_weight ? output_slws_ref : output_sls_ref;
154
+ vector<OutType >& output_ref = has_weight ? output_slws_ref : output_sls_ref;
153
155
154
156
bool success = false , success_ref = false ;
155
157
@@ -179,17 +181,19 @@ int run_benchmark(
179
181
output_ref.data ());
180
182
}
181
183
182
- vector<float >& output = has_weight ? output_slws : output_sls;
184
+ vector<OutType >& output = has_weight ? output_slws : output_sls;
183
185
vector<bool > flush_cache_options;
184
186
flush_cache_options.push_back (false );
185
187
if (!stress_multi_threading) {
186
188
flush_cache_options.push_back (true );
187
189
}
188
190
189
- auto kernel_32 = GenerateEmbeddingSpMDM<uint8_t , int32_t >(
190
- embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0 );
191
- auto kernel_64 = GenerateEmbeddingSpMDM<uint8_t , int64_t >(
192
- embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0 );
191
+ auto kernel_32 =
192
+ GenerateEmbeddingSpMDM<uint8_t , int32_t , std::int32_t , OutType>(
193
+ embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0 );
194
+ auto kernel_64 =
195
+ GenerateEmbeddingSpMDM<uint8_t , int64_t , std::int32_t , OutType>(
196
+ embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0 );
193
197
194
198
#ifdef _OPENMP
195
199
#pragma omp barrier
@@ -255,9 +259,26 @@ int run_benchmark(
255
259
false && " ERROR: refernce impl and JIT imp did not both succeed" );
256
260
} else if (success) {
257
261
for (size_t i = 0 ; i < output.size (); ++i) {
258
- assert (fabs (output[i] - output_ref[i]) < 1e-3 );
259
- if (fabs (output[i] - output_ref[i]) >= 1e-3 ) {
260
- cout << i << " " << output[i] << " " << output_ref[i] << endl;
262
+ float tmp1 = 0 ;
263
+ float tmp2 = 0 ;
264
+ if (std::is_same<OutType, float >::value) {
265
+ tmp1 = output[i];
266
+ tmp2 = output_ref[i];
267
+ } else if (std::is_same<OutType, uint16_t >::value) {
268
+ if (is_bf16_out) {
269
+ tmp1 = cpu_bf162float (output[i]);
270
+ tmp2 = cpu_bf162float (output_ref[i]);
271
+ } else {
272
+ tmp1 = cpu_half2float (output[i]);
273
+ tmp2 = cpu_half2float (output_ref[i]);
274
+ }
275
+ } else {
276
+ assert (false && " ERROR: unsupported output type" );
277
+ cout << " ERROR: unsupported output type" << endl;
278
+ }
279
+ assert (fabs (tmp1 - tmp2) < 1e-3 );
280
+ if (fabs (tmp1 - tmp2) >= 1e-3 ) {
281
+ cout << i << " " << tmp1 << " " << tmp2 << endl;
261
282
}
262
283
}
263
284
}
@@ -267,6 +288,19 @@ int run_benchmark(
267
288
#pragma omp barrier
268
289
#endif
269
290
if (fbgemm_get_thread_num () == 0 ) {
291
+ if (std::is_same<OutType, float >::value) {
292
+ cout << " out type fp32" ;
293
+ } else if (std::is_same<OutType, uint16_t >::value) {
294
+ if (is_bf16_out) {
295
+ cout << " out type bf16" ;
296
+ } else {
297
+ cout << " out type fp16" ;
298
+ }
299
+ } else {
300
+ assert (false && " ERROR: unsupported output type" );
301
+ cout << " ERROR: unsupported output type" << endl;
302
+ }
303
+
270
304
if (has_weight) {
271
305
cout << setw (16 ) << " SLW(WEIGHTED) " ;
272
306
} else {
@@ -332,7 +366,8 @@ int main() {
332
366
#ifdef _OPENMP
333
367
#pragma omp parallel if (stress_multi_threading)
334
368
#endif
335
- run_benchmark (
369
+ #if defined(OUT_TYPE_FLOAT16)
370
+ run_benchmark<float16>(
336
371
batch_size,
337
372
num_rows,
338
373
embedding_dim,
@@ -341,22 +376,46 @@ int main() {
341
376
false ,
342
377
false ,
343
378
stress_multi_threading);
344
-
379
+ #else
380
+ run_benchmark<float >(
381
+ batch_size,
382
+ num_rows,
383
+ embedding_dim,
384
+ average_len,
385
+ false ,
386
+ false ,
387
+ false ,
388
+ stress_multi_threading);
389
+ #endif
345
390
if (stress_multi_threading) {
346
391
return 0 ;
347
392
}
348
393
349
394
cout << " 64 bit indices with prefetching, " ;
350
- run_benchmark (
395
+ #if defined(OUT_TYPE_FLOAT16)
396
+ run_benchmark<float16>(
351
397
batch_size, num_rows, embedding_dim, average_len, false , false , true );
352
-
398
+ #else
399
+ run_benchmark<float >(
400
+ batch_size, num_rows, embedding_dim, average_len, false , false , true );
401
+ #endif
353
402
cout << " 32 bit indices, " ;
354
- run_benchmark (
403
+ #if defined(OUT_TYPE_FLOAT16)
404
+ run_benchmark<float16>(
405
+ batch_size, num_rows, embedding_dim, average_len, false , true );
406
+ #else
407
+ run_benchmark<float >(
355
408
batch_size, num_rows, embedding_dim, average_len, false , true );
409
+ #endif
356
410
357
411
cout << " 32 bit indices with prefetching, " ;
358
- run_benchmark (
412
+ #if defined(OUT_TYPE_FLOAT16)
413
+ run_benchmark<float16>(
359
414
batch_size, num_rows, embedding_dim, average_len, false , true , true );
415
+ #else
416
+ run_benchmark<float >(
417
+ batch_size, num_rows, embedding_dim, average_len, false , true , true );
418
+ #endif
360
419
361
420
// running with normalize by lengths
362
421
// run_benchmark(batch_size, num_rows, embedding_dim, average_len,
0 commit comments