Skip to content

Commit b14408b

Browse files
authored
feat: TRTLLM FMHAv2 backend for ctx attention (#2142)
<!-- .github/pull_request_template.md --> ## 📌 Description Porting over the [trtllm fmhav2 library](https://github.com/NVIDIA/TensorRT-LLM/tree/main/cpp/kernels/fmha_v2) to support prefill cases. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * INT8 quantization and FP8 (E4M3/E5M2) conversion utilities, plus broad packed 8/16‑bit output paths. * Hopper GMMA/TMA optimizations and SM90 GMMA/IGMMA helpers for high‑performance kernels. * Extensive FMHA v2 tiling/load/store primitives (Q/K/V/O), TMA descriptor management, and paged KV cache. * **Enhanced Support** * Alibi positional-bias params, BF16/mixed-precision conversions, causal/sliding-window masks and multi‑token prediction. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent df5c2e4 commit b14408b

File tree

77 files changed

+55337
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+55337
-0
lines changed

csrc/fmha_v2/convert.cu

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights
3+
* reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
4+
*
5+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6+
* property and proprietary rights in and to this material, related
7+
* documentation and any modifications thereto. Any use, reproduction,
8+
* disclosure or distribution of this material and related documentation
9+
* without an express license agreement from NVIDIA CORPORATION or
10+
* its affiliates is strictly prohibited.
11+
*/
12+
13+
#include <fmha/numeric_types.h>
14+
#include <fmha/utils.h>
15+
#include <stdint.h>
16+
17+
////////////////////////////////////////////////////////////////////////////////////////////////////
18+
19+
__global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) {
20+
// The step.
21+
size_t step = (size_t)gridDim.x * blockDim.x;
22+
23+
// Iterate over the elements.
24+
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
25+
// Load 4 integers.
26+
int4 tmp = reinterpret_cast<int4 const*>(src)[ii];
27+
28+
// Convert to float and scale.
29+
float x = static_cast<float>(tmp.x) * scale;
30+
float y = static_cast<float>(tmp.y) * scale;
31+
float z = static_cast<float>(tmp.z) * scale;
32+
float w = static_cast<float>(tmp.w) * scale;
33+
34+
// Convert to int8.
35+
uint32_t a;
36+
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x));
37+
uint32_t b;
38+
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y));
39+
uint32_t c;
40+
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z));
41+
uint32_t d;
42+
asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w));
43+
44+
// Compact.
45+
char4 out;
46+
out.x = reinterpret_cast<int8_t const&>(a);
47+
out.y = reinterpret_cast<int8_t const&>(b);
48+
out.z = reinterpret_cast<int8_t const&>(c);
49+
out.w = reinterpret_cast<int8_t const&>(d);
50+
51+
// Store.
52+
reinterpret_cast<uint32_t*>(dst)[ii] = reinterpret_cast<uint32_t const&>(out);
53+
}
54+
}
55+
56+
////////////////////////////////////////////////////////////////////////////////////////////////////
57+
58+
void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d,
59+
float scale) {
60+
size_t n = (size_t)s * b * h * d;
61+
convert_int32_to_int8_kernel<<<512, 256>>>(dst, src, n, scale);
62+
}
63+
64+
////////////////////////////////////////////////////////////////////////////////////////////////////
65+
66+
template <typename T>
67+
__device__ inline typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type pack_float4(
68+
float4 const& f);
69+
70+
////////////////////////////////////////////////////////////////////////////////////////////////////
71+
72+
template <>
73+
__device__ inline uint2 pack_float4<fmha::fp16_t>(float4 const& f) {
74+
return fmha::float4_to_half4(f.x, f.y, f.z, f.w);
75+
}
76+
77+
////////////////////////////////////////////////////////////////////////////////////////////////////
78+
79+
template <>
80+
__device__ inline uint2 pack_float4<fmha::bf16_t>(float4 const& f) {
81+
return fmha::float4_to_16bit_x4<fmha::bf16_t>(f.x, f.y, f.z, f.w);
82+
}
83+
84+
////////////////////////////////////////////////////////////////////////////////////////////////////
85+
86+
template <>
87+
__device__ inline uint32_t pack_float4<fmha::e4m3_t>(float4 const& f) {
88+
return fmha::float4_to_e4m3x4(f.x, f.y, f.z, f.w);
89+
}
90+
91+
////////////////////////////////////////////////////////////////////////////////////////////////////
92+
template <>
93+
__device__ inline uint32_t pack_float4<fmha::e5m2_t>(float4 const& f) {
94+
return fmha::float4_to_e5m2x4(f.x, f.y, f.z, f.w);
95+
}
96+
97+
////////////////////////////////////////////////////////////////////////////////////////////////////
98+
99+
template <typename T>
100+
__global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) {
101+
using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type;
102+
103+
// The step.
104+
size_t step = (size_t)gridDim.x * blockDim.x;
105+
106+
// Iterate over the elements.
107+
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
108+
// Load 4 floats.
109+
float4 tmp = reinterpret_cast<float4 const*>(src)[ii];
110+
// Scale.
111+
tmp.x *= scale;
112+
tmp.y *= scale;
113+
tmp.z *= scale;
114+
tmp.w *= scale;
115+
// Convert to 4 Ts.
116+
auto out = pack_float4<T>(tmp);
117+
118+
// Store.
119+
reinterpret_cast<Dst*>(dst)[ii] = reinterpret_cast<Dst const&>(out);
120+
}
121+
}
122+
123+
template <typename T>
124+
__global__ void convert_T_to_fp32_kernel(void* dst, void const* src, size_t n, float scale = 1.f) {
125+
using Src = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type;
126+
127+
union {
128+
Src raw;
129+
T elt[4];
130+
} data;
131+
132+
// The step.
133+
size_t step = (size_t)gridDim.x * blockDim.x;
134+
135+
// Iterate over the elements.
136+
for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) {
137+
// Load 4 floats.
138+
data.raw = reinterpret_cast<Src const*>(src)[ii];
139+
float4 out;
140+
// Scale.
141+
out.x = float(data.elt[0]) * scale;
142+
out.y = float(data.elt[1]) * scale;
143+
out.z = float(data.elt[2]) * scale;
144+
out.w = float(data.elt[3]) * scale;
145+
146+
// Store.
147+
reinterpret_cast<float4*>(dst)[ii] = reinterpret_cast<float4 const&>(out);
148+
}
149+
}
150+
151+
////////////////////////////////////////////////////////////////////////////////////////////////////
152+
153+
void run_conversion_fp32_to_fp16(void* dst, void const* src, int s, int b, int h, int d) {
154+
// No need to expose the scale factor for FP16/FP32.
155+
size_t n = (size_t)s * b * h * d;
156+
convert_fp32_to_T_kernel<fmha::fp16_t><<<512, 256>>>(dst, src, n, 1.f);
157+
}
158+
159+
////////////////////////////////////////////////////////////////////////////////////////////////////
160+
161+
void run_conversion_fp32_to_bf16(void* dst, void const* src, int s, int b, int h, int d) {
162+
// No need to expose the scale factor for FP16/FP32.
163+
size_t n = (size_t)s * b * h * d;
164+
convert_fp32_to_T_kernel<fmha::bf16_t><<<512, 256>>>(dst, src, n, 1.f);
165+
}
166+
167+
////////////////////////////////////////////////////////////////////////////////////////////////////
168+
169+
void run_conversion_fp32_to_e4m3(void* dst, void const* src, size_t n, float scale_o) {
170+
convert_fp32_to_T_kernel<fmha::e4m3_t><<<512, 256>>>(dst, src, n, scale_o);
171+
}
172+
173+
////////////////////////////////////////////////////////////////////////////////////////////////////
174+
175+
void run_conversion_e4m3_to_fp32(void* dst, void const* src, size_t n, float scale_o) {
176+
convert_T_to_fp32_kernel<fmha::e4m3_t><<<512, 256>>>(dst, src, n, scale_o);
177+
}
178+
179+
////////////////////////////////////////////////////////////////////////////////////////////////////
180+
181+
void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d,
182+
float scale_o) {
183+
run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o);
184+
}
185+
186+
////////////////////////////////////////////////////////////////////////////////////////////////////
187+
188+
void run_conversion_fp32_to_e5m2(void* dst, void const* src, size_t n, float scale_o) {
189+
convert_fp32_to_T_kernel<fmha::e5m2_t><<<512, 256>>>(dst, src, n, scale_o);
190+
}
191+
192+
////////////////////////////////////////////////////////////////////////////////////////////////////
193+
194+
void run_conversion_e5m2_to_fp32(void* dst, void const* src, size_t n, float scale_o) {
195+
convert_T_to_fp32_kernel<fmha::e5m2_t><<<512, 256>>>(dst, src, n, scale_o);
196+
}

