-
Notifications
You must be signed in to change notification settings - Fork 735
[Encoding] Implement matmul_k encoding propagation across reshapes. #20367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
compiler/src/iree/compiler/DispatchCreation/PropagateEncodings.cpp
Outdated
Show resolved
Hide resolved
@MaheshRavishankar implemented most of the logic and I update few implementations based on myself reviews. It'd be weird if I land it with Mahesh's approval. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This propagation seems to be losing information. It works for the padding encoding case when we don't collapse the K dimension, but this would break with current data tiling encodings.
The reason it breaks is that we are overwriting the original indexing maps of the encoding, so the access pattern of the original matmul is lost. It will already cause things to break in the data tiling encodings because we don't support materialization for multiple M/N/K dimensions.
https://github.com/hanhanW/iree/tree/matmul-k-encoding has the prototype for matmul_k encoding. It compiles llama like my other branch. @Max191 can you review if hanhanW@05a5efd looks okay to you? EDIT: it is a POC, I'll update the code when I sent them out for review. |
I added a comment on your commit: hanhanW@05a5efd#diff-88c539c7f0005642c0bd605e95e311aaad906cc3c8cb3478561490167afe8f04R47 I think the propagation still has some slight issues in that it loses the original k_dims set. I gave an example of a potential problem case. I think for the padding encodings, it should be possible to do this type of propagation, but I don't think materialization patterns can handle it yet. For now, I think it is better to avoid the propagation when the K dims are reshaped. |
I think it is reasonable because the pad resolver drops the encodings when there are multi reduction dimensions. I should update the matmul_k encoding attribute description. Note that how propagation works should not be driven by resolvers. It should be driven by the encoding semantics, which could lead to different propagation result, IMO. |
I definitely agree. Propagation shouldn't care what happens in materialization, and it should be only based on the definition of the encoding. The definition of the encoding should reflect how it can be materialized, though. In this case, having multiple K dims should be illegal in the encoding definition because we can't materialize it. I think we are on the same page, and we just need more concrete definitions of our encodings (i.e., not allowed to have multiple K dimensions, etc.). This would ultimately be defined by the implementations for propagation interface functions, but for now we can just add it to the encoding docs, since the interface functions don't exist yet. |
Cool, I'll slice out the patches from my prototype and send them out for review. Thanks for taking a look! |
Note: it only works for matmul_k encoding atm, we should invest the interface way. Signed-off-by: hanhanW <hanhan0912@gmail.com>
af07735
to
9d9fbe1
Compare
Signed-off-by: hanhanW <hanhan0912@gmail.com>
9d9fbe1
to
685ddcb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good! Just some nits
compiler/src/iree/compiler/DispatchCreation/PropagateEncodings.cpp
Outdated
Show resolved
Hide resolved
compiler/src/iree/compiler/DispatchCreation/PropagateEncodings.cpp
Outdated
Show resolved
Hide resolved
@@ -0,0 +1,33 @@ | |||
// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-propagate-encodings))" --split-input-file %s | FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test that shows hoisting does not happen when k dims are collapsed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? This pass does not know anything about hoisting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I mean propagation. Getting things mixed up today :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have the test case, see the second test. The set_encoding is in the between of two collapse shape ops.
Could you also update the description to mention that we only propagate matmul_k encodings? |
Signed-off-by: hanhanW <hanhan0912@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A test that shows propagation being blocked by collapsing k_dims would be good to have, and then LGTM!
The revision ports the sdxl propagation effort to the main branch. Ideally, we should implement it using interfaces and data-flow analysis.
It is a first step of the propagation, and we will incrementally enhance the encoding propagation pass.
Co-authored-by: MaheshRavishankar mahesh.ravishankar@gmail.com