Skip to content

Commit f412d66

Browse files
[transformation][CPU][GPU] support 3D rope (#31949)
### Details: - *support 3D rope fusing in Qwen2.5-vl-3b and Qwen2-vl-7b* ### Tickets: - *CVS-171588, CVS-171442* --------- Signed-off-by: HU Yuan2 <yuan2.hu@intel.com> Co-authored-by: zaixing.wang <zaixing.wang@intel.com>
1 parent 90b9b38 commit f412d66

File tree

13 files changed

+410
-86
lines changed

13 files changed

+410
-86
lines changed

src/common/transformations/include/ov_ops/rotary_positional_embeddings.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class TRANSFORMATIONS_API RoPE : public Op {
3232
// each head. change input order to [batch, head_cnt, 4608] to support 2d rope
3333
bool is_qwen = false; // Qwen is special which overrides other setting
3434
bool use_rope_cache = false; // use precomputed RoPE cache for trigonometric values (cosine and sine)
35+
bool support_3d_rope = false; // use same logic as RoPEFusionGPTNEOX(4), used by gpu plugin
3536
size_t head_cnt = 0;
3637
size_t head_size = 0;
3738
int gather_position_arg_id =

src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class TRANSFORMATIONS_API RoPEShareCosSin;
2828
class ov::pass::RoPEFusionGPTNEOX : public ov::pass::MatcherPass {
2929
public:
3030
OPENVINO_MATCHER_PASS_RTTI("RoPEFusionGPTNEOX");
31-
RoPEFusionGPTNEOX();
31+
RoPEFusionGPTNEOX(int rank);
3232
};
3333

3434
class ov::pass::RoPEFusionFlux : public ov::pass::MatcherPass {

src/common/transformations/src/ov_ops/rotary_positional_embeddings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ bool RoPE::visit_attributes(ov::AttributeVisitor& visitor) {
9595
visitor.on_attribute("rotary_ndims", m_config.rotary_ndims);
9696
visitor.on_attribute("is_chatglm", m_config.is_chatglm);
9797
visitor.on_attribute("support_2d_rope", m_config.support_2d_rope);
98+
visitor.on_attribute("support_3d_rope", m_config.support_3d_rope);
9899
visitor.on_attribute("is_qwen", m_config.is_qwen);
99100
visitor.on_attribute("use_rope_cache", m_config.use_rope_cache);
100101
visitor.on_attribute("head_cnt", m_config.head_cnt);

src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ bool ov::pass::RoPEFusion::run_on_model(const std::shared_ptr<ov::Model>& model)
5858
auto symbolic_ctx_manager = symbolic_optimizations.get_manager();
5959

6060
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionFlux>();
61-
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTNEOX>();
61+
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTNEOX>(4);
62+
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTNEOX>(3);
6263
symbolic_ctx_manager->register_pass<ov::pass::RoPEFusionGPTJ>();
6364
// optional heads & tails are fused in separate matcher pass,
6465
// after RoPENode has been created.
@@ -164,7 +165,7 @@ ov::pass::RoPEFusionFlux::RoPEFusionFlux() {
164165
this->register_matcher(m, callback);
165166
}
166167

167-
ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() {
168+
ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX(int rank) {
168169
using namespace ov::op::util;
169170
MATCHER_SCOPE(RoPEFusionGPTNEOX);
170171

@@ -177,19 +178,19 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() {
177178
// branch.
178179
// so here we use a WA, only match the path of rotate_hal(x)*sin and check the x*cos path
179180
// in the callback
180-
auto x = pattern::any_input(pattern::rank_equals(4));
181-
auto x_or_cos1 = pattern::any_input(pattern::rank_equals(4));
182-
auto x_or_cos2 = pattern::any_input(pattern::rank_equals(4));
183-
auto t_sin = pattern::any_input(pattern::rank_equals(4));
181+
auto x = pattern::any_input(pattern::rank_equals(rank));
182+
auto x_or_cos1 = pattern::any_input(pattern::rank_equals(rank));
183+
auto x_or_cos2 = pattern::any_input(pattern::rank_equals(rank));
184+
auto t_sin = pattern::any_input(pattern::rank_equals(rank));
184185

185-
auto varsplit = pattern::wrap_type<v1::VariadicSplit>({x, 3, {"half_ndims", "?"}});
186+
auto varsplit = pattern::wrap_type<v1::VariadicSplit>({x, rank - 1, {"half_ndims", "?"}});
186187
varsplit->set_output_size(2);
187188

188189
auto int32_max = std::numeric_limits<std::int32_t>::max();
189190

190-
auto x2 = NewGenSlice(x, "half_ndims", int32_max, 1, 3);
191+
auto x2 = NewGenSlice(x, "half_ndims", int32_max, 1, rank - 1);
191192
auto x2neg = pattern::wrap_type<v1::Multiply>({x2 | varsplit->output(1), -1.0f}, {{"auto_broadcast", "numpy"}});
192-
auto x1 = NewGenSlice(x, 0, "half_ndims", 1, 3);
193+
auto x1 = NewGenSlice(x, 0, "half_ndims", 1, rank - 1);
193194
auto x_rotate_half = pattern::wrap_type<v0::Concat>({x2neg, x1 | varsplit->output(0)}, {{"axis", -1}});
194195

195196
auto mul_cos = pattern::wrap_type<v1::Multiply>({x_or_cos1, x_or_cos2}, {{"auto_broadcast", "numpy"}});
@@ -220,6 +221,9 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() {
220221

221222
op::internal::RoPE::Config config;
222223
OutputVector new_args;
224+
if (rank == 3) {
225+
config.support_3d_rope = true;
226+
}
223227
config.rotary_ndims = 2ul * static_cast<size_t>(half_ndims.i());
224228

225229
new_args.push_back(pattern_map.at(x));

src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp

Lines changed: 128 additions & 9 deletions
Large diffs are not rendered by default.

src/plugins/intel_cpu/src/nodes/rope.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
129129
t_src = t_src.slice(3, m_config.slice_start, m_config.slice_stop);
130130
can_inplace = false;
131131
}
132+
if (t_src.m_rank == 3) {
133+
t_src = t_src.reshape({1, t_src.size(0), t_src.size(1), t_src.size(2)});
134+
t_dst = t_dst.reshape({1, t_dst.size(0), t_dst.size(1), t_dst.size(2)});
135+
}
136+
132137
if (m_config.input_trans0213) {
133138
t_src = t_src.permute({0, 2, 1, 3});
134139
can_inplace = false;
@@ -139,9 +144,13 @@ struct RoPE::RoPEExecutorRotateHalf : public RoPE::Executor {
139144

140145
if (t_cos.m_rank == 2) {
141146
t_cos = t_cos.reshape({1, 1, t_cos.size(0), t_cos.size(1)});
147+
} else if (t_cos.m_rank == 3) {
148+
t_cos = t_cos.reshape({1, t_cos.size(0), t_cos.size(1), t_cos.size(2)});
142149
}
143150
if (t_sin.m_rank == 2) {
144151
t_sin = t_sin.reshape({1, 1, t_sin.size(0), t_sin.size(1)});
152+
} else if (t_sin.m_rank == 3) {
153+
t_sin = t_sin.reshape({1, t_sin.size(0), t_sin.size(1), t_sin.size(2)});
145154
}
146155

147156
auto batch_size = t_src.size(0);

src/plugins/intel_cpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,14 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM,
7070
::testing::Values(ov::test::utils::DEVICE_CPU)),
7171
RoPETestChatGLM2DRoPEStridedSlice::getTestCaseName);
7272

73+
const std::vector<std::string> vit_param = {"VariadicSplit", "Slice", "StridedSlice"};
74+
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwenVL,
75+
RoPETestQwenVL,
76+
::testing::Combine(
77+
::testing::Values(ov::element::f32),
78+
::testing::Values(ov::test::utils::DEVICE_CPU),
79+
::testing::ValuesIn(vit_param)),
80+
RoPETestQwenVL::getTestCaseName);
81+
7382
} // namespace test
7483
} // namespace ov

src/plugins/intel_gpu/src/graph/impls/ocl_v2/rope_opt.cl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,6 @@ KERNEL(rope_opt)
344344
input_idx += SLICED_FROM_START;
345345
#endif
346346
#endif
347-
348347
uint cos_sin_p = p;
349348
#ifdef ENABLE_GATHER
350349
uint gather_b = b < INPUT3_BATCH_NUM ? b : 0;
@@ -378,7 +377,7 @@ uint cos_sin_p = p;
378377
uint cos_sin_h = 0;
379378
cos_sin_p = cos_sin_p < INPUT1_BATCH_NUM ? cos_sin_p : 0;
380379

381-
#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS
380+
#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS
382381
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_p, 0, 0, 0);
383382

384383
uint cos_idx = cos_sin_idx;
@@ -387,8 +386,20 @@ uint cos_sin_p = p;
387386
uint cos_idx = INPUT1_GET_INDEX(cos_sin_p, 0, 0, 0);
388387
uint sin_idx = INPUT2_GET_INDEX(cos_sin_p, 0, 0, 0);
389388
#endif
389+
#elif INPUT1_DIMS == 3 && INPUT2_DIMS == 3
390+
uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
391+
uint cos_sin_h = h < INPUT1_FEATURE_NUM ? h : 0;
392+
#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS
393+
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, 0, 0);
394+
395+
uint cos_idx = cos_sin_idx;
396+
uint sin_idx = cos_sin_idx;
390397
#else
391-
# error "rope_opt.cl - 4 or 2 of INPUT1_DIMS/INPUT2_DIMS is supported only"
398+
uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, 0, 0);
399+
uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_h, 0, 0);
400+
#endif
401+
#else
402+
# error "rope_opt.cl - 2, 3 or 4 of INPUT1_DIMS/INPUT2_DIMS is supported only"
392403
#endif
393404

