Skip to content

Commit 8178e3f

Browse files
RuohengMacmcamdy
andauthored
[XPU] add speculate_step_system_cache (#5397)
* [XPU] add speculate_step_system_cache * [XPU] add speculate_step_system_cache --------- Co-authored-by: cmcamdy <1027740945@qq.com>
1 parent e1c4a12 commit 8178e3f

File tree

8 files changed

+684
-82
lines changed

8 files changed

+684
-82
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "speculate_step_helper.h"
16+
17+
void SpeculateStepPaddleBase(
18+
const paddle::Tensor &stop_flags,
19+
const paddle::Tensor &seq_lens_this_time,
20+
const paddle::Tensor &ori_seq_lens_encoder,
21+
const paddle::optional<paddle::Tensor> &ori_seq_lens_decoder,
22+
const paddle::Tensor &seq_lens_encoder,
23+
const paddle::Tensor &seq_lens_decoder,
24+
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
25+
const paddle::Tensor &encoder_block_lens,
26+
const paddle::Tensor &is_block_step,
27+
const paddle::Tensor &step_block_list,
28+
const paddle::Tensor &step_lens,
29+
const paddle::Tensor &recover_block_list,
30+
const paddle::Tensor &recover_lens,
31+
const paddle::Tensor &need_block_list,
32+
const paddle::Tensor &need_block_len,
33+
const paddle::Tensor &used_list_len,
34+
const paddle::Tensor &free_list,
35+
const paddle::Tensor &free_list_len,
36+
const paddle::Tensor &input_ids,
37+
const paddle::Tensor &pre_ids,
38+
const paddle::Tensor &step_idx,
39+
const paddle::Tensor &next_tokens,
40+
const paddle::Tensor &first_token_ids,
41+
const paddle::Tensor &accept_num,
42+
const int block_size,
43+
const int encoder_decoder_block_num,
44+
const int max_draft_tokens) {
45+
namespace api = baidu::xpu::api;
46+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
47+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
48+
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
49+
api::Context *ctx = xpu_ctx->x_context();
50+
if (seq_lens_this_time.is_cpu()) {
51+
ctx = new api::Context(api::kCPU);
52+
}
53+
const int bsz = seq_lens_this_time.shape()[0];
54+
PADDLE_ENFORCE_LE(
55+
bsz,
56+
640,
57+
phi::errors::InvalidArgument(
58+
"Only support bsz <= 640, but received bsz is %d", bsz));
59+
const int block_num_per_seq = block_tables.shape()[1];
60+
const int length = input_ids.shape()[1];
61+
const int pre_id_length = pre_ids.shape()[1];
62+
const int max_decoder_block_num = pre_id_length / block_size;
63+
int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block(
64+
ctx,
65+
const_cast<bool *>(stop_flags.data<bool>()),
66+
const_cast<int *>(seq_lens_this_time.data<int>()),
67+
const_cast<int *>(seq_lens_decoder.data<int>()),
68+
const_cast<int *>(block_tables.data<int>()),
69+
const_cast<int *>(encoder_block_lens.data<int>()),
70+
const_cast<bool *>(is_block_step.data<bool>()),
71+
const_cast<int *>(step_block_list.data<int>()),
72+
const_cast<int *>(step_lens.data<int>()),
73+
const_cast<int *>(recover_block_list.data<int>()),
74+
const_cast<int *>(recover_lens.data<int>()),
75+
const_cast<int *>(need_block_list.data<int>()),
76+
const_cast<int *>(need_block_len.data<int>()),
77+
const_cast<int *>(used_list_len.data<int>()),
78+
const_cast<int *>(free_list.data<int>()),
79+
const_cast<int *>(free_list_len.data<int>()),
80+
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
81+
const_cast<int *>(accept_num.data<int>()),
82+
bsz,
83+
block_size,
84+
block_num_per_seq,
85+
max_decoder_block_num,
86+
max_draft_tokens);
87+
PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed.");
88+
auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false);
89+
int recover_lens_cpu_data = recover_lens_cpu.data<int>()[0];
90+
if (recover_lens_cpu_data > 0) {
91+
r = baidu::xpu::api::plugin::speculate_recover_block(
92+
ctx,
93+
const_cast<int *>(recover_block_list.data<int>()),
94+
const_cast<int *>(recover_lens.data<int>()),
95+
const_cast<bool *>(stop_flags.data<bool>()),
96+
const_cast<int *>(seq_lens_this_time.data<int>()),
97+
ori_seq_lens_encoder.data<int>(),
98+
ori_seq_lens_decoder ? ori_seq_lens_decoder.get_ptr()->data<int>()
99+
: nullptr,
100+
const_cast<int *>(seq_lens_encoder.data<int>()),
101+
const_cast<int *>(seq_lens_decoder.data<int>()),
102+
const_cast<int *>(block_tables.data<int>()),
103+
const_cast<int *>(free_list.data<int>()),
104+
const_cast<int *>(free_list_len.data<int>()),
105+
const_cast<int64_t *>(input_ids.data<int64_t>()),
106+
pre_ids.data<int64_t>(),
107+
step_idx.data<int64_t>(),
108+
encoder_block_lens.data<int>(),
109+
used_list_len.data<int>(),
110+
next_tokens.data<int64_t>(),
111+
first_token_ids.data<int64_t>(),
112+
bsz,
113+
block_num_per_seq,
114+
length,
115+
pre_id_length);
116+
PD_CHECK(r == 0, "speculate_recover_block failed.");
117+
}
118+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <paddle/phi/backends/xpu/xpu_context.h>
18+
#include "paddle/extension.h"
19+
#include "paddle/phi/core/enforce.h"
20+
#include "xpu/plugin.h"
21+
22+
void SpeculateStepPaddleBase(
23+
const paddle::Tensor &stop_flags,
24+
const paddle::Tensor &seq_lens_this_time,
25+
const paddle::Tensor &ori_seq_lens_encoder,
26+
const paddle::optional<paddle::Tensor> &ori_seq_lens_decoder,
27+
const paddle::Tensor &seq_lens_encoder,
28+
const paddle::Tensor &seq_lens_decoder,
29+
const paddle::Tensor &block_tables, // [bsz, block_num_per_seq]
30+
const paddle::Tensor &encoder_block_lens,
31+
const paddle::Tensor &is_block_step,
32+
const paddle::Tensor &step_block_list,
33+
const paddle::Tensor &step_lens,
34+
const paddle::Tensor &recover_block_list,
35+
const paddle::Tensor &recover_lens,
36+
const paddle::Tensor &need_block_list,
37+
const paddle::Tensor &need_block_len,
38+
const paddle::Tensor &used_list_len,
39+
const paddle::Tensor &free_list,
40+
const paddle::Tensor &free_list_len,
41+
const paddle::Tensor &input_ids,
42+
const paddle::Tensor &pre_ids,
43+
const paddle::Tensor &step_idx,
44+
const paddle::Tensor &next_tokens,
45+
const paddle::Tensor &first_token_ids,
46+
const paddle::Tensor &accept_num,
47+
const int block_size,
48+
const int encoder_decoder_block_num,
49+
const int max_draft_tokens);

custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc

Lines changed: 28 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include <paddle/phi/backends/xpu/xpu_context.h>
16-
#include "paddle/extension.h"
17-
#include "paddle/phi/core/enforce.h"
18-
#include "xpu/plugin.h"
15+
#include "speculate_step_helper.h"
1916

2017
#ifndef PD_BUILD_STATIC_OP
2118
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -48,77 +45,33 @@ void SpeculateStepPaddle(
4845
const int block_size,
4946
const int encoder_decoder_block_num,
5047
const int max_draft_tokens) {
51-
namespace api = baidu::xpu::api;
52-
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
53-
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
54-
auto xpu_ctx = static_cast<const phi::XPUContext *>(dev_ctx);
55-
api::Context *ctx = xpu_ctx->x_context();
56-
if (seq_lens_this_time.is_cpu()) {
57-
ctx = new api::Context(api::kCPU);
58-
}
59-
const int bsz = seq_lens_this_time.shape()[0];
60-
PADDLE_ENFORCE_LE(
61-
bsz,
62-
640,
63-
phi::errors::InvalidArgument(
64-
"Only support bsz <= 640, but received bsz is %d", bsz));
65-
const int block_num_per_seq = block_tables.shape()[1];
66-
const int length = input_ids.shape()[1];
67-
const int pre_id_length = pre_ids.shape()[1];
68-
const int max_decoder_block_num = pre_id_length / block_size;
69-
int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block(
70-
ctx,
71-
const_cast<bool *>(stop_flags.data<bool>()),
72-
const_cast<int *>(seq_lens_this_time.data<int>()),
73-
const_cast<int *>(seq_lens_decoder.data<int>()),
74-
const_cast<int *>(block_tables.data<int>()),
75-
const_cast<int *>(encoder_block_lens.data<int>()),
76-
const_cast<bool *>(is_block_step.data<bool>()),
77-
const_cast<int *>(step_block_list.data<int>()),
78-
const_cast<int *>(step_lens.data<int>()),
79-
const_cast<int *>(recover_block_list.data<int>()),
80-
const_cast<int *>(recover_lens.data<int>()),
81-
const_cast<int *>(need_block_list.data<int>()),
82-
const_cast<int *>(need_block_len.data<int>()),
83-
const_cast<int *>(used_list_len.data<int>()),
84-
const_cast<int *>(free_list.data<int>()),
85-
const_cast<int *>(free_list_len.data<int>()),
86-
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
87-
const_cast<int *>(accept_num.data<int>()),
88-
bsz,
89-
block_size,
90-
block_num_per_seq,
91-
max_decoder_block_num,
92-
max_draft_tokens);
93-
PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed.");
94-
auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false);
95-
int recover_lens_cpu_data = recover_lens_cpu.data<int>()[0];
96-
if (recover_lens_cpu_data > 0) {
97-
r = baidu::xpu::api::plugin::speculate_recover_block(
98-
ctx,
99-
const_cast<int *>(recover_block_list.data<int>()),
100-
const_cast<int *>(recover_lens.data<int>()),
101-
const_cast<bool *>(stop_flags.data<bool>()),
102-
const_cast<int *>(seq_lens_this_time.data<int>()),
103-
ori_seq_lens_encoder.data<int>(),
104-
const_cast<int *>(seq_lens_encoder.data<int>()),
105-
seq_lens_decoder.data<int>(),
106-
const_cast<int *>(block_tables.data<int>()),
107-
const_cast<int *>(free_list.data<int>()),
108-
const_cast<int *>(free_list_len.data<int>()),
109-
const_cast<int64_t *>(input_ids.data<int64_t>()),
110-
pre_ids.data<int64_t>(),
111-
step_idx.data<int64_t>(),
112-
encoder_block_lens.data<int>(),
113-
used_list_len.data<int>(),
114-
next_tokens.data<int64_t>(),
115-
first_token_ids.data<int64_t>(),
116-
bsz,
117-
block_num_per_seq,
118-
length,
119-
pre_id_length);
120-
PD_CHECK(r == 0, "speculate_recover_block failed.");
121-
}
48+
SpeculateStepPaddleBase(stop_flags,
49+
seq_lens_this_time,
50+
ori_seq_lens_encoder,
51+
paddle::optional<paddle::Tensor>(),
52+
seq_lens_encoder,
53+
seq_lens_decoder,
54+
block_tables,
55+
encoder_block_lens,
56+
is_block_step,
57+
step_block_list,
58+
step_lens,
59+
recover_block_list,
60+
recover_lens,
61+
need_block_list,
62+
need_block_len,
63+
used_list_len,
64+
free_list,
65+
free_list_len,
66+
input_ids,
67+
pre_ids,
68+
step_idx,
69+
next_tokens,
70+
first_token_ids,
71+
accept_num,
72+
block_size,
73+
encoder_decoder_block_num,
74+
max_draft_tokens);
12275
}
12376

12477
PD_BUILD_STATIC_OP(speculate_step_paddle)

0 commit comments

Comments
 (0)