diff --git a/custom_ops/metax_ops/cache_kv_with_rope.cu b/custom_ops/metax_ops/cache_kv_with_rope.cu index 0f3e9a54e3b..ce748f7419e 100644 --- a/custom_ops/metax_ops/cache_kv_with_rope.cu +++ b/custom_ops/metax_ops/cache_kv_with_rope.cu @@ -17,238 +17,417 @@ #include #include "helper.h" -template -struct Converter; - -template <> -struct Converter<__half> { - // __half -> float - __device__ static float to_float(__half val) { return __half2float(val); } - // float -> __half - __device__ static __half from_float(float val) { - return __float2half_rn(val); - } - // int -> __half - __device__ static __half from_int(float val) { return __int2half_rn(val); } -}; - -template <> -struct Converter<__nv_bfloat16> { - // __nv_bfloat16 -> float - __device__ static float to_float(__nv_bfloat16 val) { - return __bfloat162float(val); - } - // float -> __nv_bfloat16 - __device__ static __nv_bfloat16 from_float(float val) { - return __float2bfloat16_rn(val); - } - // int -> __nv_bfloat16 - __device__ static __nv_bfloat16 from_int(int val) { - return __int2bfloat16_rn(val); - } -}; - struct CacheKVWithRopeParams { + int64_t linear_elem_num; + int linear_stride; int head_dim; int block_size; int block_num; int cache_stride; int token_stride; - int head_stride; int q_stride; int kv_stride; - int q_head_offset; int k_head_offset; int v_head_offset; int q_head_num; int kv_head_num; + int rotary_stride; // 1 * S * 1 * D / 2 or 1 * S * 1 * D(if neox) + int batch_rotary_stride; // 2 * rotary_stride }; -template -__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr, - const T* rotary_cos_ptr, - const T* rotary_sin_ptr, - const int load_idx, - const int store_idx, - const int cache_store_idx, - const int rot_base_idx, - T* caches, - T* out) { - using VecT = AlignedVector; - - VecT qk_vec; +template +__device__ __forceinline__ void RotateQKVec( + const T* __restrict__ qkv_ptr, + const float* __restrict__ rotary_embs_ptr, + const int load_idx, + const int store_idx, + const int cache_store_idx, + T* __restrict__ caches, + T* __restrict__ out) { + using VecQKV = AlignedVector; + const int SIN_OFFSET = VecSize / 2; + + VecQKV qk_vec, qk_out_vec; Load(qkv_ptr + load_idx, &qk_vec); - VecT rot_half_vec; - int flag; -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - flag = 1 - 2 * (i % 2); - rot_half_vec[i] = -qk_vec[i + flag] * Converter::from_int(flag); - } - VecT cos_vec, sin_vec; - Load(rotary_cos_ptr + rot_base_idx, &cos_vec); - Load(rotary_sin_ptr + rot_base_idx, &sin_vec); + #pragma unroll - for (int i = 0; i < VecSize; ++i) { - T result = qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i]; - *(out + store_idx + i) = result; + for (int i = 0; i < VecSize; i += 2) { + float q0 = static_cast(qk_vec[i]); + float q1 = static_cast(qk_vec[i + 1]); - if (WriteCache) { - *(caches + cache_store_idx + i) = result; - } - } -} + float cos_val = rotary_embs_ptr[i >> 1]; + float sin_val = rotary_embs_ptr[SIN_OFFSET + (i >> 1)]; -template -__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr, - const float* rotary_cos_ptr, - const float* rotary_sin_ptr, - const int load_idx, - const int store_idx, - const int cache_store_idx, - const int rot_base_idx, - T* caches, - T* out) { - using VecT = AlignedVector; - using VecF = AlignedVector; - auto to_float = [] __device__(T val) -> float { - return Converter::to_float(val); - }; - auto from_float = [] __device__(float val) -> T { - return Converter::from_float(val); - }; - - VecT qk_vec; - Load(qkv_ptr + load_idx, &qk_vec); - VecF rot_half_vec; - int flag; -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - flag = 1 - 2 * (i % 2); - rot_half_vec[i] = -to_float(qk_vec[i + flag]) * static_cast(flag); + qk_out_vec[i] = static_cast(q0 * cos_val - q1 * sin_val); + qk_out_vec[i + 1] = static_cast(q1 * cos_val + q0 * sin_val); } - VecF cos_vec, sin_vec; - Load(rotary_cos_ptr + rot_base_idx, &cos_vec); - Load(rotary_sin_ptr + rot_base_idx, &sin_vec); -#pragma unroll - for (int i = 0; i < VecSize; ++i) { - T result = from_float(to_float(qk_vec[i]) * cos_vec[i] + - rot_half_vec[i] * sin_vec[i]); - *(out + store_idx + i) = result; - if (WriteCache) { - *(caches + cache_store_idx + i) = result; - } + + Store(qk_out_vec, out + store_idx); + if constexpr (WriteCache) { + Store(qk_out_vec, caches + cache_store_idx); } } -template -__device__ __forceinline__ void StoreValue(const T* qkv_ptr, +template +__device__ __forceinline__ void StoreValue(const T* __restrict__ qkv_ptr, const int load_idx, const int store_idx, const int cache_store_idx, - T* caches, - T* out) { + T* __restrict__ caches, + T* __restrict__ out) { using VecT = AlignedVector; VecT v_vec; Load(qkv_ptr + load_idx, &v_vec); Store(v_vec, out + store_idx); - Store(v_vec, caches + cache_store_idx); + if constexpr (WriteCache) { + Store(v_vec, caches + cache_store_idx); + } +} + +template +__device__ __forceinline__ void RotateQKVecNeox( + const T* __restrict__ qkv_ptr, + const float* __restrict__ rotary_embs_ptr, + const int left_load_idx, + const int right_load_idx, + const int left_store_idx, + const int right_store_idx, + const int left_cache_store_idx, + const int right_cache_store_idx, + T* __restrict__ caches, + T* __restrict__ out) { + using VecQKV = AlignedVector; + constexpr int SIN_OFFSET = VecSize; + + VecQKV left_vec, right_vec, left_out_vec, right_out_vec; + + Load(qkv_ptr + left_load_idx, &left_vec); + Load(qkv_ptr + right_load_idx, &right_vec); + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + float l_val = static_cast(left_vec[i]); + float r_val = static_cast(right_vec[i]); + + float cos_val = rotary_embs_ptr[i]; + float sin_val = rotary_embs_ptr[SIN_OFFSET + i]; + + left_out_vec[i] = static_cast(l_val * cos_val - r_val * sin_val); + right_out_vec[i] = static_cast(r_val * cos_val + l_val * sin_val); + } + + Store(left_out_vec, out + left_store_idx); + Store(right_out_vec, out + right_store_idx); + + if constexpr (WriteCache) { + Store(left_out_vec, caches + left_cache_store_idx); + Store(right_out_vec, caches + right_cache_store_idx); + } +} + +template +__device__ __forceinline__ void StoreValueNeox(const T* __restrict__ qkv_ptr, + const int left_load_idx, + const int right_load_idx, + const int left_store_idx, + const int right_store_idx, + const int left_cache_store_idx, + const int right_cache_store_idx, + T* __restrict__ caches, + T* __restrict__ out) { + using VecT = AlignedVector; + VecT left_v_vec, right_v_vec; + Load(qkv_ptr + left_load_idx, &left_v_vec); + Load(qkv_ptr + right_load_idx, &right_v_vec); + Store(left_v_vec, out + left_store_idx); + Store(right_v_vec, out + right_store_idx); + if constexpr (WriteCache) { + Store(left_v_vec, caches + left_cache_store_idx); + Store(right_v_vec, caches + right_cache_store_idx); + } } -template -__global__ void DispatchCacheKVWithRopeVecKernel(const T* qkv, - T* caches_k, - T* caches_v, - const int* block_tables, - const WeightType* rotary_cos, - const WeightType* rotary_sin, - const int* cu_seqlens_q, - const int* batch_ids_q, - CacheKVWithRopeParams param, - T* q_out, - T* k_out, - T* v_out) { - const int token_idx = blockIdx.x * blockDim.x + threadIdx.x; - const int head_idx = blockIdx.y * blockDim.y + threadIdx.y; - const int head_dim_idx = (blockIdx.z * blockDim.z + threadIdx.z) * VecSize; - - int load_idx, store_idx, cache_store_idx; - int rot_idx = token_idx * param.head_dim + head_dim_idx; - - const int batch_idx = *(batch_ids_q + token_idx); - const int inter_batch_token_offset = token_idx - *(cu_seqlens_q + batch_idx); - const int inter_batch_block_idx = inter_batch_token_offset / param.block_size; - const int inter_block_offset = inter_batch_token_offset % param.block_size; - const int block_idx = - *(block_tables + batch_idx * param.block_num + inter_batch_block_idx); - - assert(block_idx != -1); - - if (head_dim_idx < param.head_dim) { - if (head_idx < param.q_head_num) { // q - load_idx = token_idx * param.token_stride + - (head_idx + param.q_head_offset) * param.head_stride + - head_dim_idx; - store_idx = - token_idx * param.q_stride + head_idx * param.head_dim + head_dim_idx; - RotateQKVec(qkv, - rotary_cos, - rotary_sin, - load_idx, - store_idx, - -1, - rot_idx, - static_cast(nullptr), - q_out); +struct CacheKVIndices { + // 线程块索引 + int token_idx; + int head_idx; + int head_dim_idx; + + // RoPE 旋转索引 + int rotary_cos_idx; + int rotary_sin_idx; + + // 全局内存 Load/Store 索引 + int load_idx[3]; // q, k, v + int store_idx[3]; // q, kv + + // KV Cache 存储索引 (根据模板参数计算,但自身仍是 int 类型) + int cache_store_idx; + int right_cache_store_idx; +}; + +// 辅助函数:计算所有索引 +template +__device__ void GetIndices(int64_t linear_index, + const int half_head_dim, + const int* __restrict__ batch_ids_per_token, + const int* __restrict__ global_batch_ids, + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seqlens_q, + const int* __restrict__ block_tables, + const CacheKVWithRopeParams& param, + CacheKVIndices& indices) { + // ********** 1. linear index -> 3D index ********** + if constexpr (NeoxStyle) { + int linear_stride_half = (param.linear_stride >> 1); + int head_dim_half = (param.head_dim >> 1); + indices.token_idx = linear_index / linear_stride_half; + indices.head_idx = (linear_index % linear_stride_half) / head_dim_half; + indices.head_dim_idx = linear_index % head_dim_half; + } else { + indices.token_idx = linear_index / param.linear_stride; + indices.head_idx = (linear_index % param.linear_stride) / param.head_dim; + indices.head_dim_idx = linear_index % param.head_dim; + } + + // ********** 2. QKV Load Index ********** + indices.load_idx[0] = indices.token_idx * param.token_stride + + indices.head_idx * param.head_dim + + indices.head_dim_idx; + indices.load_idx[1] = + indices.load_idx[0] + param.k_head_offset * param.head_dim; + indices.load_idx[2] = + indices.load_idx[0] + param.v_head_offset * param.head_dim; + + // ********** 3. Batch and Seq Index ********** + const int local_batch_idx = *(batch_ids_per_token + indices.token_idx); + const int global_batch_idx = *(global_batch_ids + local_batch_idx); + const int inter_batch_token_offset = indices.token_idx + + *(seqlens_q + local_batch_idx) - + *(cu_seqlens_q + local_batch_idx); + + // ********** 4. RoPE Index ********** + if constexpr (!NeoxStyle) { + indices.rotary_cos_idx = global_batch_idx * param.batch_rotary_stride + + inter_batch_token_offset * half_head_dim; + } else { + indices.rotary_cos_idx = global_batch_idx * param.batch_rotary_stride + + inter_batch_token_offset * param.head_dim; + } + + if constexpr (!NeoxStyle) { + indices.rotary_cos_idx += (indices.head_dim_idx >> 1); + } else { + indices.rotary_cos_idx += indices.head_dim_idx % half_head_dim; + } + indices.rotary_sin_idx = indices.rotary_cos_idx + param.rotary_stride; + + // ********** 5. QKV Store Index ********** + indices.store_idx[0] = indices.token_idx * param.q_stride + + indices.head_idx * param.head_dim + + indices.head_dim_idx; + indices.store_idx[1] = indices.token_idx * param.kv_stride + + indices.head_idx * param.head_dim + + indices.head_dim_idx; + indices.store_idx[2] = indices.store_idx[1]; + + // ********** 6. KV Cache Store Index (仅 WriteCache) ********** + indices.cache_store_idx = -1; + indices.right_cache_store_idx = -1; + + if constexpr (WriteCache) { + const int inter_batch_block_idx = + inter_batch_token_offset / param.block_size; + const int inter_block_offset = inter_batch_token_offset % param.block_size; + const int block_idx = *(block_tables + global_batch_idx * param.block_num + + inter_batch_block_idx); + + assert(block_idx != -1); + + indices.cache_store_idx = + block_idx * param.cache_stride + inter_block_offset * param.kv_stride + + indices.head_idx * param.head_dim + indices.head_dim_idx; + + if constexpr (NeoxStyle) { + indices.right_cache_store_idx = indices.cache_store_idx + half_head_dim; + } + } +} + +template +__device__ inline void preload_rotary( + const WeightType* __restrict__ rotary_embs, + const int rotary_cos_idx, + const int rotary_sin_idx, + float* __restrict__ rotary_embs_vec) { + using VecRotary = AlignedVector; + + VecRotary* rotary_cos_vec = reinterpret_cast(rotary_embs_vec); + VecRotary* rotary_sin_vec = + reinterpret_cast(rotary_embs_vec + RotaryVecSize); + + if constexpr (std::is_same_v) { + Load(rotary_embs + rotary_cos_idx, rotary_cos_vec); + Load(rotary_embs + rotary_sin_idx, rotary_sin_vec); + } else { +#pragma unroll + for (int i = 0; i < RotaryVecSize; ++i) { + (*rotary_cos_vec)[i] = + static_cast(__ldg(rotary_embs + rotary_cos_idx + i)); + (*rotary_sin_vec)[i] = + static_cast(__ldg(rotary_embs + rotary_sin_idx + i)); + } + } +} + +template +__global__ void DispatchCacheKVWithRopeVecKernel( + const T* __restrict__ qkv, + const WeightType* __restrict__ rotary_embs, + const int* __restrict__ batch_ids_per_token, + const int* __restrict__ global_batch_ids, + const int* __restrict__ cu_seqlens_q, + const int* __restrict__ seqlens_q, + T* __restrict__ caches_k, + T* caches_v, + const int* __restrict__ block_tables, + CacheKVWithRopeParams param, + T* __restrict__ q_out, + T* __restrict__ k_out, + T* __restrict__ v_out) { + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int half_head_dim = (param.head_dim >> 1); + int64_t max_elements = + NeoxStyle ? (param.linear_elem_num >> 1) : param.linear_elem_num; + + constexpr int VecRotarySize2 = NeoxStyle ? VecSize * 2 : VecSize; + using VecRotary2 = AlignedVector; + VecRotary2 rotary_embs_vec; + float* rotary_embs_vec_ptr = reinterpret_cast(&rotary_embs_vec); + + // Grid Stride Loop + for (int64_t linear_index = global_thread_idx * VecSize, + step = (int64_t)gridDim.x * blockDim.x * VecSize; + linear_index < max_elements; + linear_index += step) { + // ********** 索引计算 ********** + CacheKVIndices indices; + GetIndices(linear_index, + half_head_dim, + batch_ids_per_token, + global_batch_ids, + cu_seqlens_q, + seqlens_q, + block_tables, + param, + indices); + + preload_rotary(rotary_embs, + indices.rotary_cos_idx, + indices.rotary_sin_idx, + rotary_embs_vec_ptr); + + if (indices.head_idx < param.q_head_num) { + // ********** 1. Q 向量旋转与存储 ********** + if constexpr (!NeoxStyle) { + RotateQKVec(qkv, + rotary_embs_vec_ptr, + indices.load_idx[0], + indices.store_idx[0], + -1, + static_cast(nullptr), + q_out); + } else { + int right_load_idx = indices.load_idx[0] + half_head_dim; + int right_store_idx = indices.store_idx[0] + half_head_dim; + RotateQKVecNeox(qkv, + rotary_embs_vec_ptr, + indices.load_idx[0], + right_load_idx, + indices.store_idx[0], + right_store_idx, + -1, + -1, + static_cast(nullptr), + q_out); + } } - if (head_idx < param.kv_head_num) { // kv - load_idx = token_idx * param.token_stride + - (head_idx + param.k_head_offset) * param.head_stride + - head_dim_idx; - store_idx = token_idx * param.kv_stride + head_idx * param.head_dim + - head_dim_idx; - cache_store_idx = block_idx * param.cache_stride + - inter_block_offset * param.kv_stride + - head_idx * param.head_dim + head_dim_idx; - // printf("block_idx: %d inter_block_offset: %d cache_store_idx: %d - // param.cache_stride: %d\n", block_idx, inter_block_offset, - // cache_store_idx, param.cache_stride); - RotateQKVec(qkv, - rotary_cos, - rotary_sin, - load_idx, - store_idx, - cache_store_idx, - rot_idx, - caches_k, - k_out); - - load_idx = token_idx * param.token_stride + - (head_idx + param.v_head_offset) * param.head_stride + - head_dim_idx; - StoreValue( - qkv, load_idx, store_idx, cache_store_idx, caches_v, v_out); + if (indices.head_idx < param.kv_head_num) { + // ********** 2. K 向量旋转与存储/缓存 ********** + if constexpr (!NeoxStyle) { + RotateQKVec(qkv, + rotary_embs_vec_ptr, + indices.load_idx[1], + indices.store_idx[1], + indices.cache_store_idx, + caches_k, + k_out); + } else { + int right_load_idx = indices.load_idx[1] + half_head_dim; + int right_store_idx = indices.store_idx[1] + half_head_dim; + RotateQKVecNeox(qkv, + rotary_embs_vec_ptr, + indices.load_idx[1], + right_load_idx, + indices.store_idx[1], + right_store_idx, + indices.cache_store_idx, + indices.right_cache_store_idx, + caches_k, + k_out); + } + + // ********** 3. V 向量直通与存储/缓存 ********** + if constexpr (!NeoxStyle) { + StoreValue(qkv, + indices.load_idx[2], + indices.store_idx[2], + indices.cache_store_idx, + caches_v, + v_out); + } else { + int right_load_idx = indices.load_idx[2] + half_head_dim; + int right_store_idx = indices.store_idx[2] + half_head_dim; + StoreValueNeox(qkv, + indices.load_idx[2], + right_load_idx, + indices.store_idx[2], + right_store_idx, + indices.cache_store_idx, + indices.right_cache_store_idx, + caches_v, + v_out); + } } } } -template +template void CacheKVWithRopeKernel( const paddle::Tensor& qkv, // token_num, head_num * head_dim - paddle::Tensor& + const paddle::Tensor& + rotary_embs, // [2, 1, max_seqlens, 1, half_head_dim(head_dim if neox)] + // or [bs, 2, 1, max_seqlens, 1, half_head_dim(head_dim if + // neox)] + const paddle::Tensor& batch_ids_per_token, // token_num + const paddle::Tensor& global_batch_ids, + const paddle::Tensor& cu_seqlens_q, // bs + 1 + const paddle::Tensor& seqlens_q, // bs + paddle::optional& caches_k, // max_block_num, block_size, kv_head_num, head_dim - paddle::Tensor& + paddle::optional& caches_v, // max_block_num, block_size, kv_head_num, head_dim - const paddle::Tensor& block_tables, // bs, block_num - const paddle::Tensor& rotary_cos, - const paddle::Tensor& rotary_sin, - const paddle::Tensor& cu_seqlens_q, // bs + 1 - const paddle::Tensor& batch_ids_q, // token_num + const paddle::optional& block_tables, // bs, block_num const int q_head_num, const int kv_head_num, const int head_dim, const int block_size, + const bool neox_style, paddle::Tensor& q_out, paddle::Tensor& k_out, paddle::Tensor& v_out) { @@ -256,68 +435,150 @@ void CacheKVWithRopeKernel( typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; - const int all_num_elements = qkv.numel(); + const int64_t linear_elem_num = + qkv.shape()[0] * std::max(q_head_num, kv_head_num) * head_dim; const int all_num_heads = q_head_num + 2 * kv_head_num; auto stream = qkv.stream(); - dim3 block_dims(1, 4, (head_dim + VecSize - 1) / VecSize); - dim3 grid_dims(all_num_elements / (all_num_heads * head_dim), // token - (std::max(q_head_num, kv_head_num) + block_dims.y - 1) / - block_dims.y, // head - (head_dim + (block_dims.z * VecSize) - 1) / - (block_dims.z * VecSize) // dim: load Vec at a time - ); + const int pack_size = neox_style ? (VecSize * 2) : VecSize; + const int pack_num = linear_elem_num / pack_size; + const int block_dims = 128; + int grid_dims = 1; + GetNumBlocks(pack_num, &grid_dims); // printf("grid: (%d, %d, %d)\n", grid_dims.x, grid_dims.y, grid_dims.z); // printf("block: (%d, %d, %d)\n", block_dims.x, block_dims.y, block_dims.z); CacheKVWithRopeParams param; + param.linear_elem_num = linear_elem_num; + param.linear_stride = std::max(q_head_num, kv_head_num) * head_dim; param.head_dim = head_dim; param.block_size = block_size; - param.block_num = static_cast(block_tables.shape().back()); + if (block_tables) { + param.block_num = static_cast(block_tables.get_ptr()->shape().back()); + } else { + param.block_num = -1; + } param.cache_stride = block_size * kv_head_num * head_dim; param.token_stride = all_num_heads * head_dim; - param.head_stride = head_dim; param.q_stride = q_head_num * head_dim; param.kv_stride = kv_head_num * head_dim; - param.q_head_offset = 0; param.k_head_offset = q_head_num; param.v_head_offset = q_head_num + kv_head_num; param.q_head_num = q_head_num; param.kv_head_num = kv_head_num; + const auto rotary_embs_shape = rotary_embs.shape(); + if (!neox_style) { + if (rotary_embs_shape.size() == 5) { + param.rotary_stride = rotary_embs_shape[2] * head_dim / 2; + param.batch_rotary_stride = 0; + } else { + param.rotary_stride = rotary_embs_shape[3] * head_dim / 2; + param.batch_rotary_stride = 2 * param.rotary_stride; + } + } else { + if (rotary_embs_shape.size() == 5) { + param.rotary_stride = rotary_embs_shape[2] * head_dim; + param.batch_rotary_stride = 0; + } else { + param.rotary_stride = rotary_embs_shape[3] * head_dim; + param.batch_rotary_stride = 2 * param.rotary_stride; + } + } - if (qkv.dtype() == rotary_cos.dtype()) { - DispatchCacheKVWithRopeVecKernel - <<>>( - reinterpret_cast(qkv.data()), - reinterpret_cast(caches_k.data()), - reinterpret_cast(caches_v.data()), - reinterpret_cast(block_tables.data()), - reinterpret_cast(rotary_cos.data()), - reinterpret_cast(rotary_sin.data()), - reinterpret_cast(cu_seqlens_q.data()), - reinterpret_cast(batch_ids_q.data()), - param, - reinterpret_cast(q_out.data()), - reinterpret_cast(k_out.data()), - reinterpret_cast(v_out.data())); - } else if (rotary_cos.dtype() == paddle::DataType::FLOAT32) { - DispatchCacheKVWithRopeVecKernel - <<>>( - reinterpret_cast(qkv.data()), - reinterpret_cast(caches_k.data()), - reinterpret_cast(caches_v.data()), - reinterpret_cast(block_tables.data()), - reinterpret_cast(rotary_cos.data()), - reinterpret_cast(rotary_sin.data()), - reinterpret_cast(cu_seqlens_q.data()), - reinterpret_cast(batch_ids_q.data()), - param, - reinterpret_cast(q_out.data()), - reinterpret_cast(k_out.data()), - reinterpret_cast(v_out.data())); +#define APPLY_ROPE_AND_WRITE_CACHE(DATATYPE, DATA_T) \ + DispatchCacheKVWithRopeVecKernel \ + <<>>( \ + reinterpret_cast(qkv.data()), \ + reinterpret_cast(rotary_embs.data()), \ + reinterpret_cast(batch_ids_per_token.data()), \ + reinterpret_cast(global_batch_ids.data()), \ + reinterpret_cast(cu_seqlens_q.data()), \ + reinterpret_cast(seqlens_q.data()), \ + reinterpret_cast(caches_k.get_ptr()->data()), \ + reinterpret_cast(caches_v.get_ptr()->data()), \ + reinterpret_cast(block_tables.get_ptr()->data()), \ + param, \ + reinterpret_cast(q_out.data()), \ + reinterpret_cast(k_out.data()), \ + reinterpret_cast(v_out.data())); + +#define APPLY_ROPE(DATATYPE, DATA_T) \ + DispatchCacheKVWithRopeVecKernel \ + <<>>( \ + reinterpret_cast(qkv.data()), \ + reinterpret_cast(rotary_embs.data()), \ + reinterpret_cast(batch_ids_per_token.data()), \ + reinterpret_cast(global_batch_ids.data()), \ + reinterpret_cast(cu_seqlens_q.data()), \ + reinterpret_cast(seqlens_q.data()), \ + static_cast(nullptr), \ + static_cast(nullptr), \ + static_cast(nullptr), \ + param, \ + reinterpret_cast(q_out.data()), \ + reinterpret_cast(k_out.data()), \ + reinterpret_cast(v_out.data())); + +#define APPLY_ROPE_AND_WRITE_CACHE_NEOX(DATATYPE, DATA_T) \ + DispatchCacheKVWithRopeVecKernel \ + <<>>( \ + reinterpret_cast(qkv.data()), \ + reinterpret_cast(rotary_embs.data()), \ + reinterpret_cast(batch_ids_per_token.data()), \ + reinterpret_cast(global_batch_ids.data()), \ + reinterpret_cast(cu_seqlens_q.data()), \ + reinterpret_cast(seqlens_q.data()), \ + reinterpret_cast(caches_k.get_ptr()->data()), \ + reinterpret_cast(caches_v.get_ptr()->data()), \ + reinterpret_cast(block_tables.get_ptr()->data()), \ + param, \ + reinterpret_cast(q_out.data()), \ + reinterpret_cast(k_out.data()), \ + reinterpret_cast(v_out.data())); + +#define APPLY_ROPE_NEOX(DATATYPE, DATA_T) \ + DispatchCacheKVWithRopeVecKernel \ + <<>>( \ + reinterpret_cast(qkv.data()), \ + reinterpret_cast(rotary_embs.data()), \ + reinterpret_cast(batch_ids_per_token.data()), \ + reinterpret_cast(global_batch_ids.data()), \ + reinterpret_cast(cu_seqlens_q.data()), \ + reinterpret_cast(seqlens_q.data()), \ + static_cast(nullptr), \ + static_cast(nullptr), \ + static_cast(nullptr), \ + param, \ + reinterpret_cast(q_out.data()), \ + reinterpret_cast(k_out.data()), \ + reinterpret_cast(v_out.data())); + +#define DISPATCH_CASE(WRITE_CACHE_FUNC, APPLY_ROPE_FUNC) \ + if (caches_k && caches_v) { \ + if (qkv.dtype() == rotary_embs.dtype()) { \ + WRITE_CACHE_FUNC(DataType_, data_t) \ + } else if (rotary_embs.dtype() == paddle::DataType::FLOAT32) { \ + WRITE_CACHE_FUNC(float, float) \ + } else { \ + PD_THROW( \ + "qk dtype and rope dtype should be equal or rope dtype is float"); \ + } \ + } else { \ + if (qkv.dtype() == rotary_embs.dtype()) { \ + APPLY_ROPE_FUNC(DataType_, data_t) \ + } else if (rotary_embs.dtype() == paddle::DataType::FLOAT32) { \ + APPLY_ROPE_FUNC(float, float) \ + } else { \ + PD_THROW( \ + "qk dtype and rope dtype should be equal or rope dtype is float"); \ + } \ + } + + if (neox_style) { + DISPATCH_CASE(APPLY_ROPE_AND_WRITE_CACHE_NEOX, APPLY_ROPE_NEOX) } else { - PD_THROW("Unsupported qk dtype and rope dtype."); + DISPATCH_CASE(APPLY_ROPE_AND_WRITE_CACHE, APPLY_ROPE) } cudaError_t err = cudaGetLastError(); @@ -328,34 +589,38 @@ void CacheKVWithRopeKernel( std::vector CacheKVWithRope( const paddle::Tensor& qkv, // token_num, head_num * head_dim - paddle::Tensor& + const paddle::Tensor& rotary_embs, + const paddle::Tensor& batch_ids_per_token, + const paddle::Tensor& global_batch_ids, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& seqlens_q, + paddle::optional& caches_k, // max_block_num, block_size, kv_head_num, head_dim - paddle::Tensor& + paddle::optional& caches_v, // max_block_num, block_size, kv_head_num, head_dim - const paddle::Tensor& block_tables, // bs, block_num - const paddle::Tensor& rotary_cos, - const paddle::Tensor& rotary_sin, - const paddle::Tensor& cu_seqlens_q, // bs + 1 - const paddle::Tensor& batch_ids_q, // token_num + const paddle::optional& block_tables, // bs, block_num const int q_head_num, const int kv_head_num, const int head_dim, - const int block_size) { + const int block_size, + const int out_dims, + const bool neox_style) { auto qkv_shape = qkv.shape(); auto token_num = qkv_shape[0]; auto place = qkv.place(); auto dtype = qkv.dtype(); - common::DDim q_out_shape, kv_out_shape; - if (rotary_cos.shape().size() == 3) { - q_out_shape = {token_num, q_head_num, head_dim}; - kv_out_shape = {token_num, kv_head_num, head_dim}; + + paddle::Tensor q_out, k_out, v_out; + PD_CHECK(out_dims == 3 || out_dims == 4); + if (out_dims == 3) { + q_out = GetEmptyTensor({token_num, q_head_num, head_dim}, dtype, place); + k_out = GetEmptyTensor({token_num, kv_head_num, head_dim}, dtype, place); + v_out = GetEmptyTensor({token_num, kv_head_num, head_dim}, dtype, place); } else { - q_out_shape = {token_num, 1, q_head_num, head_dim}; - kv_out_shape = {token_num, 1, kv_head_num, head_dim}; + q_out = GetEmptyTensor({token_num, 1, q_head_num, head_dim}, dtype, place); + k_out = GetEmptyTensor({token_num, 1, kv_head_num, head_dim}, dtype, place); + v_out = GetEmptyTensor({token_num, 1, kv_head_num, head_dim}, dtype, place); } - auto q_out = GetEmptyTensor(q_out_shape, dtype, place); - auto k_out = GetEmptyTensor(kv_out_shape, dtype, place); - auto v_out = GetEmptyTensor(kv_out_shape, dtype, place); if (token_num == 0) { return {q_out, k_out, v_out}; @@ -373,48 +638,52 @@ std::vector CacheKVWithRope( "The last dimension (head_dim) of qkv must be an even number " "for RoPE, but got %d", head_dim); - PADDLE_ENFORCE_EQ(q_out.shape().back(), - rotary_cos.shape().back(), - "The last dimension of cos mismatches that of q, " - "expect %d but got %d", - q_out.shape().back(), - rotary_cos.shape().back()); + if (!neox_style) { + PADDLE_ENFORCE_EQ((q_out.shape().back() / 2), + rotary_embs.shape().back(), + "The last dimension of cos mismatches that half of q, " + "expect %d but got %d", + (q_out.shape().back() / 2), + rotary_embs.shape().back()); + } else { + PADDLE_ENFORCE_EQ((q_out.shape().back()), + rotary_embs.shape().back(), + "The last dimension of cos mismatches that head_dim, " + "expect %d but got %d", + (q_out.shape().back()), + rotary_embs.shape().back()); + } + + if (caches_k && caches_v) { + if (!block_tables) { + PD_THROW("block_tables should have value if writing into cache."); + } + } + +#define KERNEL_CASE(DTYPE) \ + case DTYPE: \ + CacheKVWithRopeKernel(qkv, \ + rotary_embs, \ + batch_ids_per_token, \ + global_batch_ids, \ + cu_seqlens_q, \ + seqlens_q, \ + caches_k, \ + caches_v, \ + block_tables, \ + q_head_num, \ + kv_head_num, \ + head_dim, \ + block_size, \ + neox_style, \ + q_out, \ + k_out, \ + v_out); \ + break; switch (dtype) { - case paddle::DataType::BFLOAT16: - CacheKVWithRopeKernel(qkv, - caches_k, - caches_v, - block_tables, - rotary_cos, - rotary_sin, - cu_seqlens_q, - batch_ids_q, - q_head_num, - kv_head_num, - head_dim, - block_size, - q_out, - k_out, - v_out); - break; - case paddle::DataType::FLOAT16: - CacheKVWithRopeKernel(qkv, - caches_k, - caches_v, - block_tables, - rotary_cos, - rotary_sin, - cu_seqlens_q, - batch_ids_q, - q_head_num, - kv_head_num, - head_dim, - block_size, - q_out, - k_out, - v_out); - break; + KERNEL_CASE(paddle::DataType::BFLOAT16) + KERNEL_CASE(paddle::DataType::FLOAT16) default: PD_THROW("Only support qk dtype of BF16 and F16"); } @@ -424,54 +693,59 @@ std::vector CacheKVWithRope( std::vector> CacheKVWithRopeInferShape( const std::vector& qkv_shape, - const std::vector& caches_k_shape, - const std::vector& caches_v_shape, - const std::vector& block_tables_shape, - const std::vector& cos_shape, - const std::vector& sin_shape, + const std::vector& rotary_embs_shape, + const std::vector& batch_ids_per_token_shape, + const std::vector& global_batch_ids_shape, const std::vector& cu_seqlens_q_shape, - const std::vector& batch_ids_q_shape) { + const std::vector& seqlens_q_shape, + const paddle::optional>& caches_k_shape, + const paddle::optional>& caches_v_shape, + const paddle::optional>& block_tables_shape) { return {qkv_shape, - caches_k_shape, - caches_v_shape, - block_tables_shape, - cos_shape, - sin_shape, + rotary_embs_shape, + batch_ids_per_token_shape, + global_batch_ids_shape, cu_seqlens_q_shape, - batch_ids_q_shape}; + seqlens_q_shape}; } std::vector CacheKVWithRopeInferDtype( const paddle::DataType& qkv_dtype, - const paddle::DataType& caches_k_dtype, - const paddle::DataType& caches_v_dtype, - const paddle::DataType& block_tables_dtype, - const paddle::DataType& cos_dtype, - const paddle::DataType& sin_dtype, + const paddle::DataType& rotary_embs_dtype, + const paddle::DataType& batch_ids_per_token_dtype, + const paddle::DataType& global_batch_ids_dtype, const paddle::DataType& cu_seqlens_q_dtype, - const paddle::DataType& batch_ids_q_dtype) { + const paddle::DataType& seqlens_q_dtype, + const paddle::optional& caches_k_dtype, + const paddle::optional& caches_v_dtype, + const paddle::optional& block_tables_dtype) { return {qkv_dtype, - caches_k_dtype, - caches_v_dtype, - block_tables_dtype, - cos_dtype, - sin_dtype, + rotary_embs_dtype, + batch_ids_per_token_dtype, + global_batch_ids_dtype, cu_seqlens_q_dtype, - batch_ids_q_dtype}; + seqlens_q_dtype}; } PD_BUILD_OP(cache_kv_with_rope) - .Inputs({"qkv", - "caches_k", - "caches_v", - "block_tables", - "rotary_cos", - "rotary_sin", - "cu_seqlen_q", - "batch_ids_q"}) + .Inputs({ + "qkv", + "rotary_embs", + "batch_ids_per_token", + "global_batch_ids", + "cu_seqlens_q", + "seqlens_q", + paddle::Optional("caches_k"), + paddle::Optional("caches_v"), + paddle::Optional("block_tables"), + }) .Outputs({"q_out", "k_out", "v_out"}) - .Attrs( - {"q_head_num:int", "kv_head_num:int", "head_dim:int", "block_size:int"}) + .Attrs({"q_head_num:int", + "kv_head_num:int", + "head_dim:int", + "block_size:int", + "out_dims:int", + "neox_style:bool"}) .SetKernelFn(PD_KERNEL(CacheKVWithRope)) .SetInferShapeFn(PD_INFER_SHAPE(CacheKVWithRopeInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(CacheKVWithRopeInferDtype)); diff --git a/custom_ops/metax_ops/split_merge_qkv.cu b/custom_ops/metax_ops/split_merge_qkv.cu new file mode 100644 index 00000000000..881563aca93 --- /dev/null +++ b/custom_ops/metax_ops/split_merge_qkv.cu @@ -0,0 +1,246 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "helper.h" + +template +__global__ void RunDispatchQKV(const T* __restrict__ src_ptr_0, + const T* __restrict__ src_ptr_1, + const int* __restrict__ meta_ptr, + const int group_num, + const int hidden_dims, + const int64_t max_elements, + T* __restrict__ dst_ptr_0, + T* __restrict__ dst_ptr_1) { + extern __shared__ int s_meta[]; + for (int i = threadIdx.x; i < group_num * 5; i += blockDim.x) { + s_meta[i] = meta_ptr[i]; + } + __syncthreads(); + + using VecT = AlignedVector; + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + int64_t step = (int64_t)gridDim.x * blockDim.x * VecSize; + + const T* src_ptrs[2] = {src_ptr_0, src_ptr_1}; + T* dst_ptrs[2] = {dst_ptr_0, dst_ptr_1}; + + for (int64_t linear_index = global_thread_idx * VecSize; + linear_index < max_elements; + linear_index += step) { + int token_idx = linear_index / hidden_dims; + int hidden_idx = linear_index % hidden_dims; + + int stage = 0, start = 0, qkv_start = 0; + for (int gidx = 0; gidx < group_num; ++gidx) { + int base = gidx * 5; + int g_start = s_meta[base + 3]; + int g_end = s_meta[base + 4]; + + if (token_idx >= g_start && token_idx < g_end) { + stage = s_meta[base + 0]; + start = s_meta[base + 1]; + qkv_start = g_start; + break; + } + } + + int local_token_idx = token_idx - qkv_start + start; + int64_t local_offset = (int64_t)local_token_idx * hidden_dims + hidden_idx; + + if constexpr (IsSplit) { + T* target_dst = dst_ptrs[stage]; + Load(src_ptr_0 + linear_index, + reinterpret_cast(target_dst + local_offset)); + } else { + const T* target_src = src_ptrs[stage]; + Load(target_src + local_offset, + reinterpret_cast(dst_ptr_0 + linear_index)); + } + } +} + +void SplitQKV(const paddle::Tensor& qkv, + const paddle::Tensor& hybrid_meta, + paddle::Tensor& prefill_qkv, + paddle::Tensor& decode_qkv) { + auto qkv_shape = qkv.shape(); + int token_num = qkv_shape[0]; + + if (token_num == 0) { + return; + } + + int64_t linear_elem_num = qkv.numel(); + int hidden_dims = static_cast(linear_elem_num / token_num); + auto dtype = qkv.dtype(); + auto group_num = hybrid_meta.shape()[0]; + auto stream = qkv.stream(); + + constexpr int pack_size = 4; + constexpr int block_dims = 128; + const int pack_num = linear_elem_num / pack_size; + int grid_dims = 1; + GetNumBlocks(pack_num, &grid_dims); + size_t shared_mem_size = group_num * 5 * sizeof(int); + + switch (dtype) { + case paddle::DataType::BFLOAT16: + RunDispatchQKV<__maca_bfloat16, pack_size, true> + <<>>( + reinterpret_cast( + qkv.data()), + static_cast(nullptr), + reinterpret_cast(hybrid_meta.data()), + group_num, + hidden_dims, + linear_elem_num, + reinterpret_cast<__maca_bfloat16*>( + prefill_qkv.data()), + reinterpret_cast<__maca_bfloat16*>( + decode_qkv.data())); + break; + case paddle::DataType::FLOAT16: + RunDispatchQKV<__half, pack_size, true> + <<>>( + reinterpret_cast(qkv.data()), + static_cast(nullptr), + reinterpret_cast(hybrid_meta.data()), + group_num, + hidden_dims, + linear_elem_num, + reinterpret_cast<__half*>(prefill_qkv.data()), + reinterpret_cast<__half*>(decode_qkv.data())); + break; + default: + PD_THROW("Only support qkv dtype of BF16 and F16"); + } +} + +void MergeQKV(const paddle::Tensor& prefill_out, + const paddle::Tensor& decdoe_out, + const paddle::Tensor& hybrid_meta, + paddle::Tensor& merged_out) { + auto merged_out_shape = merged_out.shape(); + int token_num = merged_out_shape[0]; + + if (token_num == 0) { + return; + } + + int64_t linear_elem_num = merged_out.numel(); + int hidden_dims = static_cast(linear_elem_num / token_num); + auto dtype = merged_out.dtype(); + auto group_num = hybrid_meta.shape()[0]; + auto stream = merged_out.stream(); + + constexpr int pack_size = 4; + constexpr int block_dims = 128; + const int pack_num = linear_elem_num / pack_size; + int grid_dims = 1; + GetNumBlocks(pack_num, &grid_dims); + size_t shared_mem_size = group_num * 5 * sizeof(int); + + switch (dtype) { + case paddle::DataType::BFLOAT16: + RunDispatchQKV<__maca_bfloat16, pack_size> + <<>>( + reinterpret_cast( + prefill_out.data()), + reinterpret_cast( + decdoe_out.data()), + reinterpret_cast(hybrid_meta.data()), + group_num, + hidden_dims, + linear_elem_num, + reinterpret_cast<__maca_bfloat16*>( + merged_out.data()), + static_cast<__maca_bfloat16*>(nullptr)); + break; + case paddle::DataType::FLOAT16: + RunDispatchQKV<__half, pack_size> + <<>>( + reinterpret_cast( + prefill_out.data()), + reinterpret_cast( + decdoe_out.data()), + reinterpret_cast(hybrid_meta.data()), + group_num, + hidden_dims, + linear_elem_num, + reinterpret_cast<__half*>(merged_out.data()), + static_cast<__half*>(nullptr)); + break; + default: + PD_THROW("Only support qkv dtype of BF16 and F16"); + } +} + +std::vector> SplitQKVInferShape( + const std::vector& qkv_shape, + const std::vector& hybrid_meta_shape, + const std::vector& prefill_qkv_shape, + const std::vector& decode_qkv_shape) { + return {qkv_shape, hybrid_meta_shape, prefill_qkv_shape, decode_qkv_shape}; +} + +std::vector SplitQKVInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& hybrid_meta_dtype, + const paddle::DataType& prefill_qkv_dtype, + const paddle::DataType& decode_qkv_dtype) { + return {qkv_dtype, hybrid_meta_dtype, prefill_qkv_dtype, decode_qkv_dtype}; +} + +std::vector> MergeQKVInferShape( + const std::vector& prefill_out_shape, + const std::vector& decode_out_shape, + const std::vector& hybrid_meta_shape, + const std::vector& merged_out_shape) { + return { + prefill_out_shape, decode_out_shape, hybrid_meta_shape, merged_out_shape}; +} + +std::vector MergeQKVInferDtype( + const paddle::DataType& prefill_out_dtype, + const paddle::DataType& decode_out_dtype, + const paddle::DataType& hybrid_meta_dtype, + const paddle::DataType& merged_out_dtype) { + return { + prefill_out_dtype, decode_out_dtype, hybrid_meta_dtype, merged_out_dtype}; +} + +PD_BUILD_OP(split_qkv) + .Inputs({ + "qkv", + "hybrid_meta", + "prefill_qkv", + "decode_qkv", + }) + .SetKernelFn(PD_KERNEL(SplitQKV)) + .SetInferShapeFn(PD_INFER_SHAPE(SplitQKVInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SplitQKVInferDtype)); + +PD_BUILD_OP(merge_qkv) + .Inputs({ + "prefill_out", + "decode_out", + "hybrid_meta", + "merged_out", + }) + .SetKernelFn(PD_KERNEL(MergeQKV)) + .SetInferShapeFn(PD_INFER_SHAPE(MergeQKVInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MergeQKVInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 765de06030e..60a3a99e9c2 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -645,9 +645,9 @@ def find_end_files(directory, end_str): "metax_ops/moe_ffn.cu", "metax_ops/moe_reduce.cu", "metax_ops/fused_moe.cu", - "metax_ops/apply_rope_qkv.cu", "metax_ops/cache_kv_with_rope.cu", "metax_ops/cpp_extensions.cc", + "metax_ops/split_merge_qkv.cu", ] sources += find_end_files("gpu_ops/speculate_decoding", ".cu") diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py index d39f6e6155f..d2e58e1cd32 100644 --- a/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py @@ -31,7 +31,9 @@ flash_attn_kvcache_func, flash_attn_unpadded_func, ) -from fastdeploy.model_executor.ops.gpu import apply_rope_qkv, cache_kv_with_rope +from fastdeploy.model_executor.ops.gpu import cache_kv_with_rope +from fastdeploy.model_executor.ops.gpu import merge_qkv as merge_qkv_cu +from fastdeploy.model_executor.ops.gpu import split_qkv as split_qkv_cu @dataclass @@ -51,10 +53,10 @@ class FlashAttentionMetadata(AttentionMetadata): decoder_batch_ids: paddle.Tensor = None decoder_tile_ids_per_batch: paddle.Tensor = None decoder_num_blocks: paddle.Tensor = None - rotary_cos_prefill: paddle.Tensor = None - rotary_sin_prefill: paddle.Tensor = None - rotary_cos_decode: paddle.Tensor = None - rotary_sin_decode: paddle.Tensor = None + cu_seqlens_q_decode: paddle.Tensor = None + batch_ids_per_token_decode: paddle.Tensor = None + seq_lens_decode: paddle.Tensor = None + block_table_decode: paddle.Tensor = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 @@ -64,8 +66,6 @@ class FlashAttentionMetadata(AttentionMetadata): encoder_block_shape_q: int = -1 decoder_block_shape_q: int = -1 _fuse_kernel_compute_dtype: str = "bf16" - seq_lens_dec: paddle.Tensor = None - block_table_dec: paddle.Tensor = None # pd_disaggregation kv_signal_metadata: Optional[paddle.Tensor] = None @@ -128,20 +128,22 @@ def __init__( self.rank, self.device_id = init_rank_and_device_id(fd_config) self.enable_mm = fd_config.model_config.enable_mm + self.model_type = fd_config.model_config.model_type + self.is_neox_style = False + if "paddleocr" in fd_config.model_config.model_type: + self.is_neox_style = True + max_num_seqs = fd_config.scheduler_config.max_num_seqs - self.attention_metadata.rotary_cos_decode = paddle.empty( - shape=[max_num_seqs, 1, 1, self.head_dim], - dtype=self.dtype, - ) - self.attention_metadata.rotary_sin_decode = paddle.empty( - shape=[max_num_seqs, 1, 1, self.head_dim], - dtype=self.dtype, - ) - self.attention_metadata.seq_lens_dec = paddle.empty( - shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype="int32" - ) - self.attention_metadata.block_table_dec = paddle.empty( - shape=[fd_config.scheduler_config.max_num_seqs, self.head_dim], dtype="int32" + self.attention_metadata.decoder_batch_ids = paddle.empty(shape=[max_num_seqs], dtype="int32") + self.attention_metadata.cu_seqlens_q_decode = paddle.empty(shape=[max_num_seqs + 1], dtype="int32") + self.attention_metadata.batch_ids_per_token_decode = paddle.empty(shape=[max_num_seqs], dtype="int32") + self.attention_metadata.seq_lens_decode = paddle.empty(shape=[max_num_seqs, 1], dtype="int32") + self.attention_metadata.block_table_decode = paddle.empty( + shape=[ + max_num_seqs, + self.max_seq_len // self.block_size + fd_config.cache_config.enc_dec_block_num, + ], + dtype="int32", ) def init_attention_metadata(self, forward_meta: ForwardMeta): @@ -152,177 +154,91 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): prefill_non_zeros_ids = forward_meta.seq_lens_this_time > 1 decode_non_zeros_ids = forward_meta.seq_lens_this_time == 1 - self.prefill_info_dict["batch_ids"] = paddle.where(prefill_non_zeros_ids)[0] - self.decode_info_dict["batch_ids"] = paddle.where(decode_non_zeros_ids)[0] + self.prefill_info_dict["batch_ids"] = paddle.where(prefill_non_zeros_ids)[0].astype("int32") + self.decode_info_dict["batch_ids"] = paddle.where(decode_non_zeros_ids)[0].astype("int32") self.prefill_len = len(self.prefill_info_dict["batch_ids"]) self.decode_len = len(self.decode_info_dict["batch_ids"]) + self.has_prefill = self.prefill_len > 0 + self.has_decode = self.decode_len > 0 - # only prefill - if self.decode_len == 0: - cu_seq_ids = list(range(self.prefill_len + 1)) - self.prefill_info_dict["cu_seqlens_q"] = forward_meta.cu_seqlens_q[cu_seq_ids].astype("int32") - # only decode - elif self.prefill_len == 0: - pass - # both prefill and decode - else: - prefill_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[prefill_non_zeros_ids]) - decode_num_tokens = paddle.sum(forward_meta.seq_lens_this_time[decode_non_zeros_ids]) + if self.has_prefill: + batch_ids_prefill = self.prefill_info_dict["batch_ids"] - self.prefill_info_dict["cu_seqlens_q"] = paddle.zeros( - [self.prefill_len + 1], dtype=forward_meta.cu_seqlens_q.dtype + seq_lens_this_time_prefill = forward_meta.seq_lens_this_time[batch_ids_prefill, 0] + self.prefill_info_dict["cu_seqlens_q"] = paddle.concat( + [paddle.zeros([1], dtype="int32"), paddle.cumsum(seq_lens_this_time_prefill, axis=0).astype("int32")], + axis=0, ) - self.prefill_info_dict["cu_seqlens_q"][1:] = forward_meta.seq_lens_encoder[ - self.prefill_info_dict["batch_ids"], 0 - ] - self.prefill_info_dict["cu_seqlens_q"] = paddle.cumsum(self.prefill_info_dict["cu_seqlens_q"]).astype( - "int32" + self.prefill_info_dict["seq_lens_prefill"] = paddle.zeros(self.prefill_len, dtype="int32") + + local_ids = paddle.arange(self.prefill_len, dtype="int32") + self.prefill_info_dict["batch_ids_per_token"] = paddle.repeat_interleave( + local_ids, repeats=seq_lens_this_time_prefill, axis=0 ) - self.prefill_qkv = paddle.zeros([prefill_num_tokens, self.total_hidden_dim], dtype=self.dtype) - self.decode_qkv = paddle.zeros([decode_num_tokens, self.total_hidden_dim], dtype=self.dtype) - self.merged_output = paddle.zeros( - [prefill_num_tokens + decode_num_tokens, self.num_heads, self.head_dim], dtype=self.dtype + if self.has_decode: + batch_ids_decode = self.decode_info_dict["batch_ids"] + + seq_lens_this_time_decode = forward_meta.seq_lens_this_time[batch_ids_decode, 0] + cu_seqlens_q_decode = paddle.concat( + [paddle.zeros([1], dtype="int32"), paddle.cumsum(seq_lens_this_time_decode, axis=0).astype("int32")], + axis=0, ) - prefill_start, decode_start, start = 0, 0, 0 - non_zeros_ids = forward_meta.seq_lens_this_time != 0 - non_zeros_seq_lens = forward_meta.seq_lens_this_time[non_zeros_ids] - end = non_zeros_seq_lens[0] - if end > 1: - last_stage = "prefill" - prefill_end = end - decode_end = 0 - else: - last_stage = "decode" - prefill_end = 0 - decode_end = end - - self.prefill_info_dict["id_group"] = [] - self.prefill_info_dict["reverse_id_group"] = [] - self.decode_info_dict["id_group"] = [] - self.decode_info_dict["reverse_id_group"] = [] - self.record_stages = [] - for seq_len in non_zeros_seq_lens[1:]: - if seq_len > 1: - if last_stage == "decode": - self.record_stages.append((last_stage, len(self.decode_info_dict["id_group"]))) - self.decode_info_dict["id_group"].append((decode_start, decode_end)) - self.decode_info_dict["reverse_id_group"].append((start, end)) - decode_start = decode_end - start = end - last_stage = "prefill" - prefill_end += seq_len - end += seq_len - else: - if last_stage == "prefill": - self.record_stages.append((last_stage, len(self.prefill_info_dict["id_group"]))) - self.prefill_info_dict["id_group"].append((prefill_start, prefill_end)) - self.prefill_info_dict["reverse_id_group"].append((start, end)) - prefill_start = prefill_end - start = end - last_stage = "decode" - decode_end += seq_len - end += seq_len - - if prefill_start < prefill_end: - self.record_stages.append(("prefill", len(self.prefill_info_dict["id_group"]))) - self.prefill_info_dict["id_group"].append((prefill_start, prefill_end)) - self.prefill_info_dict["reverse_id_group"].append((start, end)) - if decode_start < decode_end: - self.record_stages.append(("decode", len(self.decode_info_dict["id_group"]))) - self.decode_info_dict["id_group"].append((decode_start, decode_end)) - self.decode_info_dict["reverse_id_group"].append((start, end)) - - self.batch_ids_prefill = paddle.to_tensor(self.prefill_info_dict["batch_ids"]) - self.batch_ids_decode = paddle.to_tensor(self.decode_info_dict["batch_ids"]) - self.attention_metadata.seq_lens_dec.copy_(forward_meta.seq_lens_decoder[self.batch_ids_decode, 0]) - self.attention_metadata.block_table_dec.copy_(forward_meta.block_tables[self.batch_ids_decode, :]) - - # update prefilling rope - self.update_rotary_embs_prefill(forward_meta) - # update decoding rope - self.update_rotary_embs_decoder(forward_meta) - - def update_rotary_embs_prefill(self, forward_meta: ForwardMeta): - if self.batch_ids_prefill.shape[0] == 0 or forward_meta.rotary_embs is None: - return - - batch_ids = self.batch_ids_prefill - seq_lens_this_time = forward_meta.seq_lens_this_time[batch_ids] - cached_kv_lens = forward_meta.seq_lens_decoder[batch_ids, 0] - - self.block_table_prefill = forward_meta.block_tables[batch_ids, :] - # mapping token idx to batch idx - self.batch_ids_q = paddle.repeat_interleave( - paddle.arange(0, batch_ids.shape[0], dtype="int32"), repeats=seq_lens_this_time, axis=0 - ) + local_ids = paddle.arange(self.decode_len, dtype="int32") + batch_ids_per_token_decode = paddle.repeat_interleave(local_ids, repeats=seq_lens_this_time_decode, axis=0) - all_indices = [] - for i in range(len(batch_ids)): - start_pos = cached_kv_lens[i] - seq_len_i = seq_lens_this_time[i] - if seq_len_i > 0: - indices_i = paddle.arange(start_pos, start_pos + seq_len_i, dtype="int64") - all_indices.append(indices_i) - if not all_indices: - return - - all_indices = paddle.concat(all_indices) # [token_num] - if self.enable_mm: - gather_nd_indices = paddle.stack( - [ # [token_num, 2] - paddle.repeat_interleave(batch_ids, repeats=seq_lens_this_time, axis=0), - all_indices, - ], - axis=1, + self.attention_metadata.decoder_batch_ids[: self.decode_len].copy_(batch_ids_decode) # global batch id + self.attention_metadata.cu_seqlens_q_decode[: self.decode_len + 1].copy_(cu_seqlens_q_decode) + self.attention_metadata.batch_ids_per_token_decode[: self.decode_len].copy_(batch_ids_per_token_decode) + self.attention_metadata.seq_lens_decode[: self.decode_len].copy_( + forward_meta.seq_lens_decoder[batch_ids_decode, 0] ) - gathered_embs = paddle.gather_nd( - forward_meta.rotary_embs.squeeze([2]).transpose( - [0, 2, 1, 3, 4] - ), # [B, 2, 1, S, 1, D // 2] -> [B, S, 2, 1, D // 2] - gather_nd_indices, - ) # [token_num, 2, 1, D // 2] - rot_cos = gathered_embs[:, 0, :, :] # [token_num, 1, D // 2] - rot_sin = gathered_embs[:, 1, :, :] - else: - gathered_embs = paddle.gather( - forward_meta.rotary_embs.squeeze([1]), all_indices, axis=1 # [2, 1, S, 1, D // 2] -> [2, S, 1, D // 2] - ) # [2, token_num, 1, D // 2] - rot_cos = gathered_embs[0, :, :, :] # [token_num, 1, D // 2] - rot_sin = gathered_embs[1, :, :, :] - - self.attention_metadata.rotary_cos_prefill = paddle.repeat_interleave( - rot_cos, repeats=2, axis=-1 - ) # [token_num, 1, D] - self.attention_metadata.rotary_sin_prefill = paddle.repeat_interleave(rot_sin, repeats=2, axis=-1) - - def update_rotary_embs_decoder(self, forward_meta: ForwardMeta): - if self.batch_ids_decode.shape[0] == 0: - return - - bs = self.batch_ids_decode.shape[0] - if self.enable_mm: - index = paddle.concat( - [self.batch_ids_decode.view([-1, 1]), self.attention_metadata.seq_lens_dec.to("int64").view([-1, 1])], - axis=1, + self.attention_metadata.block_table_decode[: self.decode_len].copy_( + forward_meta.block_tables[batch_ids_decode, :] + ) + + if self.has_prefill and self.has_decode: + non_zeros_mask = forward_meta.seq_lens_this_time != 0 + seq_lens_non_zeros = forward_meta.seq_lens_this_time[non_zeros_mask].astype("int32") + + global_sequence_offsets = paddle.zeros(seq_lens_non_zeros.shape[0] + 1, dtype="int32") + global_sequence_offsets[1:] = paddle.cumsum(seq_lens_non_zeros) + + is_prefill_array = seq_lens_non_zeros > 1 + + group_boundary = paddle.where(is_prefill_array[1:] != is_prefill_array[:-1])[0].astype("int32") + 1 + group_starts = paddle.concat((paddle.zeros([1], dtype="int32"), group_boundary)) + group_ends = paddle.concat( + (group_boundary, paddle.full([1], fill_value=seq_lens_non_zeros.shape[0], dtype="int32")) + ) + + compact_meta = [] + prefill_ptr = 0 + decode_ptr = 0 + + for start, end in zip(group_starts, group_ends): + is_prefill = is_prefill_array[start] + g_start = global_sequence_offsets[start] + g_end = global_sequence_offsets[end] + num_tokens = g_end - g_start + + if is_prefill: + # [0, prefill_start, prefill_end, global_start, global_end] + compact_meta.append([0, prefill_ptr, prefill_ptr + num_tokens, g_start, g_end]) + prefill_ptr += num_tokens + else: + # [1, decode_start, decode_end, global_start, global_end] + compact_meta.append([1, decode_ptr, decode_ptr + num_tokens, g_start, g_end]) + decode_ptr += num_tokens + + self.hybrid_stage_meta = paddle.to_tensor(compact_meta, dtype="int32") + self.prefill_qkv = paddle.zeros([prefill_ptr, self.total_hidden_dim], dtype=self.dtype) + self.decode_qkv = paddle.zeros([decode_ptr, self.total_hidden_dim], dtype=self.dtype) + self.merged_output = paddle.zeros( + [prefill_ptr + decode_ptr, self.num_heads, self.head_dim], dtype=self.dtype ) - rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1]) - rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1]) - else: - rot_cos = paddle.gather( - forward_meta.rotary_embs[0, 0, :, 0, :], self.attention_metadata.seq_lens_dec - ).view([bs, 1, 1, -1]) - rot_sin = paddle.gather( - forward_meta.rotary_embs[1, 0, :, 0, :], self.attention_metadata.seq_lens_dec - ).view([bs, 1, 1, -1]) - self.attention_metadata.rotary_cos_decode[:bs].copy_( - paddle.repeat_interleave(rot_cos, repeats=2, axis=-1).astype(self.dtype) - ) - self.attention_metadata.rotary_sin_decode[:bs].copy_( - paddle.repeat_interleave(rot_sin, repeats=2, axis=-1).astype(self.dtype) - ) def get_attntion_meta(self) -> AttentionMetadata: """get_attntion_meta""" @@ -348,118 +264,57 @@ def get_kv_cache_shape( return key_cache_shape, value_cache_shape - def apply_rope_native(self, qk, cos, sin): - rotate_half = paddle.reshape( - paddle.stack([-qk[..., 1::2], qk[..., 0::2]], axis=-1), - paddle.shape(qk), - ) - out = paddle.add(paddle.multiply(qk, cos), paddle.multiply(rotate_half, sin)) - return paddle.cast(out, qk.dtype) - def split_pd_qkv(self, qkv): - - for ids, reverse_ids in zip(self.prefill_info_dict["id_group"], self.prefill_info_dict["reverse_id_group"]): - self.prefill_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :] - - for ids, reverse_ids in zip(self.decode_info_dict["id_group"], self.decode_info_dict["reverse_id_group"]): - self.decode_qkv[ids[0] : ids[1], :] = qkv[reverse_ids[0] : reverse_ids[1], :] - - return self.prefill_qkv, self.decode_qkv + split_qkv_cu(qkv, self.hybrid_stage_meta, self.prefill_qkv, self.decode_qkv) def merge_pd_output(self, prefill_out, decode_out): - for stage, idx in self.record_stages: - if stage == "prefill": - ids = self.prefill_info_dict["id_group"][idx] - reverse_ids = self.prefill_info_dict["reverse_id_group"][idx] - self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = prefill_out[ids[0] : ids[1], :, :] - else: - ids = self.decode_info_dict["id_group"][idx] - reverse_ids = self.decode_info_dict["reverse_id_group"][idx] - self.merged_output[reverse_ids[0] : reverse_ids[1], :, :] = decode_out[ids[0] : ids[1], :, :] - return self.merged_output - - def update_kv_cache( - self, k, v, k_cache_id, v_cache_id, layer_id, forward_meta: ForwardMeta, specific_batch_ids=None - ): - tensor_start = 0 - for batch_idx in range(forward_meta.block_tables.shape[0]): - if specific_batch_ids is not None and batch_idx not in specific_batch_ids: - continue - seq_len = forward_meta.seq_lens_this_time[batch_idx] - if seq_len == 0: - continue - tensor_end = tensor_start + seq_len - slice_trans_k = k[tensor_start:tensor_end, :, :] - slice_trans_v = v[tensor_start:tensor_end, :, :] - - cur_block_tables = forward_meta.block_tables[batch_idx] - cur_used_block_tables = cur_block_tables[cur_block_tables != -1] - - # encoder prefil - if seq_len > 1: - cache_start = 0 - cur_used_num_blocks = cur_used_block_tables.shape[0] - - for i, block_id in enumerate(cur_used_block_tables): - - # last block: seq_len - cache_start <= block_size - if i == cur_used_num_blocks - 1: - cache_end = seq_len - cache_start - assert cache_end <= self.block_size - - forward_meta.caches[k_cache_id][block_id, 0:cache_end, :, :] = slice_trans_k[ - cache_start:seq_len, :, : - ] - forward_meta.caches[v_cache_id][block_id, 0:cache_end, :, :] = slice_trans_v[ - cache_start:seq_len, :, : - ] - if layer_id == self.num_layers - 1: - self.record_block_table_metadata[batch_idx] = { - "block_id": block_id.item(), - "cache_end": cache_end, - } - # non last block: seq_lens_this_time > block_size - else: - if bool(self.num_layers_draft_model) and ( - seq_len < self.block_size and i < cur_used_num_blocks - 1 - ): - cache_end = seq_len - cache_start - assert cache_end <= self.block_size - - forward_meta.caches[k_cache_id][block_id, 0:cache_end, :, :] = slice_trans_k[ - cache_start:seq_len, :, : - ] - forward_meta.caches[v_cache_id][block_id, 0:cache_end, :, :] = slice_trans_v[ - cache_start:seq_len, :, : - ] - if layer_id == self.num_layers - 1: - self.record_block_table_metadata[batch_idx] = { - "block_id": block_id.item(), - "cache_end": cache_end, - } - break - - assert seq_len > self.block_size - cache_end = cache_start + self.block_size - forward_meta.caches[k_cache_id][block_id] = slice_trans_k[cache_start:cache_end, :, :] - forward_meta.caches[v_cache_id][block_id] = slice_trans_v[cache_start:cache_end, :, :] - cache_start += self.block_size - tensor_start = tensor_end - - def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta): - q, k, v = cache_kv_with_rope( - prefill_qkv, - forward_meta.caches[k_cache_id], - forward_meta.caches[v_cache_id], - self.block_table_prefill, - self.attention_metadata.rotary_cos_prefill, - self.attention_metadata.rotary_sin_prefill, + merge_qkv_cu(prefill_out, decode_out, self.hybrid_stage_meta, self.merged_output) + + def apply_rope_prefill(self, qkv, rotary_embs, caches_k, caches_v, block_tables): + return cache_kv_with_rope( + qkv, + rotary_embs, + self.prefill_info_dict["batch_ids_per_token"], + self.prefill_info_dict["batch_ids"], self.prefill_info_dict["cu_seqlens_q"], - self.batch_ids_q, + self.prefill_info_dict["seq_lens_prefill"], + caches_k, + caches_v, + block_tables, self.num_heads, self.kv_num_heads, self.head_dim, self.block_size, + out_dims=3, + neox_style=self.is_neox_style, # is neox style + ) + + def apply_rope_decode(self, qkv, rotary_embs): + return cache_kv_with_rope( + qkv, + rotary_embs, + self.attention_metadata.batch_ids_per_token_decode, + self.attention_metadata.decoder_batch_ids, + self.attention_metadata.cu_seqlens_q_decode, + self.attention_metadata.seq_lens_decode, + None, + None, + None, + self.num_heads, + self.kv_num_heads, + self.head_dim, + -1, + out_dims=4, + neox_style=self.is_neox_style, # is neox style + ) + + def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta): + q, k, v = self.apply_rope_prefill( + prefill_qkv, + forward_meta.rotary_embs, + forward_meta.caches[k_cache_id], + forward_meta.caches[v_cache_id], + forward_meta.block_tables, ) prefill_out = flash_attn_unpadded_func( @@ -477,21 +332,14 @@ def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward return prefill_out def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta): - q, k, v = apply_rope_qkv( - decode_qkv, - self.attention_metadata.rotary_cos_decode, - self.attention_metadata.rotary_sin_decode, - self.num_heads, - self.kv_num_heads, - self.head_dim, - ) + q, k, v = self.apply_rope_decode(decode_qkv, forward_meta.rotary_embs) decode_out = flash_attn_kvcache_func( q, forward_meta.caches[k_cache_id], forward_meta.caches[v_cache_id], - self.attention_metadata.seq_lens_dec, - self.attention_metadata.block_table_dec, + self.attention_metadata.seq_lens_decode, + self.attention_metadata.block_table_decode, k, v, rotary_cos=None, @@ -509,17 +357,18 @@ def forward_native_backend(self, q, k, v, qkv, layer, forward_meta: ForwardMeta) k_cache_id = layer_id * 2 v_cache_id = k_cache_id + 1 - if self.decode_len == 0: + if self.has_prefill and not self.has_decode: out = self.forward_prefill(qkv, layer_id, k_cache_id, v_cache_id, forward_meta) - elif self.prefill_len == 0: + elif self.has_decode and not self.has_prefill: out = self.forward_decode(qkv, k_cache_id, v_cache_id, forward_meta) else: - prefill_qkv, decode_qkv = self.split_pd_qkv(qkv) - prefill_output = self.forward_prefill(prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta) - decode_output = self.forward_decode(decode_qkv, k_cache_id, v_cache_id, forward_meta) - out = self.merge_pd_output(prefill_output, decode_output) + self.split_pd_qkv(qkv) + prefill_output = self.forward_prefill(self.prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta) + decode_output = self.forward_decode(self.decode_qkv, k_cache_id, v_cache_id, forward_meta) + self.merge_pd_output(prefill_output, decode_output) + out = self.merged_output if qkv.dim() == 2: out = out.view([-1, self.num_heads * self.head_dim]) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index dee8dc37271..51610593176 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -59,7 +59,9 @@ set_stop_value_multi_ends, speculate_limit_thinking_content_length_v1, speculate_limit_thinking_content_length_v2, + speculate_step_system_cache, step_paddle, + step_system_cache, update_inputs, update_inputs_v1, ) diff --git a/requirements_metaxgpu.txt b/requirements_metaxgpu.txt index 96f1c458472..aeaa3c47e78 100644 --- a/requirements_metaxgpu.txt +++ b/requirements_metaxgpu.txt @@ -10,7 +10,8 @@ tqdm pynvml uvicorn==0.29.0 fastapi -paddleformers @ https://paddle-qa.bj.bcebos.com/ernie/paddleformers-0.4.0.post20251222-py3-none-any.whl +# if paddleformers version > 0.3.2, metax triton will be replaced by the newest triton. +paddleformers==0.3.2 redis etcd3 httpx