-
Notifications
You must be signed in to change notification settings - Fork 424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward) #15331
Conversation
Given that cuDNN's I believe that we should already support calls to scaled dot product attention through JAX directly, is that correct? |
Do you refer to jax-ml/jax#22670? If so, jax's sdpa api still calls cudnn sdpa from XLA behind the scene. Plus, the forward pass PR is already merged. |
You're right, that seems reasonable, thanks for the clarification. |
XlaBuilder builder(TestName()); | ||
std::string hlo_string_ref = | ||
R"( | ||
HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,4,16,16]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what is the purpose of all the HLO leading to the custom call? Does it actually provide value? If not, should we just compare the upcasted custom call vs the reference one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The HLO leading to the custom-call performs a "cast-to-representable" operation, which adjusts the input to fit within the range that the FP8 data type can represent. Therefore, it's also necessary for the reference implementation to include this step in order to maintain numerical equivalence.
ROOT out = bf16[4,4,16,16]{3,1,2,0} convert(get-tuple-element.5.0) | ||
} // main.106 | ||
)"; // NOLINT | ||
EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to run HLO passes here? If not, let's disable them in this call. (/*run_hlo_passes=*/false
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The input HLOs are before optimization ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of the test seems to be to test that the emitted custom call is correct, right? Presumably, we are not trying to test the end-to-end compilation pipeline---or are we testing anything particularly useful here by running it?
Let's make sure to use already-optimized HLO here and in the other cases where we try to ensure correctness of the custom call, and only use before-optimizations HLO where it is necesary.
GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; | ||
} | ||
XlaBuilder builder(TestName()); | ||
// generate padding mask in cuDNN directly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we actually doing any pattern matching here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No patten matching here.
// XLA pattern match does not support pattern matching padding mask | ||
// so directly lower to custom call instead for reference | ||
std::string hlo_string_ref = R"( | ||
HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0})->bf16[1,1,256,128]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same questions as above for the HLOs.
Gentle ping @wenscarl :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wenscarl thank you for thoughtfully addressing the comments! Love the new tests, they look great!
I'd prefer if we ran without HLO passes here, but I suppose it doesn't hurt the test that much and improves readability to leave it as is---since the clipping logic can be called several times without duplication.
…8. Part- 2(backward) Imported from GitHub PR #15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8 by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e2 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba by shuw <shuw@nvidia.com>: clang-format -- 0825789 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf PiperOrigin-RevId: 683501409
…8. Part- 2(backward) Imported from GitHub PR openxla/xla#15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>: clang-format -- 08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6 PiperOrigin-RevId: 683501409
…8. Part- 2(backward) Imported from GitHub PR #15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8 by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e2 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba by shuw <shuw@nvidia.com>: clang-format -- 0825789 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf PiperOrigin-RevId: 683501409
…8. Part- 2(backward) Imported from GitHub PR openxla/xla#15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>: clang-format -- 08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6 PiperOrigin-RevId: 683501409
…8. Part- 2(backward) Imported from GitHub PR #15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8 by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e2 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba by shuw <shuw@nvidia.com>: clang-format -- 0825789 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf PiperOrigin-RevId: 683501409
…8. Part- 2(backward) Imported from GitHub PR openxla/xla#15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>: clang-format -- 08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6 PiperOrigin-RevId: 683501409
…8. Part- 2(backward) Imported from GitHub PR #15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8 by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e2 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba by shuw <shuw@nvidia.com>: clang-format -- 0825789 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf PiperOrigin-RevId: 684025541
…8. Part- 2(backward) Imported from GitHub PR openxla/xla#15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>: clang-format -- 08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6 PiperOrigin-RevId: 684025541
…8. Part- 2(backward) Imported from GitHub PR #15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8 by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e2 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba by shuw <shuw@nvidia.com>: clang-format -- 0825789 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf PiperOrigin-RevId: 684025541
…8. Part- 2(backward) Imported from GitHub PR #15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8 by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e2 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba by shuw <shuw@nvidia.com>: clang-format -- 0825789 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf PiperOrigin-RevId: 684025541
…8. Part- 2(backward) Imported from GitHub PR #15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8 by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e2 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba by shuw <shuw@nvidia.com>: clang-format -- 0825789 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf PiperOrigin-RevId: 683501409
…8. Part- 2(backward) Imported from GitHub PR openxla/xla#15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>: clang-format -- 08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6 PiperOrigin-RevId: 684025541
…8. Part- 2(backward) Imported from GitHub PR openxla/xla#15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>: clang-format -- 08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6 PiperOrigin-RevId: 683501409
…8. Part- 2(backward) Imported from GitHub PR openxla/xla#15331 As the 2nd part of #15092. NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet. Copybara import of the project: -- 06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>: Scaled dot product attention implementation by cudnn. -- 937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>: Improve after review 1 -- 398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>: clang-format -- 08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>: fix typo. -- d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>: Refactor test Merging this change closes #15331 PiperOrigin-RevId: 684062495
As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.