Skip to content

Commit

Permalink
Integrate CK PR1453 for improving fp8 gemm (#2971)
Browse files Browse the repository at this point in the history
Summary:
Integrate ROCm/composable_kernel#1453 and disable conditional OOB check, which improves both memory bound and compute bound cases

* ~~Add a flag expect better register allocation (perf. regression with ROCm 6.2+)~~ Move all flags to D61204625
* Enable bf16 atomic_add D60544251
* Add 256x256x128 tile for fp8 gemm
* Optimize OOB check strategy, reduce conditional mask usage and instance number of gemm_multiply_multiply.
* Disable conditional OOB check

Pull Request resolved: #2971

X-link: facebookresearch/FBGEMM#68

Reviewed By: danzimm

Differential Revision: D60996231

fbshipit-source-id: 9b1e9841bfef7b37139c9095a71625b28177c004
  • Loading branch information
zjing14 authored and facebook-github-bot committed Aug 13, 2024
1 parent 4ae45b7 commit 440cbb0
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 44 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/experimental/gen_ai/src/gemm/ck_extensions.hip
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ at::Tensor bf16_gemm_impl(at::Tensor A, at::Tensor B, std::optional<at::Tensor>
StrideB,
DStrideArray,
StrideC,
1,
a_element_op,
b_element_op,
c_element_op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) {
} else if (M < 1024) {
// Kernel for generic medium batch sizes.
return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
} else if (M > 4096 && N > 4096 && K > 2048) {
// Kernel for very large gemm
return fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3;
} else {
// Fallback large kernel.
return fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 128 != 0) || (N % 128 != 0) || (K % 128 != 0);
bool pad = (K % 128 != 0);

// Dispatch based on whether padding is needed or not.
if (pad) {
Expand All @@ -41,9 +41,11 @@ fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v1>;
ck::BlockGemmPipelineVersion::v1,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
Expand All @@ -64,6 +66,7 @@ fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave
ck::BlockGemmPipelineVersion::v1,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 128 != 0) || (N % 128 != 0) || (K % 128 != 0);
bool pad = (K % 128 != 0);

if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
Expand All @@ -40,9 +40,11 @@ fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3>;
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
Expand All @@ -63,6 +65,7 @@ fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 128 != 0) || (N % 128 != 0) || (K % 128 != 0);
bool pad = (K % 128 != 0);

if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
Expand All @@ -40,9 +40,11 @@ fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v5>;
ck::BlockGemmPipelineVersion::v5,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
Expand All @@ -63,6 +65,7 @@ fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave
ck::BlockGemmPipelineVersion::v5,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fp8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 128 != 0) || (N % 128 != 0) || (K % 64 != 0);
bool pad = (K % 64 != 0);

// Dispatch based on whether padding is needed or not.
if (pad) {
Expand All @@ -41,9 +41,11 @@ fp8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v4>;
ck::BlockGemmPipelineVersion::v4,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
Expand All @@ -64,6 +66,7 @@ fp8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_
ck::BlockGemmPipelineVersion::v4,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave
int N = WQ.size(0);
int K = WQ.size(1);

bool mnpad = (M % 224 != 0) || (N % 256 != 0);
bool kpad = K % 128 != 0;

if (kpad) {
Expand All @@ -42,29 +41,7 @@ fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else if (mnpad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
224,
256,
128,
16,
16,
7,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::MNPadding>;
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// A kernel that seems to work well on mid sized tensors.

// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (K % 128 != 0);

// Dispatch based on whether padding is needed or not.
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
256,
128,
16,
16,
8,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
256,
128,
16,
16,
8,
8,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
2,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 256 != 0) || (N % 256 != 0) || (K % 64 != 0);
bool pad = (K % 64 != 0);

// Dispatch based on whether padding is needed or not.
if (pad) {
Expand All @@ -41,9 +41,11 @@ fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v4>;
ck::BlockGemmPipelineVersion::v4,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
Expand All @@ -64,6 +66,7 @@ fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_
ck::BlockGemmPipelineVersion::v4,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ at::Tensor f8f8bf16_rowwise_impl(
StrideB,
std::array<ck::index_t, NumDTensor>{0, 0},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,12 @@ fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

// Kernel that seems optimal for highly compute bound problems.
at::Tensor
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

0 comments on commit 440cbb0

Please sign in to comment.