1919#include < algorithm>
2020#include < cmath>
2121#include < cstdlib>
22+ #include < iostream>
2223#include < limits>
2324#include < vector>
2425
@@ -464,9 +465,28 @@ static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
464465 return result;
465466}
466467
468+ // Returns vector with 8 lanes. Shouldn't be on architectures with less than 8
469+ // lanes per vector.
470+ template <class DF , typename T = hn::TFromD<DF>,
471+ class DF8 = hn::CappedTag<T, 8 >, class VF8 = hn::Vec<DF8>,
472+ class VF = hn::Vec<DF>, typename F>
473+ static HWY_INLINE VF8 Reduce8 (DF df, VF x_0, VF x_1, VF x_2, VF x_3, VF x_4,
474+ VF x_5, VF x_6, VF x_7, F reducer) {
475+ auto res0123 = Reduce4 (df, x_0, x_1, x_2, x_3, reducer);
476+ auto res4567 = Reduce4 (df, x_4, x_5, x_6, x_7, reducer);
477+
478+ using DF4 = hn::CappedTag<T, 4 >;
479+ const DF4 df4;
480+ const DF8 df8;
481+ HWY_ALIGN T buf[8 ];
482+ hn::Store (res0123, df4, buf);
483+ hn::Store (res4567, df4, buf + 4 );
484+ return hn::Load (df8, buf);
485+ }
486+
467487// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
468488template <int kNumQueries , class DF , class VF = hn::Vec<DF>>
469- static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap (
489+ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap4 (
470490 DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
471491 VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
472492 float * HWY_RESTRICT old_max, float * HWY_RESTRICT old_d,
@@ -502,31 +522,29 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
502522 old_max_vf = hn::LoadU (df4, old_max);
503523 new_max = hn::Max (new_max, old_max_vf);
504524 auto changed_max = hn::Gt (new_max, hn::Set (df4, kNegInf ));
505- // TODO figure out what was wrong with broadcasts and change to that.
506525 hn::StoreU (new_max, df4, old_max);
507526 if constexpr (kNumQueries >= 1 ) {
508527 const VF new_max_0 = hn::Set (df, old_max[0 ]);
509- x_0_p0 = hn::Exp (df, hn::Sub (x_0_p0, new_max_0));
510- x_0_p1 = hn::Exp (df, hn::Sub (x_0_p1, new_max_0));
528+ x_0_p0 = hn::CallExp (df, hn::Sub (x_0_p0, new_max_0));
529+ x_0_p1 = hn::CallExp (df, hn::Sub (x_0_p1, new_max_0));
511530 }
512531 if constexpr (kNumQueries >= 2 ) {
513532 const VF new_max_0 = hn::Set (df, old_max[1 ]);
514- x_1_p0 = hn::Exp (df, hn::Sub (x_1_p0, new_max_0));
515- x_1_p1 = hn::Exp (df, hn::Sub (x_1_p1, new_max_0));
533+ x_1_p0 = hn::CallExp (df, hn::Sub (x_1_p0, new_max_0));
534+ x_1_p1 = hn::CallExp (df, hn::Sub (x_1_p1, new_max_0));
516535 }
517536 if constexpr (kNumQueries >= 3 ) {
518537 const VF new_max_0 = hn::Set (df, old_max[2 ]);
519- x_2_p0 = hn::Exp (df, hn::Sub (x_2_p0, new_max_0));
520- x_2_p1 = hn::Exp (df, hn::Sub (x_2_p1, new_max_0));
538+ x_2_p0 = hn::CallExp (df, hn::Sub (x_2_p0, new_max_0));
539+ x_2_p1 = hn::CallExp (df, hn::Sub (x_2_p1, new_max_0));
521540 }
522541 if constexpr (kNumQueries >= 4 ) {
523542 const VF new_max_0 = hn::Set (df, old_max[3 ]);
524- x_3_p0 = hn::Exp (df, hn::Sub (x_3_p0, new_max_0));
525- x_3_p1 = hn::Exp (df, hn::Sub (x_3_p1, new_max_0));
543+ x_3_p0 = hn::CallExp (df, hn::Sub (x_3_p0, new_max_0));
544+ x_3_p1 = hn::CallExp (df, hn::Sub (x_3_p1, new_max_0));
526545 }
527546 VF4 old_d_vf = hn::Set (df4, 0 .0f );
528547 old_d_vf = hn::LoadU (df4, old_d);
529- VF4 scale = hn::Mul (old_d_vf, hn::Exp (df4, hn::Sub (old_max_vf, new_max)));
530548
531549 VF4 x_sum = hn::Zero (df4);
532550 if constexpr (kNumQueries == 1 ) {
@@ -539,6 +557,7 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
539557 x_sum = Reduce4 (df, x_0_sum, x_1_sum, x_2_sum, x_3_sum,
540558 [](auto a, auto b) HWY_ATTR { return hn::Add (a, b); });
541559 }
560+ VF4 scale = hn::Mul (old_d_vf, hn::Exp (df4, hn::Sub (old_max_vf, new_max)));
542561 old_d_vf = hn::Add (scale, x_sum);
543562 auto non_zero_mask = hn::Gt (old_d_vf, hn::Set (df4, 0 .0f ));
544563 const VF zero = hn::Zero (df);
@@ -550,43 +569,225 @@ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
550569 hn::BlendedStore (old_d_vf, changed_max, df4, old_d);
551570 scale = hn::Mul (scale, one_over_d);
552571 hn::BlendedStore (scale, changed_max, df4, scales);
553- if (hn::ExtractLane (old_d_vf, 0 ) > 0 .0f && scales[0 ] != 1 .0f ) {
554- const VF one_over_d_0 = hn::Set (df, tmp_one_over_d[0 ]);
555- x_0_p0 = hn::Mul (x_0_p0, one_over_d_0);
556- x_0_p1 = hn::Mul (x_0_p1, one_over_d_0);
572+ // same as lambda
573+ auto mul_or_zero = [&](VF& x_p0, VF& x_p1, int i) HWY_ATTR {
574+ if (HWY_LIKELY (old_d[i] > 0 .0f && scales[i] != 1 .0f )) {
575+ const VF one_over_d_i = hn::Set (df, tmp_one_over_d[i]);
576+ x_p0 = hn::Mul (x_p0, one_over_d_i);
577+ x_p1 = hn::Mul (x_p1, one_over_d_i);
578+ } else {
579+ x_p0 = zero;
580+ x_p1 = zero;
581+ }
582+ };
583+ mul_or_zero (x_0_p0, x_0_p1, 0 );
584+ if constexpr (kNumQueries >= 2 ) {
585+ mul_or_zero (x_1_p0, x_1_p1, 1 );
586+ }
587+ if constexpr (kNumQueries >= 3 ) {
588+ mul_or_zero (x_2_p0, x_2_p1, 2 );
589+ }
590+ if constexpr (kNumQueries >= 4 ) {
591+ mul_or_zero (x_3_p0, x_3_p1, 3 );
592+ }
593+ }
594+
595+ template <class DF , class VF = hn::Vec<DF>>
596+ HWY_NOINLINE VF CallExp (DF df, VF x_p0) {
597+ return hn::Exp (df, x_p0);
598+ }
599+ template <int kNumQueries , class DF , class VF = hn::Vec<DF>>
600+ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap8 (
601+ DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
602+ VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
603+ VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
604+ VF& x_7_p0, VF& x_7_p1, float * HWY_RESTRICT old_max,
605+ float * HWY_RESTRICT old_d, float * HWY_RESTRICT scales) {
606+ using DF8 = hn::CappedTag<float , 8 >;
607+ const DF8 df8;
608+ using VF8 = hn::Vec<DF8>;
609+ static_assert (kNumQueries >= 1 && kNumQueries <= 8 );
610+ VF8 new_max = hn::Set (df8, kNegInf );
611+ VF max_0, max_1, max_2, max_3, max_4, max_5, max_6, max_7 = hn::Zero (df);
612+ max_0 = hn::Max (x_0_p0, x_0_p1);
613+ if constexpr (kNumQueries >= 2 ) {
614+ max_1 = hn::Max (x_1_p0, x_1_p1);
615+ }
616+ if constexpr (kNumQueries >= 3 ) {
617+ max_2 = hn::Max (x_2_p0, x_2_p1);
618+ }
619+ if constexpr (kNumQueries >= 4 ) {
620+ max_3 = hn::Max (x_3_p0, x_3_p1);
621+ }
622+ if constexpr (kNumQueries >= 5 ) {
623+ max_4 = hn::Max (x_4_p0, x_4_p1);
624+ }
625+ if constexpr (kNumQueries >= 6 ) {
626+ max_5 = hn::Max (x_5_p0, x_5_p1);
627+ }
628+ if constexpr (kNumQueries >= 7 ) {
629+ max_6 = hn::Max (x_6_p0, x_6_p1);
630+ }
631+ if constexpr (kNumQueries >= 8 ) {
632+ max_7 = hn::Max (x_7_p0, x_7_p1);
633+ }
634+
635+ if constexpr (kNumQueries == 1 ) {
636+ new_max = hn::InsertLane (new_max, 0 , hn::ReduceMax (df, max_0));
557637 } else {
558- x_0_p0 = zero;
559- x_0_p1 = zero;
638+ new_max =
639+ Reduce8 (df, max_0, max_1, max_2, max_3, max_4, max_5, max_6, max_7,
640+ [](auto a, auto b) HWY_ATTR { return hn::Max (a, b); });
641+ }
642+ if (att_cap > 0 .0f ) {
643+ VF8 cap = hn::Set (df8, att_cap);
644+ VF8 one_over_cap = hn::Set (df8, one_over_att_cap);
645+ new_max = hn::Mul (cap, hn::Tanh (df8, hn::Mul (new_max, one_over_cap)));
646+ }
647+ VF8 old_max_vf = hn::Set (df8, kNegInf );
648+ old_max_vf = hn::LoadU (df8, old_max);
649+ new_max = hn::Max (new_max, old_max_vf);
650+ auto changed_max = hn::Gt (new_max, hn::Set (df8, kNegInf ));
651+ hn::StoreU (new_max, df8, old_max);
652+ if constexpr (kNumQueries >= 1 ) {
653+ const VF new_max_0 = hn::Set (df, old_max[0 ]);
654+ x_0_p0 = hn::CallExp (df, hn::Sub (x_0_p0, new_max_0));
655+ x_0_p1 = hn::CallExp (df, hn::Sub (x_0_p1, new_max_0));
560656 }
561657 if constexpr (kNumQueries >= 2 ) {
562- if (hn::ExtractLane (old_d_vf, 1 ) > 0 .0f && scales[1 ] != 1 .0f ) {
563- const VF one_over_d_1 = hn::Set (df, tmp_one_over_d[1 ]);
564- x_1_p0 = hn::Mul (x_1_p0, one_over_d_1);
565- x_1_p1 = hn::Mul (x_1_p1, one_over_d_1);
566- } else {
567- x_1_p0 = zero;
568- x_1_p1 = zero;
569- }
658+ const VF new_max_0 = hn::Set (df, old_max[1 ]);
659+ x_1_p0 = hn::CallExp (df, hn::Sub (x_1_p0, new_max_0));
660+ x_1_p1 = hn::CallExp (df, hn::Sub (x_1_p1, new_max_0));
570661 }
571662 if constexpr (kNumQueries >= 3 ) {
572- if (hn::ExtractLane (old_d_vf, 2 ) > 0 .0f && scales[2 ] != 1 .0f ) {
573- const VF one_over_d_2 = hn::Set (df, tmp_one_over_d[2 ]);
574- x_2_p0 = hn::Mul (x_2_p0, one_over_d_2);
575- x_2_p1 = hn::Mul (x_2_p1, one_over_d_2);
576- } else {
577- x_2_p0 = zero;
578- x_2_p1 = zero;
579- }
663+ const VF new_max_0 = hn::Set (df, old_max[2 ]);
664+ x_2_p0 = hn::CallExp (df, hn::Sub (x_2_p0, new_max_0));
665+ x_2_p1 = hn::CallExp (df, hn::Sub (x_2_p1, new_max_0));
580666 }
581667 if constexpr (kNumQueries >= 4 ) {
582- if (hn::ExtractLane (old_d_vf, 3 ) > 0 .0f && scales[3 ] != 1 .0f ) {
583- const VF one_over_d_3 = hn::Set (df, tmp_one_over_d[3 ]);
584- x_3_p0 = hn::Mul (x_3_p0, one_over_d_3);
585- x_3_p1 = hn::Mul (x_3_p1, one_over_d_3);
668+ const VF new_max_0 = hn::Set (df, old_max[3 ]);
669+ x_3_p0 = hn::CallExp (df, hn::Sub (x_3_p0, new_max_0));
670+ x_3_p1 = hn::CallExp (df, hn::Sub (x_3_p1, new_max_0));
671+ }
672+ if constexpr (kNumQueries >= 5 ) {
673+ const VF new_max_0 = hn::Set (df, old_max[4 ]);
674+ x_4_p0 = hn::CallExp (df, hn::Sub (x_4_p0, new_max_0));
675+ x_4_p1 = hn::CallExp (df, hn::Sub (x_4_p1, new_max_0));
676+ }
677+ if constexpr (kNumQueries >= 6 ) {
678+ const VF new_max_0 = hn::Set (df, old_max[5 ]);
679+ x_5_p0 = hn::CallExp (df, hn::Sub (x_5_p0, new_max_0));
680+ x_5_p1 = hn::CallExp (df, hn::Sub (x_5_p1, new_max_0));
681+ }
682+ if constexpr (kNumQueries >= 7 ) {
683+ const VF new_max_0 = hn::Set (df, old_max[6 ]);
684+ x_6_p0 = hn::CallExp (df, hn::Sub (x_6_p0, new_max_0));
685+ x_6_p1 = hn::CallExp (df, hn::Sub (x_6_p1, new_max_0));
686+ }
687+ if constexpr (kNumQueries >= 8 ) {
688+ const VF new_max_0 = hn::Set (df, old_max[7 ]);
689+ x_7_p0 = hn::CallExp (df, hn::Sub (x_7_p0, new_max_0));
690+ x_7_p1 = hn::CallExp (df, hn::Sub (x_7_p1, new_max_0));
691+ }
692+ VF8 old_d_vf = hn::Set (df8, 0 .0f );
693+ old_d_vf = hn::LoadU (df8, old_d);
694+
695+ VF8 x_sum = hn::Zero (df8);
696+ if constexpr (kNumQueries == 1 ) {
697+ x_sum = hn::Set (df8, hn::ReduceSum (df, x_0_p0) + hn::ReduceSum (df, x_0_p1));
698+ } else {
699+ VF x_0_sum = hn::Add (x_0_p0, x_0_p1);
700+ VF x_1_sum = hn::Add (x_1_p0, x_1_p1);
701+ VF x_2_sum = hn::Add (x_2_p0, x_2_p1);
702+ VF x_3_sum = hn::Add (x_3_p0, x_3_p1);
703+ VF x_4_sum = hn::Add (x_4_p0, x_4_p1);
704+ VF x_5_sum = hn::Add (x_5_p0, x_5_p1);
705+ VF x_6_sum = hn::Add (x_6_p0, x_6_p1);
706+ VF x_7_sum = hn::Add (x_7_p0, x_7_p1);
707+ x_sum = Reduce8 (df, x_0_sum, x_1_sum, x_2_sum, x_3_sum, x_4_sum, x_5_sum,
708+ x_6_sum, x_7_sum,
709+ [](auto a, auto b) HWY_ATTR { return hn::Add (a, b); });
710+ }
711+ VF8 scale = hn::Mul (old_d_vf, hn::Exp (df8, hn::Sub (old_max_vf, new_max)));
712+ old_d_vf = hn::Add (scale, x_sum);
713+ auto non_zero_mask = hn::Gt (old_d_vf, hn::Set (df8, 0 .0f ));
714+ const VF zero = hn::Zero (df);
715+ const VF8 zero8 = hn::Zero (df8);
716+ const VF8 one_over_d =
717+ hn::MaskedDivOr (zero8, non_zero_mask, hn::Set (df8, 1 .0f ), old_d_vf);
718+ HWY_ALIGN float tmp_one_over_d[8 ];
719+ hn::Store (one_over_d, df8, tmp_one_over_d);
720+ hn::BlendedStore (old_d_vf, changed_max, df8, old_d);
721+ scale = hn::Mul (scale, one_over_d);
722+ hn::BlendedStore (scale, changed_max, df8, scales);
723+ auto mul_or_zero = [&](VF& x_p0, VF& x_p1, int i) HWY_ATTR {
724+ if (HWY_LIKELY (old_d[i] > 0 .0f && scales[i] != 1 .0f )) {
725+ const VF one_over_d_i = hn::Set (df, tmp_one_over_d[i]);
726+ x_p0 = hn::Mul (x_p0, one_over_d_i);
727+ x_p1 = hn::Mul (x_p1, one_over_d_i);
586728 } else {
587- x_3_p0 = zero;
588- x_3_p1 = zero;
729+ x_p0 = zero;
730+ x_p1 = zero;
589731 }
732+ };
733+ mul_or_zero (x_0_p0, x_0_p1, 0 );
734+ if constexpr (kNumQueries >= 2 ) {
735+ mul_or_zero (x_1_p0, x_1_p1, 1 );
736+ }
737+ if constexpr (kNumQueries >= 3 ) {
738+ mul_or_zero (x_2_p0, x_2_p1, 2 );
739+ }
740+ if constexpr (kNumQueries >= 4 ) {
741+ mul_or_zero (x_3_p0, x_3_p1, 3 );
742+ }
743+ if constexpr (kNumQueries >= 5 ) {
744+ mul_or_zero (x_4_p0, x_4_p1, 4 );
745+ }
746+ if constexpr (kNumQueries >= 6 ) {
747+ mul_or_zero (x_5_p0, x_5_p1, 5 );
748+ }
749+ if constexpr (kNumQueries >= 7 ) {
750+ mul_or_zero (x_6_p0, x_6_p1, 6 );
751+ }
752+ if constexpr (kNumQueries >= 8 ) {
753+ mul_or_zero (x_7_p0, x_7_p1, 7 );
754+ }
755+ }
756+ template <int kNumQueries , class DF , class VF = hn::Vec<DF>>
757+ static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap (
758+ DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
759+ VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
760+ VF& x_4_p0, VF& x_4_p1, VF& x_5_p0, VF& x_5_p1, VF& x_6_p0, VF& x_6_p1,
761+ VF& x_7_p0, VF& x_7_p1, float * HWY_RESTRICT old_max,
762+ float * HWY_RESTRICT old_d, float * HWY_RESTRICT scales, size_t q_group_idx,
763+ size_t kNumQueriesPerGroup ) {
764+ constexpr int kFirstHalfAmountOfQueries = std::min (kNumQueries , 4 );
765+ [[maybe_unused]] constexpr int kSecondHalfAmountOfQueries =
766+ kNumQueries - kFirstHalfAmountOfQueries ;
767+ if constexpr (kNumQueries <= 4 ) {
768+ FlashAttentionTileStepAndApplySoftCap4<kFirstHalfAmountOfQueries >(
769+ df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
770+ x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup ,
771+ old_d + (q_group_idx)*kNumQueriesPerGroup , scales);
772+ } else {
773+ #if HWY_MAX_BYTES <= 16
774+ FlashAttentionTileStepAndApplySoftCap4<4 >(
775+ df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
776+ x_2_p1, x_3_p0, x_3_p1, old_max + (q_group_idx)*kNumQueriesPerGroup ,
777+ old_d + (q_group_idx)*kNumQueriesPerGroup , scales);
778+ FlashAttentionTileStepAndApplySoftCap4<kSecondHalfAmountOfQueries >(
779+ df, att_cap, one_over_att_cap, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0,
780+ x_6_p1, x_7_p0, x_7_p1,
781+ old_max + (q_group_idx + 1 ) * kNumQueriesPerGroup ,
782+ old_d + (q_group_idx + 1 ) * kNumQueriesPerGroup ,
783+ scales + kNumQueriesPerGroup );
784+ #else
785+ FlashAttentionTileStepAndApplySoftCap8<kNumQueries >(
786+ df, att_cap, one_over_att_cap, x_0_p0, x_0_p1, x_1_p0, x_1_p1, x_2_p0,
787+ x_2_p1, x_3_p0, x_3_p1, x_4_p0, x_4_p1, x_5_p0, x_5_p1, x_6_p0, x_6_p1,
788+ x_7_p0, x_7_p1, old_max + (q_group_idx)*kNumQueriesPerGroup ,
789+ old_d + (q_group_idx)*kNumQueriesPerGroup , scales);
790+ #endif
590791 }
591792}
592793
0 commit comments