Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward) #15331

Closed
wants to merge 5 commits into from

Conversation

wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Jul 25, 2024

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.

@wenscarl wenscarl changed the title Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward) [draft]Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward) Jul 25, 2024
@wenscarl wenscarl marked this pull request as draft July 25, 2024 19:38
@wenscarl wenscarl marked this pull request as ready for review August 16, 2024 03:58
@wenscarl wenscarl changed the title [draft]Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward) Support cuDNN frontend scaled dot product attention for FP8. Part- 2(backward) Aug 16, 2024
@bchetioui
Copy link
Member

Given that cuDNN's FlashAttention is meant to remain behind a flag (as discussed previously), I wonder whether it still makes sense to integrate this within XLA.

I believe that we should already support calls to scaled dot product attention through JAX directly, is that correct?

@wenscarl
Copy link
Contributor Author

Given that cuDNN's FlashAttention is meant to remain behind a flag (as discussed previously), I wonder whether it still makes sense to integrate this within XLA.

I believe that we should already support calls to scaled dot product attention through JAX directly, is that correct?

Do you refer to jax-ml/jax#22670? If so, jax's sdpa api still calls cudnn sdpa from XLA behind the scene. Plus, the forward pass PR is already merged.

@bchetioui
Copy link
Member

Do you refer to jax-ml/jax#22670? If so, jax's sdpa api still calls cudnn sdpa from XLA behind the scene. Plus, the forward pass PR is already merged.

You're right, that seems reasonable, thanks for the clarification.

xla/service/gpu/cublas_cudnn.cc Outdated Show resolved Hide resolved
XlaBuilder builder(TestName());
std::string hlo_string_ref =
R"(
HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,4,16,16]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what is the purpose of all the HLO leading to the custom call? Does it actually provide value? If not, should we just compare the upcasted custom call vs the reference one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HLO leading to the custom-call performs a "cast-to-representable" operation, which adjusts the input to fit within the range that the FP8 data type can represent. Therefore, it's also necessary for the reference implementation to include this step in order to maintain numerical equivalence.

ROOT out = bf16[4,4,16,16]{3,1,2,0} convert(get-tuple-element.5.0)
} // main.106
)"; // NOLINT
EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to run HLO passes here? If not, let's disable them in this call. (/*run_hlo_passes=*/false).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The input HLOs are before optimization ones.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of the test seems to be to test that the emitted custom call is correct, right? Presumably, we are not trying to test the end-to-end compilation pipeline---or are we testing anything particularly useful here by running it?

Let's make sure to use already-optimized HLO here and in the other cases where we try to ensure correctness of the custom call, and only use before-optimizations HLO where it is necesary.

GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0.";
}
XlaBuilder builder(TestName());
// generate padding mask in cuDNN directly
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we actually doing any pattern matching here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No patten matching here.

// XLA pattern match does not support pattern matching padding mask
// so directly lower to custom call instead for reference
std::string hlo_string_ref = R"(
HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0}, bf16[1,1,256,128]{3,2,1,0})->bf16[1,1,256,128]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same questions as above for the HLOs.

xla/service/gpu/transforms/cudnn_custom_call_compiler.cc Outdated Show resolved Hide resolved
xla/service/gpu/transforms/cudnn_custom_call_compiler.cc Outdated Show resolved Hide resolved
@bchetioui
Copy link
Member

Gentle ping @wenscarl :)

Copy link
Member

@bchetioui bchetioui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenscarl thank you for thoughtfully addressing the comments! Love the new tests, they look great!

I'd prefer if we ran without HLO passes here, but I suppose it doesn't hurt the test that much and improves readability to leave it as is---since the clipping logic can be called several times without duplication.

copybara-service bot pushed a commit that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR #15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8 by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e2 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba by shuw <shuw@nvidia.com>:

clang-format

--
0825789 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf
PiperOrigin-RevId: 683501409
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR openxla/xla#15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>:

clang-format

--
08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6
PiperOrigin-RevId: 683501409
copybara-service bot pushed a commit that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR #15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8 by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e2 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba by shuw <shuw@nvidia.com>:

clang-format

--
0825789 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf
PiperOrigin-RevId: 683501409
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR openxla/xla#15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>:

clang-format

--
08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6
PiperOrigin-RevId: 683501409
copybara-service bot pushed a commit that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR #15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8 by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e2 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba by shuw <shuw@nvidia.com>:

clang-format

--
0825789 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf
PiperOrigin-RevId: 683501409
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR openxla/xla#15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>:

clang-format

--
08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6
PiperOrigin-RevId: 683501409
copybara-service bot pushed a commit that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR #15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8 by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e2 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba by shuw <shuw@nvidia.com>:

clang-format

--
0825789 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf
PiperOrigin-RevId: 684025541
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR openxla/xla#15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>:

clang-format

--
08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6
PiperOrigin-RevId: 684025541
copybara-service bot pushed a commit that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR #15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8 by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e2 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba by shuw <shuw@nvidia.com>:

clang-format

--
0825789 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf
PiperOrigin-RevId: 684025541
copybara-service bot pushed a commit that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR #15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8 by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e2 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba by shuw <shuw@nvidia.com>:

clang-format

--
0825789 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf
PiperOrigin-RevId: 684025541
copybara-service bot pushed a commit that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR #15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8 by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e2 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba by shuw <shuw@nvidia.com>:

clang-format

--
0825789 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf
PiperOrigin-RevId: 683501409
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR openxla/xla#15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>:

clang-format

--
08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6
PiperOrigin-RevId: 684025541
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR openxla/xla#15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>:

clang-format

--
08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15331 from wenscarl:sdpa_fp8_bwd d0ae3cf52b7483c254137d8300f4c00aa963a7c6
PiperOrigin-RevId: 683501409
@copybara-service copybara-service bot closed this in 467563e Oct 9, 2024
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 9, 2024
…8. Part- 2(backward)

Imported from GitHub PR openxla/xla#15331

As the 2nd part of #15092.
NOTE: this feature relies on cudnn-frontend v1.6.1 which is not in XLA yet.
Copybara import of the project:

--
06db3c8349ca017440a2b9c4f4a7c41e557f03af by shuw <shuw@nvidia.com>:

Scaled dot product attention implementation by cudnn.

--
937b0e26ebcf5d48fce15fed8573d7c58b47e689 by shuw <shuw@nvidia.com>:

Improve after review 1

--
398b2ba2cef82f701a0ddecb7553423d92b1f902 by shuw <shuw@nvidia.com>:

clang-format

--
08257899ea899f66799bc701d81aad6ea94af6a0 by Shu Wang <shuw@nvidia.com>:

fix typo.
--
d0ae3cf52b7483c254137d8300f4c00aa963a7c6 by shuw <shuw@nvidia.com>:

Refactor test

Merging this change closes #15331

PiperOrigin-RevId: 684062495
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants