Skip to content

Commit

Permalink
Minor CK FP8 Tuning Improvements (#2987)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#80

Pull Request resolved: #2987

This diff makes a few small changes to improve CK FP8 performance based on recent improvements to ROCM and CK that have landed.

We specifically use the large kernel added in D60996231 more liberally as it's quite good and reenable some files in the CK Profiler that can now compile.

I also add some performance flags that are currently only enabled for CK as of this diff: D61266671

The latest llama benchmarks after this change are available [here](https://docs.google.com/spreadsheets/d/1GD44u4Sud_6T9iq_SJvYmSn8tx0bRd9niZPy2gVHK-s/edit?gid=482861329#gid=482861329).

We also include fix for certain small K values that dont work with newer versions of the CK pipeline.

Reviewed By: jianyuh, zjing14

Differential Revision: D61285882

fbshipit-source-id: 55113e76026b04cda63a4324bb2c5eb5c242ecb7
  • Loading branch information
jwfromm authored and facebook-github-bot committed Aug 17, 2024
1 parent 82e00b1 commit 537aeb3
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,17 @@ RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) {
} else if (M < 64) {
// Fallback to generic small batch kernel if we cant find a good match.
return fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2;
} else if ((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) || (K <= 2048 && N <= 8192)) {
} else if (((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) || (K <= 2048 && N <= 8192)) && K >= 1024) {
// Kernel that is optimized for larger batch sizes but otherwise small
// tensors.
return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5;
} else if (K < 1024) {
// Special case for small K.
return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
} else if (M < 1024) {
// Kernel for generic medium batch sizes.
return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
} else if (M > 4096 && N > 4096 && K > 2048) {
} else if (M >= 1024 && N >= 1024 && K >= 1024) {
// Kernel for very large gemm
return fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3;
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// A kernel that seems to work well on mid sized tensors.

// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (K % 128 != 0);

// Dispatch based on whether padding is needed or not.
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
32,
128,
128,
32,
32,
1,
2,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
128,
32,
128,
128,
32,
32,
1,
2,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,15 @@ fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v
at::Tensor w_scale,
at::Tensor Y);

// Another varient of larger batch size support.
at::Tensor
fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

// Kernel that seems optimal for highly compute bound problems.
at::Tensor
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3(
Expand Down

0 comments on commit 537aeb3

Please sign in to comment.