-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(bench): Add features and fix some bugs for pipeline flashattention. #31
base: master
Are you sure you want to change the base?
Changes from all commits
3341d2c
7000e7a
f69a4ec
5b072ee
338051a
50900bc
81d70b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,8 @@ template <typename Element_, const int kM, const int kN, const int kK, | |
struct FATraits : public Base { | ||
using Element = Element_; | ||
|
||
static_assert(kTP == kP, "The current implementation requires kTP == P."); | ||
|
||
// Declare global to shared memory copy layout. | ||
using GmemLayoutQ = Layout<Shape<Int<kTM>, Int<kTK>>, Stride<Int<kK>, _1>>; | ||
using GmemLayoutK = Layout<Shape<Int<kTN>, Int<kTK>>, Stride<Int<kK>, _1>>; | ||
|
@@ -93,7 +95,9 @@ template <typename Element, typename KeTraits, const int kM, const int kN, | |
__global__ void __launch_bounds__(Nthreads) | ||
fa_kernel(const Element* dQ, const Element* dK, const Element* dV, | ||
Element* dO) { | ||
constexpr float softmax_scale = 1.250000e-01f; | ||
// constexpr float softmax_scale = 1.250000e-01f; | ||
// TODO(KuangjuX): Use a fixed value for easy comparison. | ||
constexpr float softmax_scale = 1.0f; | ||
const bool load_q_once = (kTK == kK); | ||
|
||
extern __shared__ __align__(sizeof(double)) unsigned char buf_[]; | ||
|
@@ -137,6 +141,24 @@ __global__ void __launch_bounds__(Nthreads) | |
auto acc0 = get_acc<kTM, kTN>(mma); | ||
auto acco = get_acc<kTM, kTP>(mma); | ||
|
||
#ifdef DEBUG | ||
if (thread0()) { | ||
printf("acc0 size<0>: %d, size<1>: %d, size<2>: %d\n", | ||
(int)size<0>(acc0), (int)size<1>(acc0), (int)size<2>(acc0)); | ||
printf("acco size<0>: %d, size<1>: %d, size<2>: %d\n", | ||
(int)size<0>(acco), (int)size<1>(acco), (int)size<2>(acco)); | ||
} | ||
#endif | ||
|
||
/** | ||
* In TileFusion, we use | ||
* ```cpp | ||
* using RegVec = RegTile<InType, tl::RowMajor<kAccMs, 2>>; | ||
* ``` | ||
* We need to store the reduce results for both the top row and the bottom | ||
* row simultaneously. | ||
*/ | ||
|
||
auto m_new = make_tensor<float>(Shape<Int<2 * size<1>(acc0)>>{}); | ||
auto lse_new = make_fragment_like(m_new); | ||
|
||
|
@@ -165,6 +187,8 @@ __global__ void __launch_bounds__(Nthreads) | |
int split_n = kN / kTN; | ||
for (int n = 0; n < split_n; ++n) { | ||
clear(acc0); | ||
|
||
// When `load_q_once` is true, the following code is not executed. | ||
int slice_k = kK / kTK - 1; | ||
for (int k = 0; k < slice_k; ++k) { | ||
// Barrier to ensure all data are loaded into shared memory. | ||
|
@@ -178,6 +202,8 @@ __global__ void __launch_bounds__(Nthreads) | |
cp_async_wait_flash<0>(); | ||
__syncthreads(); | ||
g2s_copy_v.prologue(); | ||
// When `load_q_once` is true, `g2s_copy_qk.prologue()` is executed only | ||
// once, and `s2r_pipeline_qk.epilogue()` is executed once as well. | ||
s2r_pipeline_qk.epilogue(); | ||
|
||
// scores = dot(q, k) | ||
|
@@ -197,30 +223,46 @@ __global__ void __launch_bounds__(Nthreads) | |
m_new(ax0) = max(m_new(ax0), scores_max(ax0)); | ||
} | ||
|
||
auto acco_rowcol = | ||
// Currently, `acco` stores the results from the previous iteration's | ||
// computation. | ||
auto previous_attn_block = | ||
make_tensor(acco.data(), convert_layout_scores(acco.layout())); | ||
|
||
// Renormalizatio for the previous block. | ||
for (int ax0 = 0; ax0 < size<0>(acco_rowcol); ++ax0) { | ||
#ifdef DEBUG | ||
if (thread0()) { | ||
printf("scores size<0>: %d, size<1>: %d\n", (int)size<0>(scores), | ||
(int)size<1>(scores)); | ||
printf("previous_attn_block size<0>: %d, size<1>: %d\n", | ||
(int)size<0>(previous_attn_block), | ||
(int)size<1>(previous_attn_block)); | ||
} | ||
#endif | ||
|
||
// Renormalization for the previous block. | ||
for (int ax0 = 0; ax0 < size<0>(previous_attn_block); ++ax0) { | ||
// Compute `acc_o_scale = exp(m_i - m_ij)` | ||
float scale = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); | ||
lse_new(ax0) = lse_new(ax0) * scale; | ||
for (int ax1 = 0; ax1 < size<1>(acco_rowcol); ++ax1) { | ||
acco_rowcol(ax0, ax1) *= scale; | ||
// Compute `acc_o = acc_o_scale * acc_o` | ||
for (int ax1 = 0; ax1 < size<1>(previous_attn_block); ++ax1) { | ||
previous_attn_block(ax0, ax1) *= scale; | ||
} | ||
} | ||
|
||
for (int ax0 = 0; ax0 < size<0>(scores); ++ax0) { | ||
float m_scaled = exp((m_old(ax0) - m_new(ax0)) * softmax_scale); | ||
lse_new(ax0) = lse_new(ax0) * m_scaled; | ||
// Compute `p = exp(qk - m_ij)` | ||
float m_scaled = m_new(ax0) * softmax_scale; | ||
for (int ax1 = 0; ax1 < size<1>(scores); ++ax1) { | ||
scores(ax0, ax1) = | ||
exp(scores(ax0, ax1) * softmax_scale - m_scaled); | ||
} | ||
} | ||
|
||
// Compute `l_ij = sum(p)`. | ||
auto scores_sum = make_fragment_like(lse_new); | ||
reduce_sum<4>(scores, scores_sum); | ||
|
||
// Compute `l_i_new = exp(lse_i - m_ij) + l_ij`. | ||
for (int ax0 = 0; ax0 < size<0>(lse_new); ++ax0) { | ||
lse_new(ax0) = lse_new(ax0) + scores_sum(ax0); | ||
} | ||
|
@@ -274,10 +316,39 @@ __global__ void __launch_bounds__(Nthreads) | |
*/ | ||
if (load_q_once) { | ||
g2s_copy_qk.prologue_K(); | ||
} else { | ||
/** | ||
* In this case, we need to reset thr pointer of Q to the | ||
* starting position and simultaneously preload the Q and K. | ||
*/ | ||
g2s_copy_qk.reset_tile_Q(kK); | ||
g2s_copy_qk.prologue(); | ||
} | ||
} | ||
|
||
// Compute `acc_o = acc_o + dot(p, v)` | ||
s2r_pipeline_v.epilogue(rP_Aregs); | ||
|
||
// Compute `lse_i = m_ij + log(l_i_new)`. | ||
for (int ax0 = 0; ax0 < size<0>(m_new); ++ax0) { | ||
m_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); | ||
} | ||
} | ||
|
||
// Normalize the attention block. | ||
auto attn_block = | ||
make_tensor(acco.data(), convert_layout_scores(acco.layout())); | ||
for (int ax0 = 0; ax0 < size<0>(attn_block); ++ax0) { | ||
// TODO(KuangjuX): fix the following code? -> `o_scale = exp(m_i - | ||
// lse_i)`. | ||
|
||
// float scale = 1 / lse_new(ax0); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the pseudocode, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to give quick answers to your two questions and can provide detailed explanations for the reasons. First of all, I have carefully derived it, and the pseudocode is correct. I believe you can follow it. The normalization factor is outside the loop intentionally to reduce computational complexity. |
||
float o_scale = exp(m_new(ax0) - lse_new(ax0)); | ||
// TODO(KuangjuX): Move this code into loop? | ||
// lse_new(ax0) = m_new(ax0) * softmax_scale + log(lse_new(ax0)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved the update of LSE inside the loop. |
||
for (int ax1 = 0; ax1 < size<1>(attn_block); ++ax1) { | ||
attn_block(ax0, ax1) *= o_scale; | ||
} | ||
} | ||
|
||
// Store O from registers to shared memory and then to global memory. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In FractalTensor, the update of LSE is performed outside the loop(https://github.com/microsoft/FractalTensor/blob/artifact/artifacts/FractalTensor/benchmarks/multi-head_attention/fractaltensor/kernel.h#L278), whereas in the pseudocode, the update of LSE is done inside the loop.