csrc/fmha_v2/fmha/alibi_params.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights
3+
* reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
4+
*
5+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6+
* property and proprietary rights in and to this material, related
7+
* documentation and any modifications thereto. Any use, reproduction,
8+
* disclosure or distribution of this material and related documentation
9+
* without an express license agreement from NVIDIA CORPORATION or
10+
* its affiliates is strictly prohibited.
11+
*/
12+
13+
#pragma once
14+
15+
namespace fmha {
16+
17+
struct AlibiParams {
18+
constexpr static int round_down_to_power_two(int x) {
19+
x = x | (x >> 1);
20+
x = x | (x >> 2);
21+
x = x | (x >> 4);
22+
x = x | (x >> 8);
23+
x = x | (x >> 16);
24+
return x - (x >> 1);
25+
}
26+
27+
AlibiParams() = default;
28+
29+
AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) {
30+
h_pow_2 = round_down_to_power_two(h);
31+
alibi_neg4_div_h = -4.0f / h_pow_2;
32+
}
33+
34+
AlibiParams(int h, int s, int tp_size, int rank, float scale_after_alibi = 1.f)
35+
: AlibiParams(h * tp_size, scale_after_alibi) {
36+
head_idx_offset = h * rank;
37+
sequence_pos_offset = s * rank;
38+
}
39+
40+
int h_pow_2{};
41+
float alibi_neg4_div_h{};
42+
float scale_after_alibi{};
43+
// Could be simplified to `int rank` derive the others as `num_heads * rank, s * rank` at
44+
// runtime, but this makes assumptions about the layout downstream
45+
// (e.g. downstream may only split across the head dimension, so s would be the full sequence)
46+
int head_idx_offset = 0;
47+
int sequence_pos_offset = 0;
48+
};
49+
50+
} // namespace fmha

0 commit comments

Comments
 (0)