Skip to content

Patch fusion linear for bert and vit #786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 5, 2024
Merged

Conversation

jiqing-feng
Copy link
Collaborator

@jiqing-feng jiqing-feng commented Jun 27, 2024

Hi @echarlaix . I added 2 more supported patched models: Bert and Vit. I only use the linear fusion module to optimize these 2 models; it will bring at least a 10% speed-up on SPR with little code addition.

Would you please take a review on these changes? Thx!

@@ -60,7 +64,7 @@
logger = logging.getLogger(__name__)


_IPEX_SUPPORT_MODEL_TYPES = ("llama",)
_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add tests for both architectures ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Bert and Vit models are already in the tests, and they will automatically go into the patching path. Please give me more instructions, like what kind of tests you need. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to ensure compatibility for all these architectures by pushing a patched tiny random models on the hub and verify generated outputs are the same as the original model (in order to make sure we keep compatibility with already exported model)

Comment on lines -142 to -145
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(
f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports IndirectAccessKVCacheAttention, LinearAdd, RotaryEmbedding"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why moving ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -27,9 +27,20 @@

_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.3.0"

if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(f"Only ipex version > {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports patching")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reply to #786 (comment), I moved the logic at the top of this file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we want to restrict support to IPEX >= v2.3.0 for all models (even the ones that are not patched) ?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jiqing-feng
Copy link
Collaborator Author

jiqing-feng commented Jul 4, 2024

Hi @echarlaix , I will get some patched models and upload them for testing.

Could you take a look at the failed tests? You will find that we could pass the test when running individually but fail when running together.

That's mainly because of the file system state, do you have any suggestions for it? I searched online and found that @pytest.fixture may works, but have no clue on how to use it. If you think it is acceptable, you can give me some instructions and I will fix it. If you have better idea, please let me know. Thx!

@jiqing-feng
Copy link
Collaborator Author

jiqing-feng commented Jul 4, 2024

Hi @echarlaix. I have changed the Ipex version check to make sure it will not have an impact on no patching path and also avoid duplicate codes. I also uploaded 2 models: patched_tiny_random_bert_for_question_answering and patched_tiny_random_vit_for_image_classification for test patching models. Could you please take a review? Thx!

BTW, if you need more pached models for all tasks, please let me know. Thx!

Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot @jiqing-feng

@unittest.skipIf(is_ipex_version("<", "2.3.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
def test_patched_model(self):
ipex_model = IPEXModelForQuestionAnswering.from_pretrained(
"Jiqing/patched_tiny_random_bert_for_question_answering"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for adding this test !

@jiqing-feng
Copy link
Collaborator Author

Hi @echarlaix . For the failed tests, the real issue comes from enable_tpp. It will change some environment variables and trigger the traced model check here, so the tests will fail when running together.

I have fixed them by change another way to check if the traced model has been patched or not. Please take a review, thx!

@echarlaix echarlaix merged commit 69b5ea5 into huggingface:main Jul 5, 2024
11 of 16 checks passed
@jiqing-feng jiqing-feng deleted the fusion branch July 9, 2024 00:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants