Skip to content

Commit 7399fd6

Browse files
Krzysztof Rymskicopybara-github
authored andcommitted
Internal changes
PiperOrigin-RevId: 871804601
1 parent 34739fd commit 7399fd6

File tree

2 files changed

+241
-40
lines changed

2 files changed

+241
-40
lines changed

gemma/flash_attention.cc

Lines changed: 240 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <algorithm>
2020
#include <cmath>
2121
#include <cstdlib>
22+
#include <iostream>
2223
#include <limits>
2324
#include <vector>
2425

@@ -464,9 +465,28 @@ static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
464465
return result;
465466
}
466467

468+
// Returns vector with 8 lanes. Shouldn't be on architectures with less than 8
469+
// lanes per vector.
470+
template <class DF, typename T = hn::TFromD<DF>,
471+
class DF8 = hn::CappedTag<T, 8>, class VF8 = hn::Vec<DF8>,
472+
class VF = hn::Vec<DF>, typename F>
473+
static HWY_INLINE VF8 Reduce8(DF df, VF x_0, VF x_1, VF x_2, VF x_3, VF x_4,
474+
VF x_5, VF x_6, VF x_7, F reducer) {
475+
auto res0123 = Reduce4(df, x_0, x_1, x_2, x_3, reducer);
476+
auto res4567 = Reduce4(df, x_4, x_5, x_6, x_7, reducer);
477+
478+
using DF4 = hn::CappedTag<T, 4>;
479+
const DF4 df4;
480+
const DF8 df8;
481+
HWY_ALIGN T buf[8];
482+
hn::Store(res0123, df4, buf);
483+
hn::Store(res4567, df4, buf + 4);
484+
return hn::Load(df8, buf);
485+
}
486+
467487
// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
468488
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
469-
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
489+
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4(
470490
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
471491
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
472492
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
@@ -502,31 +522,29 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
502522
old_max_vf = hn::LoadU(df4, old_max);
503523
new_max = hn::Max(new_max, old_max_vf);
504524
auto changed_max = hn::Gt(new_max, hn::Set(df4, kNegInf));
505-
// TODO figure out what was wrong with broadcasts and change to that.
506525
hn::StoreU(new_max, df4, old_max);
507526
if constexpr (kNumQueries >= 1) {
508527
const VF new_max_0 = hn::Set(df, old_max[0]);
509-
x_0_p0 = hn::Exp(df, hn::Sub(x_0_p0, new_max_0));
510-
x_0_p1 = hn::Exp(df, hn::Sub(x_0_p1, new_max_0));
528+
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
529+
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
511530
}
512531
if constexpr (kNumQueries >= 2) {
513532
const VF new_max_0 = hn::Set(df, old_max[1]);
514-
x_1_p0 = hn::Exp(df, hn::Sub(x_1_p0, new_max_0));
515-
x_1_p1 = hn::Exp(df, hn::Sub(x_1_p1, new_max_0));
533+
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
534+
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
516535
}
517536
if constexpr (kNumQueries >= 3) {
518537
const VF new_max_0 = hn::Set(df, old_max[2]);
519-
x_2_p0 = hn::Exp(df, hn::Sub(x_2_p0, new_max_0));
520-
x_2_p1 = hn::Exp(df, hn::Sub(x_2_p1, new_max_0));
538+
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
539+
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
521540
}
522541
if constexpr (kNumQueries >= 4) {
523542
const VF new_max_0 = hn::Set(df, old_max[3]);
524-
x_3_p0 = hn::Exp(df, hn::Sub(x_3_p0, new_max_0));
525-
x_3_p1 = hn::Exp(df, hn::Sub(x_3_p1, new_max_0));
543+
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
544+
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
526545
}
527546
VF4 old_d_vf = hn::Set(df4, 0.0f);
528547
old_d_vf = hn::LoadU(df4, old_d);
529-
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
530548

531549
VF4 x_sum = hn::Zero(df4);
532550
if constexpr (kNumQueries == 1) {
@@ -539,6 +557,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
539557
x_sum = Reduce4(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
540558
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
541559
}
560+
VF4 scale = hn::Mul(old_d_vf, hn::Exp(df4, hn::Sub(old_max_vf, new_max)));
542561
old_d_vf = hn::Add(scale, x_sum);
543562
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df4, 0.0f));
544563
const VF zero = hn::Zero(df);
@@ -550,43 +569,225 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
550569
hn::BlendedStore(old_d_vf, changed_max, df4, old_d);
551570
scale = hn::Mul(scale, one_over_d);
552571
hn::BlendedStore(scale, changed_max, df4, scales);
553-
if (hn::ExtractLane(old_d_vf, 0) > 0.0f && scales[0] != 1.0f) {
554-
const VF one_over_d_0 = hn::Set(df, tmp_one_over_d[0]);
555-
x_0_p0 = hn::Mul(x_0_p0, one_over_d_0);
556-
x_0_p1 = hn::Mul(x_0_p1, one_over_d_0);
572+
// same as lambda
573+
auto mul_or_zero = [&](VF& x_p0, VF& x_p1, int i) HWY_ATTR {
574+
if (HWY_LIKELY(old_d[i] > 0.0f && scales[i] != 1.0f)) {
575+
const VF one_over_d_i = hn::Set(df, tmp_one_over_d[i]);
576+
x_p0 = hn::Mul(x_p0, one_over_d_i);
577+
x_p1 = hn::Mul(x_p1, one_over_d_i);
578+
} else {
579+
x_p0 = zero;
580+
x_p1 = zero;
581+
}
582+
};
583+
mul_or_zero(x_0_p0, x_0_p1, 0);
584+
if constexpr (kNumQueries >= 2) {
585+
mul_or_zero(x_1_p0, x_1_p1, 1);
586+
}
587+
if constexpr (kNumQueries >= 3) {
588+
mul_or_zero(x_2_p0, x_2_p1, 2);
589+
}
590+
if constexpr (kNumQueries >= 4) {
591+
mul_or_zero(x_3_p0, x_3_p1, 3);
592+
}
593+
}
594+
595+
template <class DF, class VF = hn::Vec<DF>>
596+
HWY_NOINLINE VF CallExp(DF df, VF x_p0) {
597+
return hn::Exp(df, x_p0);
598+
}
599+
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
600+
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8(
601+
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
602+
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
603+
VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
604+
VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max,
605+
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales) {
606+
using DF8 = hn::CappedTag<float, 8>;
607+
const DF8 df8;
608+
using VF8 = hn::Vec<DF8>;
609+
static_assert(kNumQueries >= 1 && kNumQueries <= 8);
610+
VF8 new_max = hn::Set(df8, kNegInf);
611+
VF max_0, max_1, max_2, max_3, max_4, max_5, max_6, max_7 = hn::Zero(df);
612+
max_0 = hn::Max(x_0_p0, x_0_p1);
613+
if constexpr (kNumQueries >= 2) {
614+
max_1 = hn::Max(x_1_p0, x_1_p1);
615+
}
616+
if constexpr (kNumQueries >= 3) {
617+
max_2 = hn::Max(x_2_p0, x_2_p1);
618+
}
619+
if constexpr (kNumQueries >= 4) {
620+
max_3 = hn::Max(x_3_p0, x_3_p1);
621+
}
622+
if constexpr (kNumQueries >= 5) {
623+
max_4 = hn::Max(x_4_p0, x_4_p1);
624+
}
625+
if constexpr (kNumQueries >= 6) {
626+
max_5 = hn::Max(x_5_p0, x_5_p1);
627+
}
628+
if constexpr (kNumQueries >= 7) {
629+
max_6 = hn::Max(x_6_p0, x_6_p1);
630+
}
631+
if constexpr (kNumQueries >= 8) {
632+
max_7 = hn::Max(x_7_p0, x_7_p1);
633+
}
634+
635+
if constexpr (kNumQueries == 1) {
636+
new_max = hn::InsertLane(new_max, 0, hn::ReduceMax(df, max_0));
557637
} else {
558-
x_0_p0 = zero;
559-
x_0_p1 = zero;
638+
new_max =
639+
Reduce8(df, max_0, max_1, max_2, max_3, max_4, max_5, max_6, max_7,
640+
[](auto a, auto b) HWY_ATTR { return hn::Max(a, b); });
641+
}
642+
if (att_cap > 0.0f) {
643+
VF8 cap = hn::Set(df8, att_cap);
644+
VF8 one_over_cap = hn::Set(df8, one_over_att_cap);
645+
new_max = hn::Mul(cap, hn::Tanh(df8, hn::Mul(new_max, one_over_cap)));
646+
}
647+
VF8 old_max_vf = hn::Set(df8, kNegInf);
648+
old_max_vf = hn::LoadU(df8, old_max);
649+
new_max = hn::Max(new_max, old_max_vf);
650+
auto changed_max = hn::Gt(new_max, hn::Set(df8, kNegInf));
651+
hn::StoreU(new_max, df8, old_max);
652+
if constexpr (kNumQueries >= 1) {
653+
const VF new_max_0 = hn::Set(df, old_max[0]);
654+
x_0_p0 = hn::CallExp(df, hn::Sub(x_0_p0, new_max_0));
655+
x_0_p1 = hn::CallExp(df, hn::Sub(x_0_p1, new_max_0));
560656
}
561657
if constexpr (kNumQueries >= 2) {
562-
if (hn::ExtractLane(old_d_vf, 1) > 0.0f && scales[1] != 1.0f) {
563-
const VF one_over_d_1 = hn::Set(df, tmp_one_over_d[1]);
564-
x_1_p0 = hn::Mul(x_1_p0, one_over_d_1);
565-
x_1_p1 = hn::Mul(x_1_p1, one_over_d_1);
566-
} else {
567-
x_1_p0 = zero;
568-
x_1_p1 = zero;
569-
}
658+
const VF new_max_0 = hn::Set(df, old_max[1]);
659+
x_1_p0 = hn::CallExp(df, hn::Sub(x_1_p0, new_max_0));
660+
x_1_p1 = hn::CallExp(df, hn::Sub(x_1_p1, new_max_0));
570661
}
571662
if constexpr (kNumQueries >= 3) {
572-
if (hn::ExtractLane(old_d_vf, 2) > 0.0f && scales[2] != 1.0f) {
573-
const VF one_over_d_2 = hn::Set(df, tmp_one_over_d[2]);
574-
x_2_p0 = hn::Mul(x_2_p0, one_over_d_2);
575-
x_2_p1 = hn::Mul(x_2_p1, one_over_d_2);
576-
} else {
577-
x_2_p0 = zero;
578-
x_2_p1 = zero;
579-
}
663+
const VF new_max_0 = hn::Set(df, old_max[2]);
664+
x_2_p0 = hn::CallExp(df, hn::Sub(x_2_p0, new_max_0));
665+
x_2_p1 = hn::CallExp(df, hn::Sub(x_2_p1, new_max_0));
580666
}
581667
if constexpr (kNumQueries >= 4) {
582-
if (hn::ExtractLane(old_d_vf, 3) > 0.0f && scales[3] != 1.0f) {
583-
const VF one_over_d_3 = hn::Set(df, tmp_one_over_d[3]);
584-
x_3_p0 = hn::Mul(x_3_p0, one_over_d_3);
585-
x_3_p1 = hn::Mul(x_3_p1, one_over_d_3);
668+
const VF new_max_0 = hn::Set(df, old_max[3]);
669+
x_3_p0 = hn::CallExp(df, hn::Sub(x_3_p0, new_max_0));
670+
x_3_p1 = hn::CallExp(df, hn::Sub(x_3_p1, new_max_0));
671+
}
672+
if constexpr (kNumQueries >= 5) {
673+
const VF new_max_0 = hn::Set(df, old_max[4]);
674+
x_4_p0 = hn::CallExp(df, hn::Sub(x_4_p0, new_max_0));
675+
x_4_p1 = hn::CallExp(df, hn::Sub(x_4_p1, new_max_0));
676+
}
677+
if constexpr (kNumQueries >= 6) {
678+
const VF new_max_0 = hn::Set(df, old_max[5]);
679+
x_5_p0 = hn::CallExp(df, hn::Sub(x_5_p0, new_max_0));
680+
x_5_p1 = hn::CallExp(df, hn::Sub(x_5_p1, new_max_0));
681+
}
682+
if constexpr (kNumQueries >= 7) {
683+
const VF new_max_0 = hn::Set(df, old_max[6]);
684+
x_6_p0 = hn::CallExp(df, hn::Sub(x_6_p0, new_max_0));
685+
x_6_p1 = hn::CallExp(df, hn::Sub(x_6_p1, new_max_0));
686+
}
687+
if constexpr (kNumQueries >= 8) {
688+
const VF new_max_0 = hn::Set(df, old_max[7]);
689+
x_7_p0 = hn::CallExp(df, hn::Sub(x_7_p0, new_max_0));
690+
x_7_p1 = hn::CallExp(df, hn::Sub(x_7_p1, new_max_0));
691+
}
692+
VF8 old_d_vf = hn::Set(df8, 0.0f);
693+
old_d_vf = hn::LoadU(df8, old_d);
694+
695+
VF8 x_sum = hn::Zero(df8);
696+
if constexpr (kNumQueries == 1) {
697+
x_sum = hn::Set(df8, hn::ReduceSum(df, x_0_p0) + hn::ReduceSum(df, x_0_p1));
698+
} else {
699+
VF x_0_sum = hn::Add(x_0_p0, x_0_p1);
700+
VF x_1_sum = hn::Add(x_1_p0, x_1_p1);
701+
VF x_2_sum = hn::Add(x_2_p0, x_2_p1);
702+
VF x_3_sum = hn::Add(x_3_p0, x_3_p1);
703+
VF x_4_sum = hn::Add(x_4_p0, x_4_p1);
704+
VF x_5_sum = hn::Add(x_5_p0, x_5_p1);
705+
VF x_6_sum = hn::Add(x_6_p0, x_6_p1);
706+
VF x_7_sum = hn::Add(x_7_p0, x_7_p1);
707+
x_sum = Reduce8(df, x_0_sum, x_1_sum, x_2_sum, x_3_sum, x_4_sum, x_5_sum,
708+
x_6_sum, x_7_sum,
709+
[](auto a, auto b) HWY_ATTR { return hn::Add(a, b); });
710+
}
711+
VF8 scale = hn::Mul(old_d_vf, hn::Exp(df8, hn::Sub(old_max_vf, new_max)));
712+
old_d_vf = hn::Add(scale, x_sum);
713+
auto non_zero_mask = hn::Gt(old_d_vf, hn::Set(df8, 0.0f));
714+
const VF zero = hn::Zero(df);
715+
const VF8 zero8 = hn::Zero(df8);
716+
const VF8 one_over_d =
717+
hn::MaskedDivOr(zero8, non_zero_mask, hn::Set(df8, 1.0f), old_d_vf);
718+
HWY_ALIGN float tmp_one_over_d[8];
719+
hn::Store(one_over_d, df8, tmp_one_over_d);
720+
hn::BlendedStore(old_d_vf, changed_max, df8, old_d);
721+
scale = hn::Mul(scale, one_over_d);
722+
hn::BlendedStore(scale, changed_max, df8, scales);
723+
auto mul_or_zero = [&](VF& x_p0, VF& x_p1, int i) HWY_ATTR {
724+
if (HWY_LIKELY(old_d[i] > 0.0f && scales[i] != 1.0f)) {
725+
const VF one_over_d_i = hn::Set(df, tmp_one_over_d[i]);
726+
x_p0 = hn::Mul(x_p0, one_over_d_i);
727+
x_p1 = hn::Mul(x_p1, one_over_d_i);
586728
} else {
587-
x_3_p0 = zero;
588-
x_3_p1 = zero;
729+
x_p0 = zero;
730+
x_p1 = zero;
589731
}
732+
};
733+
mul_or_zero(x_0_p0, x_0_p1, 0);
734+
if constexpr (kNumQueries >= 2) {
735+
mul_or_zero(x_1_p0, x_1_p1, 1);
736+
}
737+
if constexpr (kNumQueries >= 3) {
738+
mul_or_zero(x_2_p0, x_2_p1, 2);
739+
}
740+
if constexpr (kNumQueries >= 4) {
741+
mul_or_zero(x_3_p0, x_3_p1, 3);
742+
}
743+
if constexpr (kNumQueries >= 5) {
744+
mul_or_zero(x_4_p0, x_4_p1, 4);
745+
}
746+
if constexpr (kNumQueries >= 6) {
747+
mul_or_zero(x_5_p0, x_5_p1, 5);
748+
}
749+
if constexpr (kNumQueries >= 7) {
750+
mul_or_zero(x_6_p0, x_6_p1, 6);
751+
}
752+
if constexpr (kNumQueries >= 8) {
753+
mul_or_zero(x_7_p0, x_7_p1, 7);
754+
}
755+
}
756+
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
757+
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
758+
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
759+
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
760+
VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
761+
VF& x_7_p0, VF& x_7_p1, float* HWY_RESTRICT old_max,
762+
float* HWY_RESTRICT old_d, float* HWY_RESTRICT scales, size_t q_group_idx,
763+
size_t kNumQueriesPerGroup) {
764+
constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4);
765+
[[maybe_unused]] constexpr int kSecondHalfAmountOfQueries =
766+
kNumQueries - kFirstHalfAmountOfQueries;
767+
if constexpr (kNumQueries <= 4) {
768+
FlashAttentionTileStepAndApplySoftCap4<kFirstHalfAmountOfQueries>(
769+
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
770+
x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
771+
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
772+
} else {
773+
#if HWY_MAX_BYTES <= 16
774+
FlashAttentionTileStepAndApplySoftCap4<4>(
775+
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
776+
x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
777+
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
778+
FlashAttentionTileStepAndApplySoftCap4<kSecondHalfAmountOfQueries>(
779+
df, att_cap, one_over_att_cap, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0,
780+
x_6_p1, x_7_p0, x_7_p1,
781+
old_max + (q_group_idx + 1) * kNumQueriesPerGroup,
782+
old_d + (q_group_idx + 1) * kNumQueriesPerGroup,
783+
scales + kNumQueriesPerGroup);
784+
#else
785+
FlashAttentionTileStepAndApplySoftCap8<kNumQueries>(
786+
df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
787+
x_2_p1, x_3_p0, x_3_p1, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0, x_6_p1,
788+
x_7_p0, x_7_p1, old_max + (q_group_idx)*kNumQueriesPerGroup,
789+
old_d + (q_group_idx)*kNumQueriesPerGroup, scales);
790+
#endif
590791
}
591792
}
592793

gemma/flash_attention_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ void SetMat(const size_t offset, MatPtrT<float>& mat) {
6868
const float i_scale = 1.0f / kInner;
6969
const float j_scale = 1.0f / kOuter;
7070
for (size_t i = 0; i < kOuter; ++i) {
71-
float* row = mat.Row(i);
71+
float* HWY_RESTRICT row = mat.Row(i);
7272
for (size_t j = 0; j < kInner; ++j) {
7373
row[j] =
7474
static_cast<float>((i * kInner * i_scale + (j + offset) * j_scale));

0 commit comments

Comments
 (0)