Skip to content

Commit c76ca2d

Browse files
theraysmithcopybara-github
authored andcommitted
Rewrote flash attention to use BF16, transpose k and v, rewrote the task distribution, increase parallelism on decode, and use double the registers for the core of flash attention.
PiperOrigin-RevId: 868146247
1 parent 7e5310b commit c76ca2d

27 files changed

+6645
-1190
lines changed

BUILD.bazel

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ cc_library(
547547
deps = [
548548
":basics",
549549
":configs",
550+
":flash_structs",
550551
":gemma_args",
551552
":kv_cache",
552553
":mat",
@@ -594,6 +595,11 @@ cc_test(
594595

595596
INTERNAL_DEPS = []
596597

598+
cc_library(
599+
name = "flash_structs",
600+
hdrs = ["gemma/flash_structs.h"],
601+
)
602+
597603
cc_library(
598604
name = "attention",
599605
srcs = [
@@ -603,7 +609,6 @@ cc_library(
603609
hdrs = [
604610
"gemma/attention.h",
605611
"gemma/flash_attention.h",
606-
"gemma/flash_structs.h",
607612
],
608613
textual_hdrs = [
609614
"gemma/gemma-inl.h",
@@ -612,6 +617,7 @@ cc_library(
612617
":activations",
613618
":basics",
614619
":configs",
620+
":flash_structs",
615621
":kv_cache",
616622
":mat",
617623
":matmul",
@@ -822,6 +828,38 @@ cc_test(
822828
],
823829
)
824830

831+
cc_test(
832+
name = "wheat_from_chaff_test",
833+
srcs = ["evals/wheat_from_chaff_test.cc"],
834+
data = [
835+
"evals/testdata/google/big_bang_theory.txt",
836+
"evals/testdata/google/black_hole.txt",
837+
"evals/testdata/google/general_relativity.txt",
838+
"evals/testdata/google/qed.txt",
839+
"evals/testdata/holiday_story.txt",
840+
"evals/testdata/quark_1.txt",
841+
"evals/testdata/quark_2.txt",
842+
"evals/testdata/special_relativity.txt",
843+
"evals/testdata/standard_model.txt",
844+
],
845+
linkstatic = True,
846+
# Requires model files
847+
tags = [
848+
"local",
849+
"manual",
850+
"no_tap",
851+
],
852+
deps = [
853+
":benchmark_helper",
854+
":configs",
855+
":gemma_lib",
856+
"@googletest//:gtest_main", # buildcleaner: keep
857+
"//io",
858+
"@highway//:abort_header_only",
859+
"@highway//:hwy_test_util",
860+
],
861+
)
862+
825863
cc_binary(
826864
name = "gemma",
827865
srcs = ["gemma/run.cc"],

evals/benchmark_helper.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,11 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
150150

151151
QueryResult GemmaEnv::QueryModel(const std::string& input) {
152152
const std::vector<int> prompt = WrapAndTokenize(input);
153-
return QueryModel(prompt);
153+
auto result = QueryModel(prompt);
154+
fprintf(stderr, "prompt size: %zu, response size: %zu, total tokens: %zu\n",
155+
prompt.size(), result.tokens_generated - prompt.size(),
156+
result.tokens_generated);
157+
return result;
154158
}
155159

156160
QueryResultAndMetrics GemmaEnv::BatchQueryModelWithMetrics(

evals/benchmark_helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class GemmaEnv {
6262
static_cast<size_t>(max_generated_tokens);
6363
}
6464

65+
void PrintProfileResults() { ctx_.profiler.PrintResults(); }
66+
6567
std::vector<int> Tokenize(const std::string& input) const {
6668
std::vector<int> tokens;
6769
HWY_ASSERT(gemma_.Tokenizer().Encode(input, &tokens));

evals/gemma_batch_bench.cc

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ GemmaEnv* s_env = nullptr;
3737
class GemmaBatchBench : public ::testing::Test {
3838
protected:
3939
std::vector<std::string> BatchGemmaReply(
40-
const std::vector<std::string>& inputs) {
40+
const std::vector<std::string>& inputs, AttentionImpl attention_impl) {
41+
s_env->MutableConfig().attention_impl = attention_impl;
4142
s_env->MutableConfig().temperature = 0.0f; // deterministic
4243
s_env->MutableConfig().verbosity = 2;
4344
std::vector<std::string> replies;
@@ -128,16 +129,19 @@ std::vector<std::string> GenerateInputs() {
128129
TEST_F(GemmaBatchBench, RandomQuestionsBatched) {
129130
s_env->SetMaxGeneratedTokens(12);
130131
const std::vector<std::string> inputs = GenerateInputs();
131-
132-
// Run multiple times so that auto-tuning is closer to complete.
133-
for (size_t rep = 0; rep < 4; ++rep) {
134-
std::vector<std::string> responses = BatchGemmaReply(inputs);
135-
for (size_t i = 0; i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size());
136-
++i) {
137-
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
138-
responses[i].c_str());
132+
const AttentionImpl modes[] = {AttentionImpl::kOld, AttentionImpl::kFlash};
133+
for (const AttentionImpl mode : modes) {
134+
// Run multiple times so that auto-tuning is closer to complete.
135+
fprintf(stderr, "Testing mode %s\n", GetAttentionImplName(mode).c_str());
136+
for (size_t rep = 0; rep < 4; ++rep) {
137+
std::vector<std::string> responses = BatchGemmaReply(inputs, mode);
138+
for (size_t i = 0;
139+
i < HWY_MIN(hwy::Unpredictable1() * 3, responses.size()); ++i) {
140+
fprintf(stderr, "Rep %zu batch answer %zu '%s'\n\n", rep, i,
141+
responses[i].c_str());
142+
}
143+
PROFILER_PRINT_RESULTS();
139144
}
140-
PROFILER_PRINT_RESULTS();
141145
}
142146
}
143147

evals/testdata/holiday_story.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Albert and Marcia were on holiday. Their parents had brought them to the beach.
2+
Albert was generally unimpressed with beaches, as he would rather explore a dark forest and see the variety of mosses and fungi that grow in the damp conditions.
3+
On the other hand, Marcia loved to build enormous sand castles.
4+
Albert enjoyed collecting limpet shells to decorate the outer walls of the turrets, which he secretly thought made them look like daleks.
5+
Whilst digging sand for building, Marcia always liked to dig deep, to see if she could get to water coming through the sand from the sea.
6+
When the castle was nearly complete, and Marcia needed more sand, she hit a large piece of rusty metal.
7+
Curious as to what it was, Marcia kept digging to try to expose all of it, but it was very big and hard to get at as it was so deep in the sand.
8+
Excited by the prospect of finding something unusual in the sand, Albert joined in to help dig out the entire object.
9+
Almost an hour later, they had exposed most of a ship’s anchor.
10+
During the excavation a crowd on onlookers had formed around them, who then proceeded to take selfies in front of the unusual piece of beach litter.

evals/testdata/quark_1.txt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
Text from https://en.wikipedia.org/wiki/Quark is licensed under Creative Commons Attribution-ShareAlike 4.0 License; (https://en.wikipedia.org/wiki/Wikipedia:Text_of_the_Creative_Commons_Attribution-ShareAlike_4.0_International_License)
2+
3+
Quark
4+
From Wikipedia, the free encyclopedia
5+
(Redirected from Quarks)
6+
This article is about the elementary particle and its antiparticle. For other uses, see Quark (disambiguation).
7+
8+
A quark (/ˈkwɔːrk, ˈkwɑːrk/ ⓘ) is a type of elementary particle and a fundamental constituent of matter. Quarks combine to form composite particles called hadrons, the most stable of which are protons and neutrons, the components of atomic nuclei.[1] All commonly observable matter is composed of up quarks, down quarks and electrons. Owing to a phenomenon known as color confinement, quarks are never found in isolation; they can be found only within hadrons, which include baryons (such as protons and neutrons) and mesons, or in quark–gluon plasmas.[2][3][nb 1] For this reason, much of what is known about quarks has been drawn from observations of hadrons.
9+
10+
Quarks have various intrinsic properties, including electric charge, mass, color charge, and spin. They are the only elementary particles in the Standard Model of particle physics to experience all four fundamental interactions, also known as fundamental forces (electromagnetism, gravitation, strong interaction, and weak interaction), as well as the only known particles whose electric charges are not integer multiples of the elementary charge.
11+
12+
There are six types, known as flavors, of quarks: up, down, charm, strange, top, and bottom.[4] Up and down quarks have the lowest masses of all quarks. The heavier quarks rapidly change into up and down quarks through a process of particle decay: the transformation from a higher mass state to a lower mass state. Because of this, up and down quarks are generally stable and the most common in the universe, whereas strange, charm, bottom, and top quarks can only be produced in high energy collisions (such as those involving cosmic rays and in particle accelerators). For every quark flavor there is a corresponding type of antiparticle, known as an antiquark, that differs from the quark only in that some of its properties (such as the electric charge) have equal magnitude but opposite sign.
13+
14+
The quark model was independently proposed by physicists Murray Gell-Mann and George Zweig in 1964.[5] Quarks were introduced as parts of an ordering scheme for hadrons, and there was little evidence for their physical existence until deep inelastic scattering experiments at the Stanford Linear Accelerator Center in 1968.[6][7] Accelerator program experiments have provided evidence for all six flavors. The top quark, first observed at Fermilab in 1995, was the last to be discovered.[5]
15+
16+
Classification
17+
See also: Standard Model
18+
A four-by-four table of particles. Columns are three generations of matter (fermions) and one of forces (bosons). In the first three columns, two rows contain quarks and two leptons. The top two rows' columns contain up (u) and down (d) quarks, charm (c) and strange (s) quarks, top (t) and bottom (b) quarks, and photon (γ) and gluon (g), respectively. The bottom two rows' columns contain electron neutrino (ν sub e) and electron (e), muon neutrino (ν sub μ) and muon (μ), and tau neutrino (ν sub τ) and tau (τ), and Z sup 0 and W sup ± weak force. Mass, charge, and spin are listed for each particle.
19+
Six of the particles in the Standard Model are quarks (shown in purple). Each of the first three columns forms a generation of matter.
20+
The Standard Model is the theoretical framework describing all the known elementary particles. This model contains six flavors of quarks (q), named up (u), down (d), strange (s), charm (c), bottom (b), and top (t).[4] Antiparticles of quarks are called antiquarks, and are denoted by a bar over the symbol for the corresponding quark, such as u for an up antiquark. As with antimatter in general, antiquarks have the same mass, mean lifetime, and spin as their respective quarks, but the electric charge and other charges have the opposite sign.[8]
21+

0 commit comments

Comments
 (0)