Skip to content

Commit 04e8aa7

Browse files
helloguofacebook-github-bot
authored andcommitted
add EmbeddingSpMDM8BitBenchmarkOutTypeFloat16 (#2952)
Summary: X-link: facebookresearch/FBGEMM#70 Pull Request resolved: #2952 Add EmbeddingSpMDM8BitBenchmarkOutTypeFloat16 on ARM. Reviewed By: sryap Differential Revision: D60972344 fbshipit-source-id: 5b22831cd291a325db4ff143eaf36cbc69b74fc9
1 parent 01775eb commit 04e8aa7

File tree

1 file changed

+77
-18
lines changed

1 file changed

+77
-18
lines changed

bench/EmbeddingSpMDM8BitBenchmark.cc

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ static vector<vector<int>> GetInputs_() {
6060

6161
vector<double> benchmarkTimes;
6262

63+
template <typename OutType>
6364
int run_benchmark(
6465
int batch_size,
6566
int num_rows,
@@ -68,7 +69,8 @@ int run_benchmark(
6869
bool normalize_by_lengths,
6970
bool use_32_bit_indices = false,
7071
bool prefetch = false,
71-
bool stress_multi_threading = false) {
72+
bool stress_multi_threading = false,
73+
bool is_bf16_out = false) {
7274
// Create embedding table
7375
default_random_engine generator;
7476
normal_distribution<float> embedding_distribution;
@@ -127,8 +129,8 @@ int run_benchmark(
127129
weights[i] = embedding_distribution(generator);
128130
}
129131

130-
vector<float> output_sls_ref(batch_size * embedding_dim);
131-
vector<float> output_slws_ref(output_sls_ref.size()),
132+
vector<OutType> output_sls_ref(batch_size * embedding_dim);
133+
vector<OutType> output_slws_ref(output_sls_ref.size()),
132134
output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size());
133135

134136
constexpr int NUM_WARMUP = 10;
@@ -149,7 +151,7 @@ int run_benchmark(
149151
has_weight_options.push_back(true);
150152
}
151153
for (bool has_weight : has_weight_options) {
152-
vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
154+
vector<OutType>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
153155

154156
bool success = false, success_ref = false;
155157

@@ -179,17 +181,19 @@ int run_benchmark(
179181
output_ref.data());
180182
}
181183

182-
vector<float>& output = has_weight ? output_slws : output_sls;
184+
vector<OutType>& output = has_weight ? output_slws : output_sls;
183185
vector<bool> flush_cache_options;
184186
flush_cache_options.push_back(false);
185187
if (!stress_multi_threading) {
186188
flush_cache_options.push_back(true);
187189
}
188190

189-
auto kernel_32 = GenerateEmbeddingSpMDM<uint8_t, int32_t>(
190-
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
191-
auto kernel_64 = GenerateEmbeddingSpMDM<uint8_t, int64_t>(
192-
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
191+
auto kernel_32 =
192+
GenerateEmbeddingSpMDM<uint8_t, int32_t, std::int32_t, OutType>(
193+
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
194+
auto kernel_64 =
195+
GenerateEmbeddingSpMDM<uint8_t, int64_t, std::int32_t, OutType>(
196+
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
193197

194198
#ifdef _OPENMP
195199
#pragma omp barrier
@@ -255,9 +259,26 @@ int run_benchmark(
255259
false && "ERROR: refernce impl and JIT imp did not both succeed");
256260
} else if (success) {
257261
for (size_t i = 0; i < output.size(); ++i) {
258-
assert(fabs(output[i] - output_ref[i]) < 1e-3);
259-
if (fabs(output[i] - output_ref[i]) >= 1e-3) {
260-
cout << i << " " << output[i] << " " << output_ref[i] << endl;
262+
float tmp1 = 0;
263+
float tmp2 = 0;
264+
if (std::is_same<OutType, float>::value) {
265+
tmp1 = output[i];
266+
tmp2 = output_ref[i];
267+
} else if (std::is_same<OutType, uint16_t>::value) {
268+
if (is_bf16_out) {
269+
tmp1 = cpu_bf162float(output[i]);
270+
tmp2 = cpu_bf162float(output_ref[i]);
271+
} else {
272+
tmp1 = cpu_half2float(output[i]);
273+
tmp2 = cpu_half2float(output_ref[i]);
274+
}
275+
} else {
276+
assert(false && "ERROR: unsupported output type");
277+
cout << "ERROR: unsupported output type" << endl;
278+
}
279+
assert(fabs(tmp1 - tmp2) < 1e-3);
280+
if (fabs(tmp1 - tmp2) >= 1e-3) {
281+
cout << i << " " << tmp1 << " " << tmp2 << endl;
261282
}
262283
}
263284
}
@@ -267,6 +288,19 @@ int run_benchmark(
267288
#pragma omp barrier
268289
#endif
269290
if (fbgemm_get_thread_num() == 0) {
291+
if (std::is_same<OutType, float>::value) {
292+
cout << "out type fp32";
293+
} else if (std::is_same<OutType, uint16_t>::value) {
294+
if (is_bf16_out) {
295+
cout << "out type bf16";
296+
} else {
297+
cout << "out type fp16";
298+
}
299+
} else {
300+
assert(false && "ERROR: unsupported output type");
301+
cout << "ERROR: unsupported output type" << endl;
302+
}
303+
270304
if (has_weight) {
271305
cout << setw(16) << "SLW(WEIGHTED) ";
272306
} else {
@@ -332,7 +366,8 @@ int main() {
332366
#ifdef _OPENMP
333367
#pragma omp parallel if (stress_multi_threading)
334368
#endif
335-
run_benchmark(
369+
#if defined(OUT_TYPE_FLOAT16)
370+
run_benchmark<float16>(
336371
batch_size,
337372
num_rows,
338373
embedding_dim,
@@ -341,22 +376,46 @@ int main() {
341376
false,
342377
false,
343378
stress_multi_threading);
344-
379+
#else
380+
run_benchmark<float>(
381+
batch_size,
382+
num_rows,
383+
embedding_dim,
384+
average_len,
385+
false,
386+
false,
387+
false,
388+
stress_multi_threading);
389+
#endif
345390
if (stress_multi_threading) {
346391
return 0;
347392
}
348393

349394
cout << "64 bit indices with prefetching, ";
350-
run_benchmark(
395+
#if defined(OUT_TYPE_FLOAT16)
396+
run_benchmark<float16>(
351397
batch_size, num_rows, embedding_dim, average_len, false, false, true);
352-
398+
#else
399+
run_benchmark<float>(
400+
batch_size, num_rows, embedding_dim, average_len, false, false, true);
401+
#endif
353402
cout << "32 bit indices, ";
354-
run_benchmark(
403+
#if defined(OUT_TYPE_FLOAT16)
404+
run_benchmark<float16>(
405+
batch_size, num_rows, embedding_dim, average_len, false, true);
406+
#else
407+
run_benchmark<float>(
355408
batch_size, num_rows, embedding_dim, average_len, false, true);
409+
#endif
356410

357411
cout << "32 bit indices with prefetching, ";
358-
run_benchmark(
412+
#if defined(OUT_TYPE_FLOAT16)
413+
run_benchmark<float16>(
359414
batch_size, num_rows, embedding_dim, average_len, false, true, true);
415+
#else
416+
run_benchmark<float>(
417+
batch_size, num_rows, embedding_dim, average_len, false, true, true);
418+
#endif
360419

361420
// running with normalize by lengths
362421
// run_benchmark(batch_size, num_rows, embedding_dim, average_len,

0 commit comments

Comments
 (0)