-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] Fuse SDPA and Concat as early as possible #28189
[CPU] Fuse SDPA and Concat as early as possible #28189
Conversation
440527f
to
fb472cd
Compare
1bc0547
to
108f726
Compare
ef0f6d0
to
02c2d19
Compare
src/common/transformations/include/transformations/utils/gen_pattern.hpp
Outdated
Show resolved
Hide resolved
.../intel_cpu/tests/functional/custom/subgraph_tests/src/x64/fuse_reshape_transpose_to_sdpa.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
CPU_REGISTER_PASS_COMMON(manager, ov::pass::transpose_sinking::TSShapeOfForward); | ||
CPU_REGISTER_PASS_COMMON(manager, StatefulSDPAFusion); | ||
// TODO: SDPAFuseTransposeReshape may cause regressions in icx. | ||
// CPU_REGISTER_PASS_X64(manager, ov::intel_cpu::SDPAFuseTransposeReshape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any details on that?
I recall that pass was added for Whisper model. So how do we guarantee we don't bring regression to Whisper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently no LLM models will hit SDPAFuseTransposeReshape including Whisper so there is no regression, checked on WW52 models. The pattern would work for customized model which should not upstream yet. From @maxnick's comment, there is a plan to use snippets to cover the pattern. So, should we leave the case to snippets or use this transformation to cover?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This customized pattern is expected to become default: huggingface/optimum-intel#1078. Export part has been waiting runtime optimizations actually and SDPAFuseTransposeReshape is part of these optimizations https://jira.devtools.intel.com/browse/CVS-153616.
Snippets are responsible for SDPA patterns w/o states (like regular transformers or diffusers). SDPA with states (LLMs. Whisper) should be processed via custom SDPA node as of now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SDPAFuseTransposeReshape
only works with SDPA when fusing to SDPA+concat failed(if success the node type will be ScaledDotProductAttentionWithKVCache
):
Lines 54 to 55 in 548786a
auto sdpa_node = | |
wrap_type<op::v13::ScaledDotProductAttention>({q_transpose_node, k_transpose_node, v_transpose_node}); |
I suppose huggingface/optimum-intel#1078 will change current Whisper model into stateful model, so, we may need to change
SDPAFuseTransposeReshape
to support ScaledDotProductAttentionWithKVCache
or add the function into StatefulSDPAFusion
. If my understanding is correct, I will create a ticket to track this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xipingyan Could you please clarify how SDPAFuseTransposeReshape worked for Whisper model? Have StatefulSDPAFusion falied for this model before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @dmitry-gorokhov @luo-cheng2021 ,
Actually, Whisper model has 2 kinds of SDPA OPS. First one is mapped to ScaledDotProductAttentionWithKVCache
, second one is mapped to ScaledDotProductAttention
. In my case, SDPAFuseTransposeReshape
only works for second one. So @luo-cheng2021 's suggestion "change SDPAFuseTransposeReshape
to support ScaledDotProductAttentionWithKVCache
" will not work.
@dmitry-gorokhov , I just aligned with @luo-cheng2021 , we need to double confirm whether current master has supported dynamic shape's snippet,
If yes, this pattern SDPAFuseTransposeReshape
can be removed.
If no, maybe we can try to merge Reshape+Transpose before SDPA into init graph of readvalue, take it as temp solution, the snippet with dynamic shape still is our final target.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disscussed offline.
Agreed to updated conditions for SDPAFuseTransposeReshape applicability: it should check that Reshape op goes after ReadValue. That will help to limit the pass implication on Whisper model only and avoif negative perf impact on SD and others.
Once Snippets will support such patterns with dynamic shapes, SDPAFuseTransposeReshape might be fully removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, no models in ww52 will hit SDPAFuseTransposeReshape
, only customized Whisper will hit it.
.../intel_cpu/tests/functional/custom/subgraph_tests/src/x64/fuse_reshape_transpose_to_sdpa.cpp
Outdated
Show resolved
Hide resolved
### Details: - *Move StatefulSDPAFusion before CommonOptimizations* - *...* ### Tickets: - *[158738](https://jira.devtools.intel.com/browse/CVS-158738)*
Details:
Tickets: