Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,273 @@ __global__ void append_speculate_cache_neox_rope_kernel(
}
}

template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool IsFP8 = false>
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
// gqa_group_size, head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
// block_size, head_size // 2]
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
T* __restrict__ cache_k_scale,
T* __restrict__ cache_v_scale,
const float* q_norm_weight,
const float* k_norm_weight,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int block_size,
const float max_bound,
const float min_bound,
const int gqa_group_size,
const bool rope_3d,
const float rms_norm_eps) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane_id = tid % 32;
const int token_id = blockIdx.x;

const int bid = batch_id_per_token[token_id];

const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid] + token_id - start_token_idx;
if (write_seq_id == 0) return;
const int* block_table_now = block_tables + bid * max_blocks_per_seq;
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;

int cache_offset;
if (head_idx < num_heads) {
cache_offset = 0;
} else if (head_idx < num_heads + 2 * gqa_group_size) {
cache_offset = block_idx * gqa_group_size * block_size + (head_idx - num_heads) % gqa_group_size * block_size + block_offset;
}
T *cache_k_scale_now = cache_k_scale + cache_offset;
T *cache_v_scale_now = cache_v_scale + cache_offset;

float thread_m2 = 0.0f;
float warp_m2 = 0.0f;

if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;

LoadT src_vec;
LoadBiasT bias_vec;
LoadOutScaleT out_scale_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
const T* qkv_now = quant_qkv + token_id * hidden_size;
T* qkv_out_now = qkv_out + token_id * hidden_size;
#pragma unroll
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
head_bias += 32 * VecSize) {
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);

// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
bias_vec[2 * i] =
static_cast<T>(tmp1);
bias_vec[2 * i + 1] =
static_cast<T>(tmp2);
}
// qk norm
if (q_norm_weight) {
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / HeadDim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadOutScaleT q_norm_vec;
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(static_cast<float>(bias_vec[i]) * row_inv_var * q_norm_vec[i]);
}
}
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
}
} else if (head_idx < num_heads + 2 * gqa_group_size) {
// k
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size;

constexpr int K_VEC_SIZE = 4;
constexpr int HALF_K_VEC_SIZE = 2;
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
using LoadEmbT = AlignedVector<float, 1>;
LoadKVResT cache_vec;
LoadT src_vec1, src_vec2;
LoadBiasT bias_vec1, bias_vec2;
LoadOutScaleT out_scale_vec1, out_scale_vec2;
LoadEmbT cos_emb_vec1, cos_emb_vec2;
LoadEmbT sin_emb_vec1, sin_emb_vec2;

const T* qkv_now = quant_qkv + token_id * hidden_size;
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
T scale = T(1.0f);
const int k_head_idx = head_idx - num_heads;
const int v_head_idx = head_idx - num_heads - gqa_group_size;
if (head_idx < num_heads + gqa_group_size) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
}

float input_left = static_cast<float>(src_vec1[0]);
float input_right = static_cast<float>(src_vec1[1]);
if (head_idx < num_heads + gqa_group_size) {
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
bias_vec1[0] =
static_cast<T>(tmp1);
bias_vec1[1] =
static_cast<T>(tmp2);
} else {
bias_vec1[0] = static_cast<T>(input_left);
bias_vec1[1] = static_cast<T>(input_right);
}

input_left = static_cast<float>(src_vec2[0]);
input_right = static_cast<float>(src_vec2[1]);
if (head_idx < num_heads + gqa_group_size) {
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
bias_vec2[0] =
static_cast<T>(tmp1);
bias_vec2[1] =
static_cast<T>(tmp2);
} else {
bias_vec2[0] = static_cast<T>(input_left);
bias_vec2[1] = static_cast<T>(input_right);
}
if (k_norm_weight) {
if (head_idx < num_heads + gqa_group_size) {
LoadOutScaleT k_norm_vec1, k_norm_vec2;
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
// qk norm
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / HeadDim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);

for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
bias_vec1[i] = static_cast<T>(static_cast<float>(bias_vec1[i]) * row_inv_var * k_norm_vec1[i]);
bias_vec2[i] = static_cast<T>(static_cast<float>(bias_vec2[i]) * row_inv_var * k_norm_vec2[i]);
}
}
}
// reduce max, 1 head per warp
T local_max = -INFINITY;
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
local_max = __hmax(local_max, __habs(bias_vec1[i]));
local_max = __hmax(local_max, __habs(bias_vec2[i]));
}
#pragma unroll
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
}

scale = __hdiv(448, local_max);

if (lane_id == 0) {
if (head_idx < num_heads + gqa_group_size) {
cache_k_scale_now[0] = __hdiv(1, scale);
} else {
cache_v_scale_now[0] = __hdiv(1, scale);
}
}

#pragma unroll
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec1[i], max_bound, min_bound);
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec2[i], max_bound, min_bound);
}
if (head_idx < num_heads + gqa_group_size) {
const int start_block_16 =
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
const uint32_t tgt_cache_idx =
block_idx * gqa_group_size * block_size * HeadDim +
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
} else {
const uint32_t base_tgt_cache_idx =
block_idx * gqa_group_size * HeadDim * block_size +
kv_head_idx * HeadDim * block_size +
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
block_offset % 8 / 2 * 4 // per 4
+ block_offset % 16 / 8 * 2 // per 2
+ block_offset % 2; // per 1
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
value_cache[tgt_cache_idx1] = cache_vec[0];
value_cache[tgt_cache_idx2] = cache_vec[1];
value_cache[tgt_cache_idx3] = cache_vec[2];
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
}

template <typename T,
int VecSize = 4,
int RoundType = 0,
Expand Down
Loading
Loading