-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MHLO operation regions need to use scalars arguments #22
Comments
Not all reduction are scalars though. The zero-rank is just the degenerated case, but take for example (from the test-suite):
|
Yes, I did notice that (and actually didnt know that this existed). Specifically such an operation cannot be lowered to Linalg directly (today). So maybe all that is needed is an MHLO -> MHLO transform before lowering to Linalg that converts the zero-rank tensor case to scalars and converts |
Is this actually a core issue about how we want to model mhlo.reduce going forward, or is this more about the mechanics of lowering to linalg? It feels like lowing mhlo.add to arith.addf for the payload can be handled by being somewhat more sophisticated about the use of the dialect conversion infrastructure (setting up legality properly, or doing the conversion in two phases, or something). |
I don't think you need the dialect conversion framework to handle this. I think the biggest issue is what operations are supported in the MHLO reduce region. I could easily see non-elementwise operations being used in the reduction region preventing lowering to linalg. |
Not a stakeholder in MHLO per se, but for me |
So if `mhlo.reduce` does support tensor operations in the payload, there needs to be further mhlo -> mhlo transformations that would be needed to get it to state where it can be lowered to Linalg (as an example). |
Imported from GitHub PR tensorflow/tensorflow#58720 Enables scaled GEMMs based on `F8E4M3FN` and `F8E5M2` [FP8 data types](https://arxiv.org/abs/2209.05433). The pattern described by steps 1 through 6 in [RFC #22](openxla/xla#22) is rewritten into a Custom Call of the form (A, B, a_scale, b_scale, d_scale) -> (D, d_amax), where A, B and D are FP8 matrices and a_scale, b_scale and d_scale are their respective scaling factors. The scalar d_amax gives the maximum of the absolute values in D before rescaling and casting to FP8 and can be used in the calculation of new scaling factors. Copybara import of the project: -- f2eb35a9efcaaffdbb7314f99521357840bd49d8 by Philipp Hack <phack@nvidia.com>: Support for FP8 GEMMs in XLA. -- 0afd695b3840417fdb1c00987c8c5e980be0de33 by Philipp Hack <phack@nvidia.com>: Support for FP8 GEMMs in XLA. -- 5aba0882bc624215613c77d73dd23ec3b1d8b0d9 by Philipp Hack <phack@nvidia.com>: Support for FP8 GEMMs in XLA. -- 8d18d22d61b1b440421fc3dd402acdaaf27519b3 by Philipp Hack <phack@nvidia.com>: Support for FP8 GEMMs in XLA. -- 7759e0a5d041c26c632d4e433d5f544e0194ea40 by Philipp Hack <phack@nvidia.com>: Support for FP8 GEMMs in XLA. Merging this change closes #58720 PiperOrigin-RevId: 495806551
MHLO operations that have regions use a zero-rank tensor to represent what are really scalar values. For example
There are a couple of issues here.
mhlo.reduce
here has anmhlo.add
. The way one would lowermhlo.add
to saylinalg
dialect is very different whether this operation is within anmhlo
op or at the top level. This seems to be a conflation between different uses of anmhlo.add
operation. It would be much easier to handle this ifmhlo.add
was only used at the top level and a different operation was used withinmhlo
operations.mhlo
operation in this case seems to be a sequence of computations that are really scalars. Using tensor of zero rank introduces additional complexity when translating this toLinalg
dialect since this requires a type conversion of the arguments from zero rank tensor to scalars. Having this scalar before the conversion would reduce a lot of the complexity.The text was updated successfully, but these errors were encountered: