From 8f06f7e432cbf24cffe19b1154bceb9cf9c8e981 Mon Sep 17 00:00:00 2001 From: maruoheng Date: Fri, 5 Dec 2025 08:56:17 +0000 Subject: [PATCH 1/2] [XPU] add speculate_step_system_cache --- .../src/ops/mtp/speculate_step_helper.cc | 117 +++++++ .../src/ops/mtp/speculate_step_helper.h | 49 +++ .../src/ops/mtp/speculate_step_paddle.cc | 105 ++---- .../ops/mtp/speculate_step_system_cache.cc | 145 ++++++++ .../xpu_ops/src/plugin/include/xpu/plugin.h | 3 +- .../mtp_kernel/speculate_recover_block.xpu | 11 +- .../mtp_wrapper/speculate_recover_block.cpp | 21 +- .../test/test_speculate_step_system_cache.py | 316 ++++++++++++++++++ 8 files changed, 685 insertions(+), 82 deletions(-) create mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc create mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h create mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc create mode 100644 custom_ops/xpu_ops/test/test_speculate_step_system_cache.py diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc new file mode 100644 index 00000000000..383abd9536b --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "speculate_step_helper.h" + +void SpeculateStepPaddleBase( + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::optional &ori_seq_lens_decoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens) { + namespace api = baidu::xpu::api; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context *ctx = xpu_ctx->x_context(); + if (seq_lens_this_time.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + const int bsz = seq_lens_this_time.shape()[0]; + PADDLE_ENFORCE_LE( + bsz, + 640, + phi::errors::InvalidArgument( + "Only support bsz <= 640, but received bsz is %d", bsz)); + const int block_num_per_seq = block_tables.shape()[1]; + const int length = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + const int max_decoder_block_num = pre_id_length / block_size; + int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block( + ctx, + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(is_block_step.data()), + const_cast(step_block_list.data()), + const_cast(step_lens.data()), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(need_block_list.data()), + const_cast(need_block_len.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(first_token_ids.data()), + const_cast(accept_num.data()), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed."); + auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); + int recover_lens_cpu_data = recover_lens_cpu.data()[0]; + if (recover_lens_cpu_data > 0) { + r = baidu::xpu::api::plugin::speculate_recover_block( + ctx, + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + ori_seq_lens_encoder.data(), + ori_seq_lens_decoder ? ori_seq_lens_decoder.get_ptr()->data() : nullptr, + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(input_ids.data()), + pre_ids.data(), + step_idx.data(), + encoder_block_lens.data(), + used_list_len.data(), + next_tokens.data(), + first_token_ids.data(), + bsz, + block_num_per_seq, + length, + pre_id_length); + PD_CHECK(r == 0, "speculate_recover_block failed."); + } +} \ No newline at end of file diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h new file mode 100644 index 00000000000..4d9d5e97a7b --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h @@ -0,0 +1,49 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +void SpeculateStepPaddleBase( + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::optional &ori_seq_lens_decoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens); \ No newline at end of file diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc index d8b113fb81a..542f0f1a4fa 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include "paddle/extension.h" -#include "paddle/phi/core/enforce.h" -#include "xpu/plugin.h" +#include "speculate_step_helper.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -48,77 +45,35 @@ void SpeculateStepPaddle( const int block_size, const int encoder_decoder_block_num, const int max_draft_tokens) { - namespace api = baidu::xpu::api; - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - api::Context *ctx = xpu_ctx->x_context(); - if (seq_lens_this_time.is_cpu()) { - ctx = new api::Context(api::kCPU); - } - const int bsz = seq_lens_this_time.shape()[0]; - PADDLE_ENFORCE_LE( - bsz, - 640, - phi::errors::InvalidArgument( - "Only support bsz <= 640, but received bsz is %d", bsz)); - const int block_num_per_seq = block_tables.shape()[1]; - const int length = input_ids.shape()[1]; - const int pre_id_length = pre_ids.shape()[1]; - const int max_decoder_block_num = pre_id_length / block_size; - int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block( - ctx, - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_decoder.data()), - const_cast(block_tables.data()), - const_cast(encoder_block_lens.data()), - const_cast(is_block_step.data()), - const_cast(step_block_list.data()), - const_cast(step_lens.data()), - const_cast(recover_block_list.data()), - const_cast(recover_lens.data()), - const_cast(need_block_list.data()), - const_cast(need_block_len.data()), - const_cast(used_list_len.data()), - const_cast(free_list.data()), - const_cast(free_list_len.data()), - const_cast(first_token_ids.data()), - const_cast(accept_num.data()), - bsz, - block_size, - block_num_per_seq, - max_decoder_block_num, - max_draft_tokens); - PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed."); - auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); - int recover_lens_cpu_data = recover_lens_cpu.data()[0]; - if (recover_lens_cpu_data > 0) { - r = baidu::xpu::api::plugin::speculate_recover_block( - ctx, - const_cast(recover_block_list.data()), - const_cast(recover_lens.data()), - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - ori_seq_lens_encoder.data(), - const_cast(seq_lens_encoder.data()), - seq_lens_decoder.data(), - const_cast(block_tables.data()), - const_cast(free_list.data()), - const_cast(free_list_len.data()), - const_cast(input_ids.data()), - pre_ids.data(), - step_idx.data(), - encoder_block_lens.data(), - used_list_len.data(), - next_tokens.data(), - first_token_ids.data(), - bsz, - block_num_per_seq, - length, - pre_id_length); - PD_CHECK(r == 0, "speculate_recover_block failed."); - } + SpeculateStepPaddleBase( + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + paddle::optional(), + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens + ); } PD_BUILD_STATIC_OP(speculate_step_paddle) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc new file mode 100644 index 00000000000..89643a457e5 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "speculate_step_helper.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void SpeculateStepSystemCachePaddle( + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::Tensor &ori_seq_lens_decoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens) { + SpeculateStepPaddleBase( + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + paddle::make_optional(ori_seq_lens_decoder), + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens + ); +} + +PD_BUILD_STATIC_OP(speculate_step_system_cache) + .Inputs({"stop_flags", + "seq_lens_this_time", + "ori_seq_lens_encoder", + "ori_seq_lens_decoder", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables", + "encoder_block_lens", + "is_block_step", + "step_block_list", + "step_lens", + "recover_block_list", + "recover_lens", + "need_block_list", + "need_block_len", + "used_list_len", + "free_list", + "free_list_len", + "input_ids", + "pre_ids", + "step_idx", + "next_tokens", + "first_token_ids", + "accept_num"}) + .Attrs({"block_size: int", + "encoder_decoder_block_num: int", + "max_draft_tokens: int"}) + .Outputs({"stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "block_tables_out", + "encoder_block_lens_out", + "is_block_step_out", + "step_block_list_out", + "step_lens_out", + "recover_block_list_out", + "recover_lens_out", + "need_block_list_out", + "need_block_len_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out", + "input_ids_out", + "first_token_ids_out"}) + .SetInplaceMap({{"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"block_tables", "block_tables_out"}, + {"encoder_block_lens", "encoder_block_lens_out"}, + {"is_block_step", "is_block_step_out"}, + {"step_block_list", "step_block_list_out"}, + {"step_lens", "step_lens_out"}, + {"recover_block_list", "recover_block_list_out"}, + {"recover_lens", "recover_lens_out"}, + {"need_block_list", "need_block_list_out"}, + {"need_block_len", "need_block_len_out"}, + {"used_list_len", "used_list_len_out"}, + {"free_list", "free_list_out"}, + {"free_list_len", "free_list_len_out"}, + {"input_ids", "input_ids_out"}, + {"first_token_ids", "first_token_ids_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateStepSystemCachePaddle)); + diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 09a426a3126..bc27a54a94a 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -207,8 +207,9 @@ DLL_EXPORT int speculate_recover_block(Context* ctx, bool* stop_flags, int* seq_lens_this_time, const int* ori_seq_lens_encoder, + const int* ori_seq_lens_decoder, int* seq_lens_encoder, - const int* seq_lens_decoder, + int* seq_lens_decoder, int* block_tables, int* free_list, int* free_list_len, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu index 46d24821dda..6eb7279d97d 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu @@ -33,8 +33,9 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] bool* stop_flags, int* seq_lens_this_time, const int* ori_seq_lens_encoder, + const int* ori_seq_lens_decoder, int* seq_lens_encoder, - const int* seq_lens_decoder, + int* seq_lens_decoder, int* block_tables, int* free_list, int* free_list_len, @@ -82,6 +83,7 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] for (int bid = cid; bid < recover_len_lm; bid += ncores) { int recover_id; int ori_seq_len_encoder; + int ori_seq_len_decoder; int step_idx_now; int encoder_block_len; int decoder_used_len; @@ -89,12 +91,19 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] GM2LM(recover_block_list + bid, &recover_id, sizeof(int)); GM2LM_ASYNC( ori_seq_lens_encoder + recover_id, &ori_seq_len_encoder, sizeof(int)); + if (ori_seq_lens_decoder != nullptr) { + GM2LM_ASYNC( + ori_seq_lens_decoder + recover_id, &ori_seq_len_decoder, sizeof(int)); + } GM2LM_ASYNC(step_idx + recover_id, &step_idx_now, sizeof(int)); GM2LM_ASYNC( encoder_block_lens + recover_id, &encoder_block_len, sizeof(int)); GM2LM_ASYNC(used_list_len + recover_id, &decoder_used_len, sizeof(int)); GM2LM_ASYNC(next_tokens + recover_id, &next_token, sizeof(int64_t)); mfence(); + if (ori_seq_lens_decoder != nullptr) { + LM2GM_ASYNC(&ori_seq_len_decoder, seq_lens_decoder + recover_id, sizeof(int)); + } int seq_len = ori_seq_len_encoder + step_idx_now; mfence(); diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp index 2996325c833..5f3c8bdf6c2 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp @@ -26,8 +26,9 @@ __attribute__((global)) void speculate_recover_block( bool *stop_flags, int *seq_lens_this_time, const int *ori_seq_lens_encoder, + const int *ori_seq_lens_decoder, int *seq_lens_encoder, - const int *seq_lens_decoder, + int *seq_lens_decoder, int *block_tables, int *free_list, int *free_list_len, @@ -57,8 +58,9 @@ static int cpu_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, const int *ori_seq_lens_encoder, + const int *ori_seq_lens_decoder, int *seq_lens_encoder, - const int *seq_lens_decoder, + int *seq_lens_decoder, int *block_tables, int *free_list, int *free_list_len, @@ -76,6 +78,9 @@ static int cpu_wrapper(Context *ctx, for (int bid = 0; bid < recover_len[0]; bid++) { const int recover_id = recover_block_list[bid]; const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; + if (ori_seq_lens_decoder != nullptr) { + seq_lens_decoder[recover_id] = ori_seq_lens_decoder[recover_id]; + } const int step_idx_now = step_idx[recover_id]; const int seq_len = ori_seq_len_encoder + step_idx_now; const int encoder_block_len = encoder_block_lens[recover_id]; @@ -112,8 +117,9 @@ static int xpu3_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, const int *ori_seq_lens_encoder, + const int *ori_seq_lens_decoder, int *seq_lens_encoder, - const int *seq_lens_decoder, + int *seq_lens_decoder, int *block_tables, int *free_list, int *free_list_len, @@ -136,6 +142,7 @@ static int xpu3_wrapper(Context *ctx, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, + ori_seq_lens_decoder, seq_lens_encoder, seq_lens_decoder, block_tables, @@ -161,8 +168,9 @@ int speculate_recover_block(Context *ctx, bool *stop_flags, int *seq_lens_this_time, const int *ori_seq_lens_encoder, + const int *ori_seq_lens_decoder, int *seq_lens_encoder, - const int *seq_lens_decoder, + int *seq_lens_decoder, int *block_tables, int *free_list, int *free_list_len, @@ -185,7 +193,8 @@ int speculate_recover_block(Context *ctx, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, - seq_lens_encoder); + ori_seq_lens_decoder); + WRAPPER_DUMP_PARAM1(ctx, seq_lens_encoder); WRAPPER_DUMP_PARAM6(ctx, seq_lens_decoder, block_tables, @@ -208,6 +217,7 @@ int speculate_recover_block(Context *ctx, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, + ori_seq_lens_decoder, seq_lens_encoder, seq_lens_decoder, block_tables, @@ -232,6 +242,7 @@ int speculate_recover_block(Context *ctx, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, + ori_seq_lens_decoder, seq_lens_encoder, seq_lens_decoder, block_tables, diff --git a/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py b/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py new file mode 100644 index 00000000000..d691533d03f --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py @@ -0,0 +1,316 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_step_system_cache + +# 固定随机种子,保证测试可复现 +np.random.seed(2023) +paddle.seed(2023) + +def generate_test_data(): + """ + 生成测试数据的辅助函数。 + 这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。 + """ + # max_bs = 128 + max_bs = 8 + bs = max_bs + max_seq_len = 8192 + block_size = 64 + block_bs = 8 + block_ratio = 0.75 + max_draft_tokens = 1 + encoder_decoder_block_num = 1 + + # 生成原始测试数据(完全复用原有逻辑) + stop_flags = np.random.randint(0, 2, [max_bs]).astype("bool") + seq_lens_this_time = np.zeros([bs], "int32") + seq_lens_encoder = np.zeros([max_bs], "int32") + seq_lens_decoder = np.zeros([max_bs], "int32") + accept_num = np.random.randint(1, 3, [max_bs]).astype("int32") + for i in range(bs): + seq_lens_decoder[i] = 2 + i * 2 + seq_lens_this_time[i] = 1 + + ori_seq_lens_encoder = np.zeros([max_bs], "int32") + ori_seq_lens_encoder[:] = seq_lens_decoder[:] // 2 + ori_seq_lens_decoder = np.random.randint(1, 10, (max_bs), "int32") + step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64") + + max_block_num = block_bs * max_seq_len // block_size + free_list_len = int(max_block_num * (1 - block_ratio)) + free_list_len = np.full([1], free_list_len, "int32") + free_list = np.arange( + max_block_num - 1, max_block_num - free_list_len.item() - 1, -1, dtype="int32" # 加 .item() 转为标量 + ) + encoder_block_lens = np.zeros([max_bs], "int32") + used_list_len = np.zeros([max_bs], "int32") + block_tables = np.full([max_bs, 128], -1, "int32") + encoder_block_id = 0 + + for i in range(bs): + enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size + encoder_block_lens[i] = enc_block_num + dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num + used_list_len[i] = dec_block_num + block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32") + encoder_block_id += enc_block_num + if dec_block_num > 0: + block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[ + free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1 + ] + free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1 + free_list_len[0] -= dec_block_num + + assert free_list_len[0] >= 0, "free_list_len should not be negative" + + is_block_step = np.zeros([max_bs], "bool") + is_block_step[:bs] = np.random.randint(0, 2, [bs]).astype("bool") + step_block_list = np.full([max_bs], -1, "int32") + step_lens = np.full([1], 0, "int32") + + for i in range(bs): + if is_block_step[i]: + step_block_list[step_lens[0]] = i + step_lens[0] += 1 + + recover_lens = np.full([1], 0, "int32") + recover_block_list = np.full([max_bs], -1, "int32") + need_block_len = np.full([1], 0, "int32") + need_block_list = np.full([max_bs], -1, "int32") + + input_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") + pre_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") + next_tokens = np.random.randint(0, 1000, [max_bs], "int64") + first_token_ids = np.random.randint(0, 1000, [max_bs], "int64") + + paddle.set_device("cpu") + # 转换为 paddle tensor(保持原有逻辑) + data_cpu = { + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder), + "ori_seq_lens_decoder": paddle.to_tensor(ori_seq_lens_decoder), + "block_tables": paddle.to_tensor(block_tables), + "encoder_block_lens": paddle.to_tensor(encoder_block_lens), + "is_block_step": paddle.to_tensor(is_block_step), + "step_block_list": paddle.to_tensor(step_block_list), + "step_lens": paddle.to_tensor(step_lens), + "recover_block_list": paddle.to_tensor(recover_block_list), + "recover_lens": paddle.to_tensor(recover_lens), + "need_block_list": paddle.to_tensor(need_block_list), + "need_block_len": paddle.to_tensor(need_block_len), + "used_list_len": paddle.to_tensor(used_list_len), + "free_list_len": paddle.to_tensor(free_list_len), + "free_list": paddle.to_tensor(free_list), + "input_ids": paddle.to_tensor(input_ids), + "pre_ids": paddle.to_tensor(pre_ids), + "step_idx": paddle.to_tensor(step_idx), + "next_tokens": paddle.to_tensor(next_tokens), + "first_token_ids": paddle.to_tensor(first_token_ids), + "accept_num": paddle.to_tensor(accept_num), + "block_size": block_size, + "encoder_decoder_block_num": encoder_decoder_block_num, + "max_draft_tokens": max_draft_tokens, + } + + paddle.set_device("xpu:0") + data_xpu = { + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder), + "ori_seq_lens_decoder": paddle.to_tensor(ori_seq_lens_decoder), + "block_tables": paddle.to_tensor(block_tables), + "encoder_block_lens": paddle.to_tensor(encoder_block_lens), + "is_block_step": paddle.to_tensor(is_block_step), + "step_block_list": paddle.to_tensor(step_block_list), + "step_lens": paddle.to_tensor(step_lens), + "recover_block_list": paddle.to_tensor(recover_block_list), + "recover_lens": paddle.to_tensor(recover_lens), + "need_block_list": paddle.to_tensor(need_block_list), + "need_block_len": paddle.to_tensor(need_block_len), + "used_list_len": paddle.to_tensor(used_list_len), + "free_list_len": paddle.to_tensor(free_list_len), + "free_list": paddle.to_tensor(free_list), + "input_ids": paddle.to_tensor(input_ids), + "pre_ids": paddle.to_tensor(pre_ids), + "step_idx": paddle.to_tensor(step_idx), + "next_tokens": paddle.to_tensor(next_tokens), + "first_token_ids": paddle.to_tensor(first_token_ids), + "accept_num": paddle.to_tensor(accept_num), + "block_size": block_size, + "encoder_decoder_block_num": encoder_decoder_block_num, + "max_draft_tokens": max_draft_tokens, + } + + # 恢复默认设备,避免影响其他测试 + paddle.set_device("cpu") + + return data_cpu, data_xpu + + +def speculate_step_paddle_execution(test_data): + """测试 speculate_step_system_cache 函数的执行性和输出合理性""" + # 提取输入数据 + stop_flags = test_data["stop_flags"] # 克隆避免影响夹具数据 + seq_lens_this_time = test_data["seq_lens_this_time"] + ori_seq_lens_encoder = test_data["ori_seq_lens_encoder"] + ori_seq_lens_decoder = test_data["ori_seq_lens_decoder"] + seq_lens_encoder = test_data["seq_lens_encoder"] + seq_lens_decoder = test_data["seq_lens_decoder"] + block_tables = test_data["block_tables"] + encoder_block_lens = test_data["encoder_block_lens"] + is_block_step = test_data["is_block_step"] + step_block_list = test_data["step_block_list"] + step_lens = test_data["step_lens"] + recover_block_list = test_data["recover_block_list"] + recover_lens = test_data["recover_lens"] + need_block_list = test_data["need_block_list"] + need_block_len = test_data["need_block_len"] + used_list_len = test_data["used_list_len"] + free_list = test_data["free_list"] + free_list_len = test_data["free_list_len"] + input_ids = test_data["input_ids"] + pre_ids = test_data["pre_ids"] + step_idx = test_data["step_idx"] + next_tokens = test_data["next_tokens"] + first_token_ids = test_data["first_token_ids"] + accept_num = test_data["accept_num"] + block_size = test_data["block_size"] + encoder_decoder_block_num = test_data["encoder_decoder_block_num"] + max_draft_tokens = test_data["max_draft_tokens"] + + # 可选:打印执行前关键信息(如需调试可开启) + if os.environ.get("STEP_TEST_DEBUG", "0") == "1": + print("-" * 50 + "before step op" + "-" * 50) + # ... (省略打印内容以保持简洁) + + # 执行目标函数(核心测试步骤) + speculate_step_system_cache( + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + ori_seq_lens_decoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens, + ) + + # 可选:打印执行后关键信息(如需调试可开启) + if os.environ.get("STEP_TEST_DEBUG", "0") == "1": + print("-" * 50 + "after step op" + "-" * 50) + # ... (省略打印内容以保持简洁) + + return test_data + + +class TestSpeculateStepSystemCache(unittest.TestCase): + """ + 测试类,继承自 unittest.TestCase。 + 所有以 'test_' 开头的方法都会被视为测试用例。 + """ + + def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08): + """ + 自定义的断言方法,用于比较两个 test_data 结构和数据。 + 在 unittest 中,自定义断言通常以 'assert' 开头。 + """ + # 1. 先校验两个 test_data 的字段名完全一致 + keys1 = set(test_data1.keys()) + keys2 = set(test_data2.keys()) + self.assertEqual( + keys1, + keys2, + msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}", + ) + + # 2. 逐字段校验数据 + for key in keys1: + data1 = test_data1[key] + data2 = test_data2[key] + + # 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用) + if isinstance(data1, paddle.Tensor): + np1 = data1.detach().cpu().numpy() + else: + np1 = np.asarray(data1) + + if isinstance(data2, paddle.Tensor): + np2 = data2.detach().cpu().numpy() + else: + np2 = np.asarray(data2) + + # 3. 校验数据 + if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8): + # 布尔/整数型:必须完全相等 + np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!") + else: + # 浮点型:允许 rtol/atol 范围内的误差 + np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!") + + print("✅ 两个 test_data 结构和数据完全一致!") + + def test_speculate_step_system_cache_execution(self): + """ + 核心测试用例方法。 + 该方法会调用 generate_test_data 获取数据, + 分别在 CPU 和 XPU 上执行测试函数, + 并使用自定义的断言方法比较结果。 + """ + print("\nRunning test: test_speculate_step_system_cache_execution") + + # 1. 获取测试数据 + data_cpu, data_xpu = generate_test_data() + + # 2. 执行测试函数 + result_xpu = speculate_step_paddle_execution(data_xpu) + result_cpu = speculate_step_paddle_execution(data_cpu) + + # 3. 断言结果一致 + self.assert_test_data_equal(result_xpu, result_cpu) + + +if __name__ == "__main__": + # 使用 unittest 的主程序来运行所有测试用例 + unittest.main() From e0bec564929003897a9549c6db27ff6c12da3918 Mon Sep 17 00:00:00 2001 From: maruoheng Date: Mon, 8 Dec 2025 07:14:56 +0000 Subject: [PATCH 2/2] [XPU] add speculate_step_system_cache --- .../src/ops/mtp/speculate_step_helper.cc | 5 +- .../src/ops/mtp/speculate_step_helper.h | 2 +- .../src/ops/mtp/speculate_step_paddle.cc | 56 +++++++++---------- .../ops/mtp/speculate_step_system_cache.cc | 56 +++++++++---------- .../mtp_kernel/speculate_recover_block.xpu | 7 ++- .../mtp_wrapper/speculate_recover_block.cpp | 2 +- .../test/test_speculate_step_system_cache.py | 1 + 7 files changed, 64 insertions(+), 65 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc index 383abd9536b..2344531a333 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc @@ -95,7 +95,8 @@ void SpeculateStepPaddleBase( const_cast(stop_flags.data()), const_cast(seq_lens_this_time.data()), ori_seq_lens_encoder.data(), - ori_seq_lens_decoder ? ori_seq_lens_decoder.get_ptr()->data() : nullptr, + ori_seq_lens_decoder ? ori_seq_lens_decoder.get_ptr()->data() + : nullptr, const_cast(seq_lens_encoder.data()), const_cast(seq_lens_decoder.data()), const_cast(block_tables.data()), @@ -114,4 +115,4 @@ void SpeculateStepPaddleBase( pre_id_length); PD_CHECK(r == 0, "speculate_recover_block failed."); } -} \ No newline at end of file +} diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h index 4d9d5e97a7b..ea2eb2c9bb6 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h @@ -46,4 +46,4 @@ void SpeculateStepPaddleBase( const paddle::Tensor &accept_num, const int block_size, const int encoder_decoder_block_num, - const int max_draft_tokens); \ No newline at end of file + const int max_draft_tokens); diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc index 542f0f1a4fa..1088b604c91 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc @@ -45,35 +45,33 @@ void SpeculateStepPaddle( const int block_size, const int encoder_decoder_block_num, const int max_draft_tokens) { - SpeculateStepPaddleBase( - stop_flags, - seq_lens_this_time, - ori_seq_lens_encoder, - paddle::optional(), - seq_lens_encoder, - seq_lens_decoder, - block_tables, - encoder_block_lens, - is_block_step, - step_block_list, - step_lens, - recover_block_list, - recover_lens, - need_block_list, - need_block_len, - used_list_len, - free_list, - free_list_len, - input_ids, - pre_ids, - step_idx, - next_tokens, - first_token_ids, - accept_num, - block_size, - encoder_decoder_block_num, - max_draft_tokens - ); + SpeculateStepPaddleBase(stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + paddle::optional(), + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens); } PD_BUILD_STATIC_OP(speculate_step_paddle) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc index 89643a457e5..0040600ca37 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc @@ -47,34 +47,33 @@ void SpeculateStepSystemCachePaddle( const int encoder_decoder_block_num, const int max_draft_tokens) { SpeculateStepPaddleBase( - stop_flags, - seq_lens_this_time, - ori_seq_lens_encoder, - paddle::make_optional(ori_seq_lens_decoder), - seq_lens_encoder, - seq_lens_decoder, - block_tables, - encoder_block_lens, - is_block_step, - step_block_list, - step_lens, - recover_block_list, - recover_lens, - need_block_list, - need_block_len, - used_list_len, - free_list, - free_list_len, - input_ids, - pre_ids, - step_idx, - next_tokens, - first_token_ids, - accept_num, - block_size, - encoder_decoder_block_num, - max_draft_tokens - ); + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + paddle::make_optional(ori_seq_lens_decoder), + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens); } PD_BUILD_STATIC_OP(speculate_step_system_cache) @@ -142,4 +141,3 @@ PD_BUILD_STATIC_OP(speculate_step_system_cache) {"input_ids", "input_ids_out"}, {"first_token_ids", "first_token_ids_out"}}) .SetKernelFn(PD_KERNEL(SpeculateStepSystemCachePaddle)); - diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu index 6eb7279d97d..85439819265 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu @@ -92,8 +92,8 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] GM2LM_ASYNC( ori_seq_lens_encoder + recover_id, &ori_seq_len_encoder, sizeof(int)); if (ori_seq_lens_decoder != nullptr) { - GM2LM_ASYNC( - ori_seq_lens_decoder + recover_id, &ori_seq_len_decoder, sizeof(int)); + GM2LM_ASYNC( + ori_seq_lens_decoder + recover_id, &ori_seq_len_decoder, sizeof(int)); } GM2LM_ASYNC(step_idx + recover_id, &step_idx_now, sizeof(int)); GM2LM_ASYNC( @@ -102,7 +102,8 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] GM2LM_ASYNC(next_tokens + recover_id, &next_token, sizeof(int64_t)); mfence(); if (ori_seq_lens_decoder != nullptr) { - LM2GM_ASYNC(&ori_seq_len_decoder, seq_lens_decoder + recover_id, sizeof(int)); + LM2GM_ASYNC( + &ori_seq_len_decoder, seq_lens_decoder + recover_id, sizeof(int)); } int seq_len = ori_seq_len_encoder + step_idx_now; diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp index 5f3c8bdf6c2..0e270c0f01a 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp @@ -79,7 +79,7 @@ static int cpu_wrapper(Context *ctx, const int recover_id = recover_block_list[bid]; const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; if (ori_seq_lens_decoder != nullptr) { - seq_lens_decoder[recover_id] = ori_seq_lens_decoder[recover_id]; + seq_lens_decoder[recover_id] = ori_seq_lens_decoder[recover_id]; } const int step_idx_now = step_idx[recover_id]; const int seq_len = ori_seq_len_encoder + step_idx_now; diff --git a/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py b/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py index d691533d03f..6b52efe13f7 100644 --- a/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py +++ b/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py @@ -24,6 +24,7 @@ np.random.seed(2023) paddle.seed(2023) + def generate_test_data(): """ 生成测试数据的辅助函数。