Skip to content

Commit 4ba6dd2

Browse files
Address review comments
CachedLlamaAttention -> LlamaAttention and make parameter state public in the attention layer
1 parent 33c9227 commit 4ba6dd2

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

keras_nlp/models/llama/llama_attention.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from keras_nlp.utils.keras_utils import clone_initializer
1818

1919

20-
class CachedLlamaAttention(keras.layers.Layer):
20+
class LlamaAttention(keras.layers.Layer):
2121
"""A cached grounded query attention layer with sliding window."""
2222

2323
def __init__(
@@ -31,18 +31,18 @@ def __init__(
3131
**kwargs,
3232
):
3333
super().__init__(**kwargs)
34-
self._num_query_heads = num_query_heads
35-
self._num_key_value_heads = num_key_value_heads
36-
self._dropout = dropout
34+
self.num_query_heads = num_query_heads
35+
self.num_key_value_heads = num_key_value_heads
36+
self.dropout = dropout
3737

38-
self._num_key_value_groups = num_query_heads // num_key_value_heads
39-
self._rope_max_wavelength = rope_max_wavelength
38+
self.num_key_value_groups = num_query_heads // num_key_value_heads
39+
self.rope_max_wavelength = rope_max_wavelength
4040

41-
self._kernel_initializer = keras.initializers.get(
41+
self.kernel_initializer = keras.initializers.get(
4242
clone_initializer(kernel_initializer)
4343
)
4444

45-
self._rope_scaling_factor = rope_scaling_factor
45+
self.rope_scaling_factor = rope_scaling_factor
4646

4747
def build(self, inputs_shape):
4848
# Einsum variables:
@@ -54,12 +54,12 @@ def build(self, inputs_shape):
5454
# v = num key/value heads
5555
# h = head dim
5656
self._hidden_dim = inputs_shape[-1]
57-
self._head_dim = self._hidden_dim // self._num_query_heads
57+
self._head_dim = self._hidden_dim // self.num_query_heads
5858

5959
self._query_dense = keras.layers.EinsumDense(
6060
equation="bqm,muh->bquh",
61-
output_shape=(None, self._num_query_heads, self._head_dim),
62-
kernel_initializer=self._kernel_initializer,
61+
output_shape=(None, self.num_query_heads, self._head_dim),
62+
kernel_initializer=self.kernel_initializer,
6363
dtype=self.dtype_policy,
6464
name="query",
6565
)
@@ -69,10 +69,10 @@ def build(self, inputs_shape):
6969
equation="bkm,mvh->bkvh",
7070
output_shape=(
7171
None,
72-
self._num_key_value_heads,
72+
self.num_key_value_heads,
7373
self._head_dim,
7474
),
75-
kernel_initializer=self._kernel_initializer,
75+
kernel_initializer=self.kernel_initializer,
7676
dtype=self.dtype_policy,
7777
name="key",
7878
)
@@ -82,10 +82,10 @@ def build(self, inputs_shape):
8282
equation="bkm,mvh->bkvh",
8383
output_shape=(
8484
None,
85-
self._num_key_value_heads,
85+
self.num_key_value_heads,
8686
self._head_dim,
8787
),
88-
kernel_initializer=self._kernel_initializer,
88+
kernel_initializer=self.kernel_initializer,
8989
dtype=self.dtype_policy,
9090
name="value",
9191
)
@@ -98,24 +98,24 @@ def build(self, inputs_shape):
9898
)
9999

100100
self._dropout_layer = keras.layers.Dropout(
101-
rate=self._dropout,
101+
rate=self.dropout,
102102
dtype=self.dtype_policy,
103103
)
104104

105105
self._output_dense = keras.layers.EinsumDense(
106106
equation="bquh,uhm->bqm",
107107
output_shape=(None, self._hidden_dim),
108-
kernel_initializer=self._kernel_initializer,
108+
kernel_initializer=self.kernel_initializer,
109109
dtype=self.dtype_policy,
110110
name="attention_output",
111111
)
112112
self._output_dense.build(
113-
(None, None, self._num_query_heads, self._head_dim)
113+
(None, None, self.num_query_heads, self._head_dim)
114114
)
115115

116116
self.rotary_embedding_layer = RotaryEmbedding(
117-
max_wavelength=self._rope_max_wavelength,
118-
scaling_factor=self._rope_scaling_factor,
117+
max_wavelength=self.rope_max_wavelength,
118+
scaling_factor=self.rope_scaling_factor,
119119
dtype=self.dtype_policy,
120120
)
121121

@@ -162,8 +162,8 @@ def call(
162162

163163
# [batch_shape, seq_len, num_key_value_heads, head_dim]
164164
# -> [batch_shape, seq_len, num_heads, head_dim]
165-
key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2)
166-
value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2)
165+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
166+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
167167

168168
attention_output = self._compute_attention(
169169
query, key, value, attention_mask
@@ -206,14 +206,14 @@ def get_config(self):
206206
config = super().get_config()
207207
config.update(
208208
{
209-
"num_query_heads": self._num_query_heads,
210-
"num_key_value_heads": self._num_key_value_heads,
211-
"rope_max_wavelength": self._rope_max_wavelength,
212-
"rope_scaling_factor": self._rope_scaling_factor,
209+
"num_query_heads": self.num_query_heads,
210+
"num_key_value_heads": self.num_key_value_heads,
211+
"rope_max_wavelength": self.rope_max_wavelength,
212+
"rope_scaling_factor": self.rope_scaling_factor,
213213
"kernel_initializer": keras.initializers.serialize(
214-
self._kernel_initializer
214+
self.kernel_initializer
215215
),
216-
"dropout": self._dropout,
216+
"dropout": self.dropout,
217217
}
218218
)
219219
return config

keras_nlp/models/llama/llama_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from keras_nlp.layers.modeling.transformer_layer_utils import (
2020
merge_padding_and_attention_mask,
2121
)
22-
from keras_nlp.models.llama.llama_attention import CachedLlamaAttention
22+
from keras_nlp.models.llama.llama_attention import LlamaAttention
2323
from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm
2424
from keras_nlp.utils.keras_utils import clone_initializer
2525

@@ -61,7 +61,7 @@ def build(self, decoder_sequence_shape):
6161
self.hidden_dim = decoder_sequence_shape[-1]
6262

6363
# Self attention layer.
64-
self._self_attention_layer = CachedLlamaAttention(
64+
self._self_attention_layer = LlamaAttention(
6565
num_query_heads=self.num_query_heads,
6666
num_key_value_heads=self.num_key_value_heads,
6767
rope_max_wavelength=self.rope_max_wavelength,

0 commit comments

Comments
 (0)