18
18
19
19
20
20
class LlamaAttention (keras .layers .Layer ):
21
- """Grouped query attention for Llama models """
21
+ """A cached grounded query attention layer with sliding window. """
22
22
23
23
def __init__ (
24
24
self ,
25
25
num_query_heads ,
26
26
num_key_value_heads ,
27
+ rope_max_wavelength = 10000 ,
27
28
rope_scaling_factor = 1.0 ,
28
29
kernel_initializer = "glorot_uniform" ,
29
- rope_max_wavelength = 10000 ,
30
- max_sequence_length = 512 ,
30
+ dropout = 0 ,
31
31
** kwargs ,
32
32
):
33
33
super ().__init__ (** kwargs )
34
34
self .num_query_heads = num_query_heads
35
35
self .num_key_value_heads = num_key_value_heads
36
+ self .dropout = dropout
36
37
37
38
self .num_key_value_groups = num_query_heads // num_key_value_heads
39
+ self .rope_max_wavelength = rope_max_wavelength
38
40
39
- self .kernel_initializer = keras .initializers .get (kernel_initializer )
40
- self .max_sequence_length = max_sequence_length
41
+ self .kernel_initializer = keras .initializers .get (
42
+ clone_initializer (kernel_initializer )
43
+ )
41
44
42
45
self .rope_scaling_factor = rope_scaling_factor
43
- self .rope_max_wavelength = rope_max_wavelength
44
46
45
47
def build (self , inputs_shape ):
46
- self .hidden_dim = inputs_shape [- 1 ]
47
- self .attn_head_size = self .hidden_dim // self .num_query_heads
48
-
49
48
# Einsum variables:
50
49
# b = batch size
51
50
# q = query length
@@ -54,27 +53,40 @@ def build(self, inputs_shape):
54
53
# u = num query heads
55
54
# v = num key/value heads
56
55
# h = head dim
56
+ hidden_dim = inputs_shape [- 1 ]
57
+ head_dim = hidden_dim // self .num_query_heads
58
+ self ._norm_factor = ops .sqrt (ops .cast (head_dim , self .compute_dtype ))
59
+
57
60
self ._query_dense = keras .layers .EinsumDense (
58
61
equation = "bqm,muh->bquh" ,
59
- output_shape = (None , self .num_query_heads , self . attn_head_size ),
60
- kernel_initializer = clone_initializer ( self .kernel_initializer ) ,
62
+ output_shape = (None , self .num_query_heads , head_dim ),
63
+ kernel_initializer = self .kernel_initializer ,
61
64
dtype = self .dtype_policy ,
62
65
name = "query" ,
63
66
)
64
67
self ._query_dense .build (inputs_shape )
68
+
65
69
self ._key_dense = keras .layers .EinsumDense (
66
70
equation = "bkm,mvh->bkvh" ,
67
- output_shape = (None , self .num_key_value_heads , self .attn_head_size ),
68
- kernel_initializer = clone_initializer (self .kernel_initializer ),
71
+ output_shape = (
72
+ None ,
73
+ self .num_key_value_heads ,
74
+ head_dim ,
75
+ ),
76
+ kernel_initializer = self .kernel_initializer ,
69
77
dtype = self .dtype_policy ,
70
78
name = "key" ,
71
79
)
72
80
self ._key_dense .build (inputs_shape )
73
81
74
82
self ._value_dense = keras .layers .EinsumDense (
75
83
equation = "bkm,mvh->bkvh" ,
76
- output_shape = (None , self .num_key_value_heads , self .attn_head_size ),
77
- kernel_initializer = clone_initializer (self .kernel_initializer ),
84
+ output_shape = (
85
+ None ,
86
+ self .num_key_value_heads ,
87
+ head_dim ,
88
+ ),
89
+ kernel_initializer = self .kernel_initializer ,
78
90
dtype = self .dtype_policy ,
79
91
name = "value" ,
80
92
)
@@ -86,21 +98,28 @@ def build(self, inputs_shape):
86
98
name = "attention_softmax" ,
87
99
)
88
100
101
+ self ._dropout_layer = keras .layers .Dropout (
102
+ rate = self .dropout ,
103
+ dtype = self .dtype_policy ,
104
+ )
105
+
89
106
self ._output_dense = keras .layers .EinsumDense (
90
- equation = "bqm,mh->bqh " ,
91
- output_shape = (None , self . hidden_dim ),
92
- kernel_initializer = clone_initializer ( self .kernel_initializer ) ,
107
+ equation = "bquh,uhm->bqm " ,
108
+ output_shape = (None , hidden_dim ),
109
+ kernel_initializer = self .kernel_initializer ,
93
110
dtype = self .dtype_policy ,
94
111
name = "attention_output" ,
95
112
)
96
- self ._output_dense .build (inputs_shape )
113
+ self ._output_dense .build (( None , None , self . num_query_heads , head_dim ) )
97
114
98
- self ._rotary_embedding_layer = RotaryEmbedding (
115
+ self .rotary_embedding_layer = RotaryEmbedding (
99
116
max_wavelength = self .rope_max_wavelength ,
100
117
scaling_factor = self .rope_scaling_factor ,
101
118
dtype = self .dtype_policy ,
102
119
)
103
- self ._rotary_embedding_layer .build (inputs_shape )
120
+
121
+ self ._dot_product_equation = "bquh,bkuh->buqk"
122
+ self ._combine_equation = "buqk,bkuh->bquh"
104
123
105
124
self .built = True
106
125
@@ -110,6 +129,7 @@ def call(
110
129
attention_mask = None ,
111
130
cache = None ,
112
131
cache_update_index = None ,
132
+ training = None ,
113
133
):
114
134
query = self ._query_dense (hidden_states )
115
135
@@ -136,75 +156,61 @@ def call(
136
156
key = self ._key_dense (hidden_states )
137
157
value = self ._value_dense (hidden_states )
138
158
139
- query = self ._rotary_embedding_layer (query )
140
- key = self ._rotary_embedding_layer (key )
159
+ query = self .rotary_embedding_layer (query )
160
+ key = self .rotary_embedding_layer (key )
141
161
142
- key = ops .tile (key , [1 , 1 , self .num_key_value_groups , 1 ])
143
- value = ops .tile (value , [1 , 1 , self .num_key_value_groups , 1 ])
162
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
163
+ # -> [batch_shape, seq_len, num_heads, head_dim]
164
+ key = ops .repeat (key , repeats = self .num_key_value_groups , axis = 2 )
165
+ value = ops .repeat (value , repeats = self .num_key_value_groups , axis = 2 )
144
166
145
- attention_output , attention_scores = self ._compute_attention (
167
+ attention_output = self ._compute_attention (
146
168
query , key , value , attention_mask
147
169
)
148
170
149
- attention_output_shape = ops .shape (attention_output )
150
-
151
- attention_output = ops .reshape (
152
- attention_output ,
153
- [
154
- attention_output_shape [0 ],
155
- attention_output_shape [1 ],
156
- self .hidden_dim ,
157
- ],
171
+ attention_output = self ._dropout_layer (
172
+ attention_output , training = training
158
173
)
159
174
160
175
attention_output = self ._output_dense (attention_output )
161
176
162
177
if cache is not None :
163
- return ( attention_output , cache )
178
+ return attention_output , cache
164
179
return attention_output
165
180
166
181
def _masked_softmax (self , attention_scores , attention_mask = None ):
167
182
if attention_mask is not None :
168
- mask_expansion_axis = - 3
169
- for _ in range (
170
- len (attention_scores .shape ) - len (attention_mask .shape )
171
- ):
172
- attention_mask = ops .expand_dims (
173
- attention_mask , axis = mask_expansion_axis
174
- )
175
- return self ._softmax (attention_scores , attention_mask )
183
+ return self ._softmax (
184
+ attention_scores , attention_mask [:, None , :, :]
185
+ )
186
+ return self ._softmax (attention_scores )
176
187
177
188
def _compute_attention (self , query , key , value , attention_mask = None ):
178
- attention_scores = ops .einsum ("aecd,abcd->acbe" , key , query )
179
-
180
- norm_factor = ops .sqrt (
181
- ops .convert_to_tensor (self .attn_head_size , self .compute_dtype )
182
- )
189
+ attention_scores = ops .einsum (self ._dot_product_equation , query , key )
183
190
184
- attention_scores /= norm_factor
191
+ attention_scores = attention_scores / self . _norm_factor
185
192
attention_scores = self ._masked_softmax (
186
193
attention_scores , attention_mask
187
194
)
188
195
attention_scores = ops .cast (attention_scores , self .compute_dtype )
189
196
attention_output = ops .einsum (
190
- "acbe,aecd->abcd" , attention_scores , value
197
+ self . _combine_equation , attention_scores , value
191
198
)
192
199
193
- return attention_output , attention_scores
200
+ return attention_output
194
201
195
202
def get_config (self ):
196
203
config = super ().get_config ()
197
204
config .update (
198
205
{
199
206
"num_query_heads" : self .num_query_heads ,
200
- "hidden_dim" : self .hidden_dim ,
207
+ "num_key_value_heads" : self .num_key_value_heads ,
208
+ "rope_max_wavelength" : self .rope_max_wavelength ,
209
+ "rope_scaling_factor" : self .rope_scaling_factor ,
201
210
"kernel_initializer" : keras .initializers .serialize (
202
211
self .kernel_initializer
203
212
),
204
- "rope_max_wavelength" : self .rope_max_wavelength ,
205
- "rope_scaling_factor" : self .rope_scaling_factor ,
206
- "num_key_value_heads" : self .num_key_value_heads ,
207
- "max_sequence_length" : self .max_sequence_length ,
213
+ "dropout" : self .dropout ,
208
214
}
209
215
)
210
216
return config
0 commit comments