From 537aeb39cf7370bb51c199ea0e6c1e1e5263684a Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sat, 17 Aug 2024 09:43:46 -0700 Subject: [PATCH] Minor CK FP8 Tuning Improvements (#2987) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/80 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2987 This diff makes a few small changes to improve CK FP8 performance based on recent improvements to ROCM and CK that have landed. We specifically use the large kernel added in D60996231 more liberally as it's quite good and reenable some files in the CK Profiler that can now compile. I also add some performance flags that are currently only enabled for CK as of this diff: D61266671 The latest llama benchmarks after this change are available [here](https://docs.google.com/spreadsheets/d/1GD44u4Sud_6T9iq_SJvYmSn8tx0bRd9niZPy2gVHK-s/edit?gid=482861329#gid=482861329). We also include fix for certain small K values that dont work with newer versions of the CK pipeline. Reviewed By: jianyuh, zjing14 Differential Revision: D61285882 fbshipit-source-id: 55113e76026b04cda63a4324bb2c5eb5c242ecb7 --- .../ck_extensions/fp8_rowwise_gemm.hip | 7 +- ...8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip | 72 +++++++++++++++++++ .../kernels/fp8_rowwise_kernel_manifest.h | 9 +++ 3 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip index 4cd23404fb..f4c22e1e03 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip @@ -157,14 +157,17 @@ RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) { } else if (M < 64) { // Fallback to generic small batch kernel if we cant find a good match. return fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2; - } else if ((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) || (K <= 2048 && N <= 8192)) { + } else if (((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) || (K <= 2048 && N <= 8192)) && K >= 1024) { // Kernel that is optimized for larger batch sizes but otherwise small // tensors. return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5; + } else if (K < 1024) { + // Special case for small K. + return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; } else if (M < 1024) { // Kernel for generic medium batch sizes. return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; - } else if (M > 4096 && N > 4096 && K > 2048) { + } else if (M >= 1024 && N >= 1024 && K >= 1024) { // Kernel for very large gemm return fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; } else { diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip new file mode 100644 index 0000000000..8b3d5abdda --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // A kernel that seems to work well on mid sized tensors. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (K % 128 != 0); + + // Dispatch based on whether padding is needed or not. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h index 9c01d2b3da..2b07530175 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_kernel_manifest.h @@ -206,6 +206,15 @@ fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v at::Tensor w_scale, at::Tensor Y); +// Another varient of larger batch size support. +at::Tensor +fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + // Kernel that seems optimal for highly compute bound problems. at::Tensor fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3(