394405
uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0);
@@ -415,6 +426,7 @@ uint cos_sin_p = p;
415426
*(OUTPUT_VEC_TYPE*)(output + output_idx + r) = out1;
416427
*(OUTPUT_VEC_TYPE*)(output + output_idx + HALF_ROTARY_NDIMS + r) = out2;
417428
#endif
429+
418430
}
419431
#endif
420432

src/plugins/intel_gpu/src/graph/impls/ocl_v2/rope_opt.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,10 @@ class RopeGenerator : public KernelGenerator {
152152
auto b = extract_channel(ChannelName::BATCH, out_l);
153153
auto f = extract_channel(ChannelName::FEATURE, out_l);
154154
auto y = extract_channel(ChannelName::Y, out_l);
155-
156155
wgs.global = {b, f, y * cfg.rotary_ndims / 2ul / vec_size};
156+
if (cfg.support_3d_rope) {
157+
wgs.global = {b, f, cfg.rotary_ndims / 2ul / vec_size};
158+
}
157159
}
158160

159161
wgs.local = ov::intel_gpu::get_optimal_lws(wgs.global, params.get_device_info(), in_l.format, out_l.format, dims_by_gws);

src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,33 @@ INSTANTIATE_TEST_SUITE_P(smoke_RoPETestFlux,
1414
::testing::Values(ov::test::utils::DEVICE_GPU)),
1515
RoPETestFlux::getTestCaseName);
1616

17+
class GPURoPETestQwenVL : public RoPETestQwenVL {
18+
protected:
19+
void SetUp() override {
20+
RoPETestQwenVL::SetUp();
21+
const auto& [element_type, _targetDevice, split_op_type] = this->GetParam();
22+
if (element_type == ov::element::f16) {
23+
abs_threshold = 0.015f;
24+
}
25+
}
26+
};
27+
28+
TEST_P(GPURoPETestQwenVL, CompareWithRefs) {
29+
SKIP_IF_CURRENT_TEST_IS_DISABLED();
30+
run();
31+
auto function = compiledModel.get_runtime_model();
32+
CheckNumberOfNodesWithType(function, {"RoPE"}, 1);
33+
};
34+
35+
const std::vector<std::string> vit_param = {"VariadicSplit", "Slice", "StridedSlice"};
36+
INSTANTIATE_TEST_SUITE_P(smoke_RoPEQwenVL,
37+
GPURoPETestQwenVL,
38+
::testing::Combine(
39+
::testing::Values(ov::element::f16, ov::element::f32),
40+
::testing::Values(ov::test::utils::DEVICE_GPU),
41+
::testing::ValuesIn(vit_param)),
42+
GPURoPETestQwenVL::getTestCaseName);
43+
1744
INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM,
1845
RoPETestChatGLMStridedSlice,
1946
::testing::Combine(

0 commit comments

Comments
 (0)