Skip to content

[Encoding] Implement matmul_k encoding propagation across reshapes. #20367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 14, 2025

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Mar 25, 2025

The revision ports the sdxl propagation effort to the main branch. Ideally, we should implement it using interfaces and data-flow analysis.

It is a first step of the propagation, and we will incrementally enhance the encoding propagation pass.

Co-authored-by: MaheshRavishankar mahesh.ravishankar@gmail.com

Copy link
Contributor

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@hanhanW
Copy link
Contributor Author

hanhanW commented Mar 27, 2025

@MaheshRavishankar implemented most of the logic and I update few implementations based on myself reviews. It'd be weird if I land it with Mahesh's approval.

So @pashu123 @Max191 @IanWood1 can one of you help review?

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This propagation seems to be losing information. It works for the padding encoding case when we don't collapse the K dimension, but this would break with current data tiling encodings.

The reason it breaks is that we are overwriting the original indexing maps of the encoding, so the access pattern of the original matmul is lost. It will already cause things to break in the data tiling encodings because we don't support materialization for multiple M/N/K dimensions.

@hanhanW
Copy link
Contributor Author

hanhanW commented Apr 3, 2025

https://github.com/hanhanW/iree/tree/matmul-k-encoding has the prototype for matmul_k encoding. It compiles llama like my other branch.

@Max191 can you review if hanhanW@05a5efd looks okay to you?

EDIT: it is a POC, I'll update the code when I sent them out for review.

@Max191
Copy link
Contributor

Max191 commented Apr 3, 2025

https://github.com/hanhanW/iree/tree/matmul-k-encoding has the prototype for matmul_k encoding. It compiles llama like my other branch.

@Max191 can you review if hanhanW@05a5efd looks okay to you?

EDIT: it is a POC, I'll update the code when I sent them out for review.

I added a comment on your commit: hanhanW@05a5efd#diff-88c539c7f0005642c0bd605e95e311aaad906cc3c8cb3478561490167afe8f04R47

I think the propagation still has some slight issues in that it loses the original k_dims set. I gave an example of a potential problem case. I think for the padding encodings, it should be possible to do this type of propagation, but I don't think materialization patterns can handle it yet. For now, I think it is better to avoid the propagation when the K dims are reshaped.

@hanhanW
Copy link
Contributor Author

hanhanW commented Apr 3, 2025

For now, I think it is better to avoid the propagation when the K dims are reshaped.

I think it is reasonable because the pad resolver drops the encodings when there are multi reduction dimensions. I should update the matmul_k encoding attribute description. Note that how propagation works should not be driven by resolvers. It should be driven by the encoding semantics, which could lead to different propagation result, IMO.

@Max191
Copy link
Contributor

Max191 commented Apr 3, 2025

Note that how propagation works should not be driven by resolvers. It should be driven by the encoding semantics, which could lead to different propagation result, IMO.

I definitely agree. Propagation shouldn't care what happens in materialization, and it should be only based on the definition of the encoding. The definition of the encoding should reflect how it can be materialized, though. In this case, having multiple K dims should be illegal in the encoding definition because we can't materialize it.

I think we are on the same page, and we just need more concrete definitions of our encodings (i.e., not allowed to have multiple K dimensions, etc.). This would ultimately be defined by the implementations for propagation interface functions, but for now we can just add it to the encoding docs, since the interface functions don't exist yet.

@hanhanW
Copy link
Contributor Author

hanhanW commented Apr 3, 2025

Note that how propagation works should not be driven by resolvers. It should be driven by the encoding semantics, which could lead to different propagation result, IMO.

I definitely agree. Propagation shouldn't care what happens in materialization, and it should be only based on the definition of the encoding. The definition of the encoding should reflect how it can be materialized, though. In this case, having multiple K dims should be illegal in the encoding definition because we can't materialize it.

I think we are on the same page, and we just need more concrete definitions of our encodings (i.e., not allowed to have multiple K dimensions, etc.). This would ultimately be defined by the implementations for propagation interface functions, but for now we can just add it to the encoding docs, since the interface functions don't exist yet.

Cool, I'll slice out the patches from my prototype and send them out for review. Thanks for taking a look!

Note: it only works for matmul_k encoding atm, we should invest the
interface way.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
@hanhanW hanhanW force-pushed the basic-encoding-propagation branch from af07735 to 9d9fbe1 Compare April 10, 2025 23:53
Signed-off-by: hanhanW <hanhan0912@gmail.com>
@hanhanW hanhanW force-pushed the basic-encoding-propagation branch from 9d9fbe1 to 685ddcb Compare April 10, 2025 23:55
@hanhanW hanhanW marked this pull request as ready for review April 10, 2025 23:56
@hanhanW hanhanW requested a review from Max191 April 10, 2025 23:56
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Just some nits

@@ -0,0 +1,33 @@
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-propagate-encodings))" --split-input-file %s | FileCheck %s
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test that shows hoisting does not happen when k dims are collapsed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? This pass does not know anything about hoisting?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I mean propagation. Getting things mixed up today :p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have the test case, see the second test. The set_encoding is in the between of two collapse shape ops.

@Max191
Copy link
Contributor

Max191 commented Apr 11, 2025

Could you also update the description to mention that we only propagate matmul_k encodings?

@hanhanW hanhanW changed the title [Encoding] Implement basic encoding propagation pass across reshapes. [Encoding] Implement matmul_k encoding propagation pass across reshapes. Apr 11, 2025
@hanhanW hanhanW changed the title [Encoding] Implement matmul_k encoding propagation pass across reshapes. [Encoding] Implement basic encoding propagation across reshapes. Apr 11, 2025
@hanhanW hanhanW changed the title [Encoding] Implement basic encoding propagation across reshapes. [Encoding] Implement matmul_k encoding propagation across reshapes. Apr 11, 2025
@hanhanW hanhanW requested a review from Max191 April 11, 2025 20:39
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test that shows propagation being blocked by collapsing k_dims would be good to have, and then LGTM!

@hanhanW hanhanW merged commit 7724306 into iree-org:main Apr 14, 2025
39 of 42 checks passed
@hanhanW hanhanW deleted the basic-encoding-propagation branch April 14, 2025 17:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants