-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(bench): Add features and fix some bugs for pipeline flashattention. #31
base: master
Are you sure you want to change the base?
Conversation
The pseudocode for FlashAttention-2 can be represented in the following form:
In the implementation of FractalTensor, some implementation details seem to differ, and I am not sure if this is compliant with the specifications. Below, I will point them out one by one. Can you help me check these issues? @lcy-seso |
|
||
// Compute `lse_i = m_ij + log(l_i_new)`. | ||
for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { | ||
m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In FractalTensor, the update of LSE is performed outside the loop(https://github.com/microsoft/FractalTensor/blob/artifact/artifacts/FractalTensor/benchmarks/multi-head_attention/fractaltensor/kernel.h#L278), whereas in the pseudocode, the update of LSE is done inside the loop.
// float scale = 1 / lse_new(ax0); | ||
float o_scale = exp(m_new(ax0) - lse_new(ax0)); | ||
// TODO(KuangjuX): Move this code into loop? | ||
// lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the update of LSE inside the loop.
// TODO(KuangjuX): fix the following code? -> `o_scale = exp(m_i - | ||
// lse_i)`. | ||
|
||
// float scale = 1 / lse_new(ax0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the pseudocode, o_scale = exp(mi − lsei)
is used to rescale the final result, while in FractalTensor, the operation is performed using 1 / lse_new(ax0)
(https://github.com/microsoft/FractalTensor/blob/artifact/artifacts/FractalTensor/benchmarks/multi-head_attention/fractaltensor/kernel.h#L277).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to give quick answers to your two questions and can provide detailed explanations for the reasons. First of all, I have carefully derived it, and the pseudocode is correct. I believe you can follow it. The normalization factor is outside the loop intentionally to reduce computational complexity.
No description provided.