-
Notifications
You must be signed in to change notification settings - Fork 130
Patch fusion linear for bert and vit #786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -60,7 +64,7 @@ | |||
logger = logging.getLogger(__name__) | |||
|
|||
|
|||
_IPEX_SUPPORT_MODEL_TYPES = ("llama",) | |||
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add tests for both architectures ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Bert and Vit models are already in the tests, and they will automatically go into the patching path. Please give me more instructions, like what kind of tests you need. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to ensure compatibility for all these architectures by pushing a patched tiny random models on the hub and verify generated outputs are the same as the original model (in order to make sure we keep compatibility with already exported model)
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): | ||
raise ImportError( | ||
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why moving ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #786 (comment)
@@ -27,9 +27,20 @@ | |||
|
|||
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0" | |||
|
|||
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): | |||
raise ImportError(f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports patching") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reply to #786 (comment), I moved the logic at the top of this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we want to restrict support to IPEX >= v2.3.0 for all models (even the ones that are not patched) ?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi @echarlaix , I will get some patched models and upload them for testing. Could you take a look at the failed tests? You will find that we could pass the test when running individually but fail when running together. That's mainly because of the file system state, do you have any suggestions for it? I searched online and found that |
Hi @echarlaix. I have changed the Ipex version check to make sure it will not have an impact on no patching path and also avoid duplicate codes. I also uploaded 2 models: patched_tiny_random_bert_for_question_answering and patched_tiny_random_vit_for_image_classification for test patching models. Could you please take a review? Thx! BTW, if you need more pached models for all tasks, please let me know. Thx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a lot @jiqing-feng
@unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching") | ||
def test_patched_model(self): | ||
ipex_model = IPEXModelForQuestionAnswering.from_pretrained( | ||
"Jiqing/patched_tiny_random_bert_for_question_answering" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot for adding this test !
Hi @echarlaix . For the failed tests, the real issue comes from enable_tpp. It will change some environment variables and trigger the traced model check here, so the tests will fail when running together. I have fixed them by change another way to check if the traced model has been patched or not. Please take a review, thx! |
Hi @echarlaix . I added 2 more supported patched models: Bert and Vit. I only use the linear fusion module to optimize these 2 models; it will bring at least a 10% speed-up on SPR with little code addition.
Would you please take a review on these changes? Thx!