Skip to content

Commit a59a26f

Browse files
Fixes for the LLaMA backbone + add dropout (#1499)
* Firxes for the LLaMA backbone + add dropout * Address review comments CachedLlamaAttention -> LlamaAttention and make parameter state public in the attention layer * Remove self._hidden_dim and self._head_dim
1 parent 5136876 commit a59a26f

File tree

4 files changed

+182
-122
lines changed

4 files changed

+182
-122
lines changed

keras_nlp/models/llama/llama_attention.py

Lines changed: 63 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,33 @@
1818

1919

2020
class LlamaAttention(keras.layers.Layer):
21-
"""Grouped query attention for Llama models"""
21+
"""A cached grounded query attention layer with sliding window."""
2222

2323
def __init__(
2424
self,
2525
num_query_heads,
2626
num_key_value_heads,
27+
rope_max_wavelength=10000,
2728
rope_scaling_factor=1.0,
2829
kernel_initializer="glorot_uniform",
29-
rope_max_wavelength=10000,
30-
max_sequence_length=512,
30+
dropout=0,
3131
**kwargs,
3232
):
3333
super().__init__(**kwargs)
3434
self.num_query_heads = num_query_heads
3535
self.num_key_value_heads = num_key_value_heads
36+
self.dropout = dropout
3637

3738
self.num_key_value_groups = num_query_heads // num_key_value_heads
39+
self.rope_max_wavelength = rope_max_wavelength
3840

39-
self.kernel_initializer = keras.initializers.get(kernel_initializer)
40-
self.max_sequence_length = max_sequence_length
41+
self.kernel_initializer = keras.initializers.get(
42+
clone_initializer(kernel_initializer)
43+
)
4144

4245
self.rope_scaling_factor = rope_scaling_factor
43-
self.rope_max_wavelength = rope_max_wavelength
4446

4547
def build(self, inputs_shape):
46-
self.hidden_dim = inputs_shape[-1]
47-
self.attn_head_size = self.hidden_dim // self.num_query_heads
48-
4948
# Einsum variables:
5049
# b = batch size
5150
# q = query length
@@ -54,27 +53,40 @@ def build(self, inputs_shape):
5453
# u = num query heads
5554
# v = num key/value heads
5655
# h = head dim
56+
hidden_dim = inputs_shape[-1]
57+
head_dim = hidden_dim // self.num_query_heads
58+
self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype))
59+
5760
self._query_dense = keras.layers.EinsumDense(
5861
equation="bqm,muh->bquh",
59-
output_shape=(None, self.num_query_heads, self.attn_head_size),
60-
kernel_initializer=clone_initializer(self.kernel_initializer),
62+
output_shape=(None, self.num_query_heads, head_dim),
63+
kernel_initializer=self.kernel_initializer,
6164
dtype=self.dtype_policy,
6265
name="query",
6366
)
6467
self._query_dense.build(inputs_shape)
68+
6569
self._key_dense = keras.layers.EinsumDense(
6670
equation="bkm,mvh->bkvh",
67-
output_shape=(None, self.num_key_value_heads, self.attn_head_size),
68-
kernel_initializer=clone_initializer(self.kernel_initializer),
71+
output_shape=(
72+
None,
73+
self.num_key_value_heads,
74+
head_dim,
75+
),
76+
kernel_initializer=self.kernel_initializer,
6977
dtype=self.dtype_policy,
7078
name="key",
7179
)
7280
self._key_dense.build(inputs_shape)
7381

7482
self._value_dense = keras.layers.EinsumDense(
7583
equation="bkm,mvh->bkvh",
76-
output_shape=(None, self.num_key_value_heads, self.attn_head_size),
77-
kernel_initializer=clone_initializer(self.kernel_initializer),
84+
output_shape=(
85+
None,
86+
self.num_key_value_heads,
87+
head_dim,
88+
),
89+
kernel_initializer=self.kernel_initializer,
7890
dtype=self.dtype_policy,
7991
name="value",
8092
)
@@ -86,21 +98,28 @@ def build(self, inputs_shape):
8698
name="attention_softmax",
8799
)
88100

101+
self._dropout_layer = keras.layers.Dropout(
102+
rate=self.dropout,
103+
dtype=self.dtype_policy,
104+
)
105+
89106
self._output_dense = keras.layers.EinsumDense(
90-
equation="bqm,mh->bqh",
91-
output_shape=(None, self.hidden_dim),
92-
kernel_initializer=clone_initializer(self.kernel_initializer),
107+
equation="bquh,uhm->bqm",
108+
output_shape=(None, hidden_dim),
109+
kernel_initializer=self.kernel_initializer,
93110
dtype=self.dtype_policy,
94111
name="attention_output",
95112
)
96-
self._output_dense.build(inputs_shape)
113+
self._output_dense.build((None, None, self.num_query_heads, head_dim))
97114

98-
self._rotary_embedding_layer = RotaryEmbedding(
115+
self.rotary_embedding_layer = RotaryEmbedding(
99116
max_wavelength=self.rope_max_wavelength,
100117
scaling_factor=self.rope_scaling_factor,
101118
dtype=self.dtype_policy,
102119
)
103-
self._rotary_embedding_layer.build(inputs_shape)
120+
121+
self._dot_product_equation = "bquh,bkuh->buqk"
122+
self._combine_equation = "buqk,bkuh->bquh"
104123

105124
self.built = True
106125

@@ -110,6 +129,7 @@ def call(
110129
attention_mask=None,
111130
cache=None,
112131
cache_update_index=None,
132+
training=None,
113133
):
114134
query = self._query_dense(hidden_states)
115135

@@ -136,75 +156,61 @@ def call(
136156
key = self._key_dense(hidden_states)
137157
value = self._value_dense(hidden_states)
138158

139-
query = self._rotary_embedding_layer(query)
140-
key = self._rotary_embedding_layer(key)
159+
query = self.rotary_embedding_layer(query)
160+
key = self.rotary_embedding_layer(key)
141161

142-
key = ops.tile(key, [1, 1, self.num_key_value_groups, 1])
143-
value = ops.tile(value, [1, 1, self.num_key_value_groups, 1])
162+
# [batch_shape, seq_len, num_key_value_heads, head_dim]
163+
# -> [batch_shape, seq_len, num_heads, head_dim]
164+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
165+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
144166

145-
attention_output, attention_scores = self._compute_attention(
167+
attention_output = self._compute_attention(
146168
query, key, value, attention_mask
147169
)
148170

149-
attention_output_shape = ops.shape(attention_output)
150-
151-
attention_output = ops.reshape(
152-
attention_output,
153-
[
154-
attention_output_shape[0],
155-
attention_output_shape[1],
156-
self.hidden_dim,
157-
],
171+
attention_output = self._dropout_layer(
172+
attention_output, training=training
158173
)
159174

160175
attention_output = self._output_dense(attention_output)
161176

162177
if cache is not None:
163-
return (attention_output, cache)
178+
return attention_output, cache
164179
return attention_output
165180

166181
def _masked_softmax(self, attention_scores, attention_mask=None):
167182
if attention_mask is not None:
168-
mask_expansion_axis = -3
169-
for _ in range(
170-
len(attention_scores.shape) - len(attention_mask.shape)
171-
):
172-
attention_mask = ops.expand_dims(
173-
attention_mask, axis=mask_expansion_axis
174-
)
175-
return self._softmax(attention_scores, attention_mask)
183+
return self._softmax(
184+
attention_scores, attention_mask[:, None, :, :]
185+
)
186+
return self._softmax(attention_scores)
176187

177188
def _compute_attention(self, query, key, value, attention_mask=None):
178-
attention_scores = ops.einsum("aecd,abcd->acbe", key, query)
179-
180-
norm_factor = ops.sqrt(
181-
ops.convert_to_tensor(self.attn_head_size, self.compute_dtype)
182-
)
189+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
183190

184-
attention_scores /= norm_factor
191+
attention_scores = attention_scores / self._norm_factor
185192
attention_scores = self._masked_softmax(
186193
attention_scores, attention_mask
187194
)
188195
attention_scores = ops.cast(attention_scores, self.compute_dtype)
189196
attention_output = ops.einsum(
190-
"acbe,aecd->abcd", attention_scores, value
197+
self._combine_equation, attention_scores, value
191198
)
192199

193-
return attention_output, attention_scores
200+
return attention_output
194201

195202
def get_config(self):
196203
config = super().get_config()
197204
config.update(
198205
{
199206
"num_query_heads": self.num_query_heads,
200-
"hidden_dim": self.hidden_dim,
207+
"num_key_value_heads": self.num_key_value_heads,
208+
"rope_max_wavelength": self.rope_max_wavelength,
209+
"rope_scaling_factor": self.rope_scaling_factor,
201210
"kernel_initializer": keras.initializers.serialize(
202211
self.kernel_initializer
203212
),
204-
"rope_max_wavelength": self.rope_max_wavelength,
205-
"rope_scaling_factor": self.rope_scaling_factor,
206-
"num_key_value_heads": self.num_key_value_heads,
207-
"max_sequence_length": self.max_sequence_length,
213+
"dropout": self.dropout,
208214
}
209215
)
210216
return config

0 commit comments

Comments
 (0)