Skip to content

Commit 520d04e

Browse files
Krzysztof Rymskicopybara-github
authored andcommitted
Implementation of tiled attention with bf16 and circular buffers which reduces memory requirements by 4x on longer context on gemma models
PiperOrigin-RevId: 864904207
1 parent 463a368 commit 520d04e

15 files changed

+3056
-38
lines changed

BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,10 +652,12 @@ cc_library(
652652
name = "gemma_lib",
653653
srcs = [
654654
"gemma/gemma.cc",
655+
"gemma/tiled_attention.cc",
655656
"gemma/vit.cc",
656657
],
657658
hdrs = [
658659
"gemma/gemma.h",
660+
"gemma/tiled_attention.h",
659661
"gemma/vit.h",
660662
],
661663
exec_properties = {

CMakeLists.txt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ set(SOURCES
9393
gemma/model_store.h
9494
gemma/tensor_info.cc
9595
gemma/tensor_info.h
96+
gemma/tiled_attention.cc
97+
gemma/tiled_attention.h
9698
gemma/tokenizer.cc
9799
gemma/tokenizer.h
98100
gemma/vit.cc
@@ -171,20 +173,20 @@ install(TARGETS libgemma DESTINATION lib)
171173
if(BUILD_GEMMA_DLL)
172174
add_library(gemma_shared SHARED ${SOURCES})
173175
set_property(TARGET gemma_shared PROPERTY CXX_STANDARD 17)
174-
set_target_properties(gemma_shared PROPERTIES
176+
set_target_properties(gemma_shared PROPERTIES
175177
PREFIX ""
176178
OUTPUT_NAME "gemma"
177179
)
178180
set_property(TARGET gemma_shared PROPERTY POSITION_INDEPENDENT_CODE ON)
179181
target_include_directories(gemma_shared PUBLIC ./)
180-
target_link_libraries(gemma_shared PRIVATE
182+
target_link_libraries(gemma_shared PRIVATE
181183
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy>
182184
$<LINK_LIBRARY:WHOLE_ARCHIVE,hwy_contrib>
183185
$<LINK_LIBRARY:WHOLE_ARCHIVE,sentencepiece-static>
184186
)
185187
target_include_directories(gemma_shared PUBLIC ${sentencepiece_SOURCE_DIR})
186-
target_compile_definitions(gemma_shared
187-
PRIVATE
188+
target_compile_definitions(gemma_shared
189+
PRIVATE
188190
GEMMA_EXPORTS
189191
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
190192
)

gemma/activations.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,14 @@ struct AttentionActivations {
153153
// Accumulation of attention outputs over heads
154154
MatStorageT<BF16> att_sums;
155155

156+
MatStorageT<float> k_tile_vec;
157+
MatStorageT<float> v_tile_vec;
158+
std::vector<MatStorageT<float>> sub_task_att_out;
159+
std::vector<AlignedFloatVector>
160+
sub_task_exp_denominator_sums;
161+
std::vector<AlignedFloatVector>
162+
sub_task_max_logits;
163+
156164
// Rope
157165
MatStorageT<float> inv_timescale;
158166
MatStorageT<float> inv_timescale_global;
@@ -244,6 +252,16 @@ struct AttentionActivationsPtrs {
244252
// Accumulation of attention outputs over heads, size batch_size x
245253
// model_dim.
246254
MatPtrT<BF16> att_sums;
255+
// Stores intermediate results of computing QKV,
256+
// [qbatch * kv_heads , k_tile_size * qkv_dim]
257+
MatPtrT<float> k_tile_vec;
258+
MatPtrT<float> v_tile_vec;
259+
// Used by TiledFlashAttention to store intermediate results.
260+
std::vector<MatStorageT<float>>* sub_task_att_out;
261+
std::vector<AlignedFloatVector>*
262+
sub_task_exp_denominator_sums;
263+
std::vector<AlignedFloatVector>*
264+
sub_task_max_logits;
247265
// Inverse timescales for RoPE computation.
248266
MatPtrT<float> inv_timescale;
249267
// Inverse timescales for global RoPE computation.

gemma/configs.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ static inline bool EnumValid(LayerAttentionType type) {
8383
enum class AttentionImpl {
8484
kOld,
8585
kFlash,
86+
kFlashTransposedQs,
87+
kFlashTransposedQsBF16,
8688
kSentinel,
8789
};
8890

@@ -108,6 +110,8 @@ static inline int AttentionImplToFlags(AttentionImpl impl,
108110
case AttentionImpl::kOld:
109111
return kAttentionUseOld;
110112
case AttentionImpl::kFlash:
113+
case AttentionImpl::kFlashTransposedQs:
114+
case AttentionImpl::kFlashTransposedQsBF16:
111115
default:
112116
return 0;
113117
}

0 commit comments

Comments
 (0)