Skip to content

Commit 600f38e

Browse files
author
pytorchbot
committed
2024-08-08 nightly release (cefcad3)
1 parent 3916909 commit 600f38e

File tree

31 files changed

+1442
-207
lines changed

31 files changed

+1442
-207
lines changed

fbgemm_gpu/docs/src/nitpick.ignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ cpp:identifier Tensor
2929
cpp:identifier TensorQuantizationParams
3030
cpp:identifier uint32_t
3131
cpp:identifier uint8_t
32+
cpp:identifier cudaStream_t
33+
cpp:identifier cudaError_t
3234

3335
py:class BoundsCheckMode
3436
py:class c_size_t

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_llama_shapes() -> List[Tuple[int, int, int]]:
3434
# Helper function that returns a list of shapes relevant to llama.
3535

3636
llama_shapes = []
37-
for M in [1, 16384]:
37+
for M in [1, 16, 32, 64, 96, 128, 16384]:
3838
# Add shapes for llama 70B
3939
llama_shapes += [
4040
(M, 1280, 8192),

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def rotating_buffer_fn(fn, args_list, copy_cnt):
7070
# so divide time accordingly
7171
return triton.testing.do_bench_cudagraph(
7272
lambda: rotating_buffer_fn(self.compute, args_list, copy_cnt + 1),
73-
rep=500,
73+
rep=200,
7474
) / (copy_cnt + 1)
7575

7676
def benchmark(
@@ -259,7 +259,7 @@ def compute(self, xq, wq, x_scale, w_scale, dummy_scale):
259259
use_fast_accum=True,
260260
)
261261
# Apply separate rowwise scaling.
262-
output = output * x_scale[:, None] * w_scale[None, :]
262+
output = scale_fp8_row(output, x_scale, w_scale)
263263
return output
264264

265265
def quantize_and_compute(self, x, w):

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_gemm.hip

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,79 @@ static const std::unordered_map<
4545
IntTupleHash>
4646
rowwise_lookup_dispatch = {
4747
// LLama 70B Decode shapes.
48-
{{1, 1280, 8192},
48+
// Support for decode across batch sizes for [1280, 8192]
49+
{{16, 1280, 8192},
50+
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
51+
{{32, 1280, 8192},
4952
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
50-
{{1, 8192, 1024},
53+
{{64, 1280, 8192},
54+
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
55+
{{128, 1280, 8192},
56+
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
57+
// Support for decode across batch sizes for [8192, 1024]
58+
{{16, 8192, 1024},
5159
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
52-
{{1, 7168, 8192},
60+
{{32, 8192, 1024},
61+
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
62+
{{64, 8192, 1024},
63+
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
64+
{{128, 8192, 1024},
65+
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
66+
// Support for decode across batch sizes for [7168, 8192]
67+
{{16, 7168, 8192},
68+
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
69+
{{32, 7168, 8192},
5370
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
54-
{{1, 8192, 3584},
71+
{{64, 7168, 8192},
72+
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
73+
{{128, 7168, 8192},
74+
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
75+
// Support for decode across batch sizes for [8192, 3584]
76+
{{16, 8192, 3584},
5577
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
78+
{{32, 8192, 3584},
79+
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
80+
{{64, 8192, 3584},
81+
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
82+
{{128, 8192, 3584},
83+
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
5684
// Llama 405B Decode Shapes.
57-
{{1, 13312, 6656},
85+
// Support for decode across batch sizes for [13312, 6656].
86+
{{16, 13312, 6656},
5887
fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1},
59-
{{1, 13312, 16384},
60-
//fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1},
88+
{{32, 13312, 6656},
89+
fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2},
90+
{{64, 13312, 6656},
91+
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
92+
{{128, 13312, 6656},
93+
fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
94+
// Support for decode across batch sizes for [13312, 16384].
95+
{{16, 13312, 16384},
6196
fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2},
62-
{{1, 16384, 6656},
97+
{{32, 13312, 16384},
98+
fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2},
99+
{{64, 13312, 16384},
100+
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
101+
{{128, 13312, 16384},
102+
fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
103+
// Support for decode across batch sizes for [16384, 6656].
104+
{{16, 16384, 6656},
63105
fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1},
64-
{{1, 16384, 16384},
65-
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
106+
{{32, 16384, 6656},
107+
fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2},
108+
{{64, 16384, 6656},
109+
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
110+
{{128, 16384, 6656},
111+
fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
112+
// Support for decode across batch sizes for [16384, 16384].
113+
{{16, 16384, 16384},
114+
fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2},
115+
{{32, 16384, 16384},
116+
fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2},
117+
{{64, 16384, 16384},
118+
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
119+
{{128, 16384, 16384},
120+
fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
66121
// EMU 1.6 Shapes.
67122
{{1536, 3584, 3584},
68123
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1},
@@ -117,8 +172,20 @@ RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) {
117172

118173
RowwiseKernel rowwise_dispatch(int M, int N, int K) {
119174
// For a given shape, either find the best kernel via lookup or heuristic.
175+
// For many small M shapes, we bucket them to the next largest kernel.
176+
// This is fine since kernels are padded anyway.
177+
int padded_m = M;
178+
if (M <= 16) {
179+
padded_m = 16;
180+
} else if (M <= 32) {
181+
padded_m = 32;
182+
} else if (M <= 64) {
183+
padded_m = 64;
184+
} else if (M <= 128) {
185+
padded_m = 128;
186+
}
120187
// First check if this shape is available in the direct lookup.
121-
auto it = rowwise_lookup_dispatch.find({M, N, K});
188+
auto it = rowwise_lookup_dispatch.find({padded_m, N, K});
122189
// If we found an optimal kernel, use it.
123190
if (it != rowwise_lookup_dispatch.end()) {
124191
return it->second;

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,54 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
1818
// The smallest kernel we have available. Works well for memory bound shapes.
19-
using DeviceGemmInstance = DeviceGemmHelper<
20-
128,
21-
16,
22-
32,
23-
128,
24-
16,
25-
16,
26-
1,
27-
1,
28-
S<8, 16, 1>,
29-
S<8, 16, 1>,
30-
S<1, 16, 1, 8>,
31-
S<4, 4, 1>,
32-
1,
33-
1,
34-
ck::BlockGemmPipelineScheduler::Interwave,
35-
ck::BlockGemmPipelineVersion::v2>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
19+
20+
// Check if this input needs to be padded.
21+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
22+
int N = WQ.size(0);
23+
int K = WQ.size(1);
24+
bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % 128 != 0);
25+
26+
if (pad) {
27+
using DeviceGemmInstance = DeviceGemmHelper<
28+
128,
29+
16,
30+
32,
31+
128,
32+
16,
33+
16,
34+
1,
35+
1,
36+
S<8, 16, 1>,
37+
S<8, 16, 1>,
38+
S<1, 16, 1, 8>,
39+
S<4, 4, 1>,
40+
1,
41+
1,
42+
ck::BlockGemmPipelineScheduler::Interwave,
43+
ck::BlockGemmPipelineVersion::v2>;
44+
// Run kernel instance.
45+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
46+
} else {
47+
using DeviceGemmInstance = DeviceGemmHelper<
48+
128,
49+
16,
50+
32,
51+
128,
52+
16,
53+
16,
54+
1,
55+
1,
56+
S<8, 16, 1>,
57+
S<8, 16, 1>,
58+
S<1, 16, 1, 8>,
59+
S<4, 4, 1>,
60+
1,
61+
1,
62+
ck::BlockGemmPipelineScheduler::Interwave,
63+
ck::BlockGemmPipelineVersion::v2,
64+
ck::tensor_operation::device::GemmSpecialization::Default>;
65+
// Run kernel instance.
66+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
67+
68+
}
3869
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,52 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
1818
// The smallest kernel we have available. Works well for memory bound shapes.
19-
using DeviceGemmInstance = DeviceGemmHelper<
20-
128,
21-
16,
22-
32,
23-
128,
24-
16,
25-
16,
26-
1,
27-
1,
28-
S<8, 16, 1>,
29-
S<8, 16, 1>,
30-
S<1, 16, 1, 8>,
31-
S<4, 4, 1>,
32-
1,
33-
1,
34-
ck::BlockGemmPipelineScheduler::Intrawave,
35-
ck::BlockGemmPipelineVersion::v2>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
19+
20+
// Check if this input needs to be padded.
21+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
22+
int N = WQ.size(0);
23+
int K = WQ.size(1);
24+
bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % 128 != 0);
25+
if (pad) {
26+
using DeviceGemmInstance = DeviceGemmHelper<
27+
128,
28+
16,
29+
32,
30+
128,
31+
16,
32+
16,
33+
1,
34+
1,
35+
S<8, 16, 1>,
36+
S<8, 16, 1>,
37+
S<1, 16, 1, 8>,
38+
S<4, 4, 1>,
39+
1,
40+
1,
41+
ck::BlockGemmPipelineScheduler::Intrawave,
42+
ck::BlockGemmPipelineVersion::v2>;
43+
// Run kernel instance.
44+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
45+
} else{
46+
using DeviceGemmInstance = DeviceGemmHelper<
47+
128,
48+
16,
49+
32,
50+
128,
51+
16,
52+
16,
53+
1,
54+
1,
55+
S<8, 16, 1>,
56+
S<8, 16, 1>,
57+
S<1, 16, 1, 8>,
58+
S<4, 4, 1>,
59+
1,
60+
1,
61+
ck::BlockGemmPipelineScheduler::Intrawave,
62+
ck::BlockGemmPipelineVersion::v2,
63+
ck::tensor_operation::device::GemmSpecialization::Default>;
64+
// Run kernel instance.
65+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
66+
}
3867
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/kernels/fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,54 @@ fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
// The smallest kernel we have available. Works well for memory bound shapes.
19-
using DeviceGemmInstance = DeviceGemmHelper<
20-
128,
21-
32,
22-
16,
23-
128,
24-
16,
25-
16,
26-
1,
27-
1,
28-
S<8, 16, 1>,
29-
S<8, 16, 1>,
30-
S<1, 16, 1, 8>,
31-
S<2, 2, 1>,
32-
1,
33-
1,
34-
ck::BlockGemmPipelineScheduler::Interwave,
35-
ck::BlockGemmPipelineVersion::v2>;
36-
// Run kernel instance.
37-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
18+
// A small kernel for small but not tiny shapes.
19+
20+
// Check if this input needs to be padded.
21+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
22+
int N = WQ.size(0);
23+
int K = WQ.size(1);
24+
bool pad = (M % 32 != 0) || (N % 16 != 0) || (K % 128 != 0);
25+
26+
if (pad) {
27+
using DeviceGemmInstance = DeviceGemmHelper<
28+
128,
29+
32,
30+
16,
31+
128,
32+
16,
33+
16,
34+
1,
35+
1,
36+
S<8, 16, 1>,
37+
S<8, 16, 1>,
38+
S<1, 16, 1, 8>,
39+
S<2, 2, 1>,
40+
1,
41+
1,
42+
ck::BlockGemmPipelineScheduler::Interwave,
43+
ck::BlockGemmPipelineVersion::v2>;
44+
// Run kernel instance.
45+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
46+
} else {
47+
using DeviceGemmInstance = DeviceGemmHelper<
48+
128,
49+
32,
50+
16,
51+
128,
52+
16,
53+
16,
54+
1,
55+
1,
56+
S<8, 16, 1>,
57+
S<8, 16, 1>,
58+
S<1, 16, 1, 8>,
59+
S<2, 2, 1>,
60+
1,
61+
1,
62+
ck::BlockGemmPipelineScheduler::Interwave,
63+
ck::BlockGemmPipelineVersion::v2,
64+
ck::tensor_operation::device::GemmSpecialization::Default>;
65+
// Run kernel instance.
66+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
67+
}
3868
}

0 commit comments

Comments
 (0)