Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bench): Add features and fix some bugs for pipeline flashattention. #31

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

KuangjuX
Copy link
Collaborator

@KuangjuX KuangjuX commented Jan 6, 2025

No description provided.

@KuangjuX KuangjuX marked this pull request as draft January 6, 2025 08:30
@KuangjuX
Copy link
Collaborator Author

KuangjuX commented Jan 6, 2025

The pseudocode for FlashAttention-2 can be represented in the following form:

Iterate (k, v) in K, V :
  qk = dot(q, k) (1)
  mij = max(max(qk), lsei) (2)
  p = exp(qk − mij ) (3)
  lij = sum(p) (4)
  // renormalize o
  acc_o_scale = exp(mi − mij ) (5)
  acc_o = acc_o_scale ∗ acc_o (6)
  acc_o = acc_o + dot(p, v) (7)
  // update statistics
  mi = mij (8)
  li_new = exp (lsei − mij ) + lij (9)
  lsei = mij + log(li_new) (10)
// o_scale is the denominator of the softmax function
o_scale = exp(mi − lsei) (11)
acc_o = acc_o ∗ o_scale (12)

In the implementation of FractalTensor, some implementation details seem to differ, and I am not sure if this is compliant with the specifications. Below, I will point them out one by one.

Can you help me check these issues? @lcy-seso


// Compute `lse_i = m_ij + log(l_i_new)`.
for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) {
m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In FractalTensor, the update of LSE is performed outside the loop(https://github.com/microsoft/FractalTensor/blob/artifact/artifacts/FractalTensor/benchmarks/multi-head_attention/fractaltensor/kernel.h#L278), whereas in the pseudocode, the update of LSE is done inside the loop.

// float scale = 1 / lse_new(ax0);
float o_scale = exp(m_new(ax0) - lse_new(ax0));
// TODO(KuangjuX): Move this code into loop?
// lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the update of LSE inside the loop.

// TODO(KuangjuX): fix the following code? -> `o_scale = exp(m_i -
// lse_i)`.

// float scale = 1 / lse_new(ax0);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the pseudocode, o_scale = exp(mi − lsei) is used to rescale the final result, while in FractalTensor, the operation is performed using 1 / lse_new(ax0)(https://github.com/microsoft/FractalTensor/blob/artifact/artifacts/FractalTensor/benchmarks/multi-head_attention/fractaltensor/kernel.h#L277).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to give quick answers to your two questions and can provide detailed explanations for the reasons. First of all, I have carefully derived it, and the pseudocode is correct. I believe you can follow it. The normalization factor is outside the loop intentionally to reduce computational complexity.

@KuangjuX KuangjuX changed the title feat(bench): Add feature and fix some bugs for pipeline flashattention. feat(bench): Add features and fix some bugs for pipeline flashattention. Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants