diff --git a/bench/EmbeddingSpMDM8BitBenchmark.cc b/bench/EmbeddingSpMDM8BitBenchmark.cc index 573062f6d9..2f69538562 100644 --- a/bench/EmbeddingSpMDM8BitBenchmark.cc +++ b/bench/EmbeddingSpMDM8BitBenchmark.cc @@ -60,6 +60,7 @@ static vector> GetInputs_() { vector benchmarkTimes; +template int run_benchmark( int batch_size, int num_rows, @@ -68,7 +69,8 @@ int run_benchmark( bool normalize_by_lengths, bool use_32_bit_indices = false, bool prefetch = false, - bool stress_multi_threading = false) { + bool stress_multi_threading = false, + bool is_bf16_out = false) { // Create embedding table default_random_engine generator; normal_distribution embedding_distribution; @@ -127,8 +129,8 @@ int run_benchmark( weights[i] = embedding_distribution(generator); } - vector output_sls_ref(batch_size * embedding_dim); - vector output_slws_ref(output_sls_ref.size()), + vector output_sls_ref(batch_size * embedding_dim); + vector output_slws_ref(output_sls_ref.size()), output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size()); constexpr int NUM_WARMUP = 10; @@ -149,7 +151,7 @@ int run_benchmark( has_weight_options.push_back(true); } for (bool has_weight : has_weight_options) { - vector& output_ref = has_weight ? output_slws_ref : output_sls_ref; + vector& output_ref = has_weight ? output_slws_ref : output_sls_ref; bool success = false, success_ref = false; @@ -179,17 +181,19 @@ int run_benchmark( output_ref.data()); } - vector& output = has_weight ? output_slws : output_sls; + vector& output = has_weight ? output_slws : output_sls; vector flush_cache_options; flush_cache_options.push_back(false); if (!stress_multi_threading) { flush_cache_options.push_back(true); } - auto kernel_32 = GenerateEmbeddingSpMDM( - embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0); - auto kernel_64 = GenerateEmbeddingSpMDM( - embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0); + auto kernel_32 = + GenerateEmbeddingSpMDM( + embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0); + auto kernel_64 = + GenerateEmbeddingSpMDM( + embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0); #ifdef _OPENMP #pragma omp barrier @@ -255,9 +259,26 @@ int run_benchmark( false && "ERROR: refernce impl and JIT imp did not both succeed"); } else if (success) { for (size_t i = 0; i < output.size(); ++i) { - assert(fabs(output[i] - output_ref[i]) < 1e-3); - if (fabs(output[i] - output_ref[i]) >= 1e-3) { - cout << i << " " << output[i] << " " << output_ref[i] << endl; + float tmp1 = 0; + float tmp2 = 0; + if (std::is_same::value) { + tmp1 = output[i]; + tmp2 = output_ref[i]; + } else if (std::is_same::value) { + if (is_bf16_out) { + tmp1 = cpu_bf162float(output[i]); + tmp2 = cpu_bf162float(output_ref[i]); + } else { + tmp1 = cpu_half2float(output[i]); + tmp2 = cpu_half2float(output_ref[i]); + } + } else { + assert(false && "ERROR: unsupported output type"); + cout << "ERROR: unsupported output type" << endl; + } + assert(fabs(tmp1 - tmp2) < 1e-3); + if (fabs(tmp1 - tmp2) >= 1e-3) { + cout << i << " " << tmp1 << " " << tmp2 << endl; } } } @@ -267,6 +288,19 @@ int run_benchmark( #pragma omp barrier #endif if (fbgemm_get_thread_num() == 0) { + if (std::is_same::value) { + cout << "out type fp32"; + } else if (std::is_same::value) { + if (is_bf16_out) { + cout << "out type bf16"; + } else { + cout << "out type fp16"; + } + } else { + assert(false && "ERROR: unsupported output type"); + cout << "ERROR: unsupported output type" << endl; + } + if (has_weight) { cout << setw(16) << "SLW(WEIGHTED) "; } else { @@ -332,7 +366,8 @@ int main() { #ifdef _OPENMP #pragma omp parallel if (stress_multi_threading) #endif - run_benchmark( +#if defined(OUT_TYPE_FLOAT16) + run_benchmark( batch_size, num_rows, embedding_dim, @@ -341,22 +376,46 @@ int main() { false, false, stress_multi_threading); - +#else + run_benchmark( + batch_size, + num_rows, + embedding_dim, + average_len, + false, + false, + false, + stress_multi_threading); +#endif if (stress_multi_threading) { return 0; } cout << "64 bit indices with prefetching, "; - run_benchmark( +#if defined(OUT_TYPE_FLOAT16) + run_benchmark( batch_size, num_rows, embedding_dim, average_len, false, false, true); - +#else + run_benchmark( + batch_size, num_rows, embedding_dim, average_len, false, false, true); +#endif cout << "32 bit indices, "; - run_benchmark( +#if defined(OUT_TYPE_FLOAT16) + run_benchmark( + batch_size, num_rows, embedding_dim, average_len, false, true); +#else + run_benchmark( batch_size, num_rows, embedding_dim, average_len, false, true); +#endif cout << "32 bit indices with prefetching, "; - run_benchmark( +#if defined(OUT_TYPE_FLOAT16) + run_benchmark( batch_size, num_rows, embedding_dim, average_len, false, true, true); +#else + run_benchmark( + batch_size, num_rows, embedding_dim, average_len, false, true, true); +#endif // running with normalize by lengths // run_benchmark(batch_size, num_rows, embedding_dim, average_len,