Skip to content

Commit ec5821e

Browse files
Remove self._hidden_dim and self._head_dim
1 parent 4ba6dd2 commit ec5821e

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

keras_nlp/models/llama/llama_attention.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,13 @@ def build(self, inputs_shape):
5353
# u = num query heads
5454
# v = num key/value heads
5555
# h = head dim
56-
self._hidden_dim = inputs_shape[-1]
57-
self._head_dim = self._hidden_dim // self.num_query_heads
56+
hidden_dim = inputs_shape[-1]
57+
head_dim = hidden_dim // self.num_query_heads
58+
self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype))
5859

5960
self._query_dense = keras.layers.EinsumDense(
6061
equation="bqm,muh->bquh",
61-
output_shape=(None, self.num_query_heads, self._head_dim),
62+
output_shape=(None, self.num_query_heads, head_dim),
6263
kernel_initializer=self.kernel_initializer,
6364
dtype=self.dtype_policy,
6465
name="query",
@@ -70,7 +71,7 @@ def build(self, inputs_shape):
7071
output_shape=(
7172
None,
7273
self.num_key_value_heads,
73-
self._head_dim,
74+
head_dim,
7475
),
7576
kernel_initializer=self.kernel_initializer,
7677
dtype=self.dtype_policy,
@@ -83,7 +84,7 @@ def build(self, inputs_shape):
8384
output_shape=(
8485
None,
8586
self.num_key_value_heads,
86-
self._head_dim,
87+
head_dim,
8788
),
8889
kernel_initializer=self.kernel_initializer,
8990
dtype=self.dtype_policy,
@@ -104,14 +105,12 @@ def build(self, inputs_shape):
104105

105106
self._output_dense = keras.layers.EinsumDense(
106107
equation="bquh,uhm->bqm",
107-
output_shape=(None, self._hidden_dim),
108+
output_shape=(None, hidden_dim),
108109
kernel_initializer=self.kernel_initializer,
109110
dtype=self.dtype_policy,
110111
name="attention_output",
111112
)
112-
self._output_dense.build(
113-
(None, None, self.num_query_heads, self._head_dim)
114-
)
113+
self._output_dense.build((None, None, self.num_query_heads, head_dim))
115114

116115
self.rotary_embedding_layer = RotaryEmbedding(
117116
max_wavelength=self.rope_max_wavelength,
@@ -189,9 +188,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
189188
def _compute_attention(self, query, key, value, attention_mask=None):
190189
attention_scores = ops.einsum(self._dot_product_equation, query, key)
191190

192-
norm_factor = ops.sqrt(ops.cast(self._head_dim, self.compute_dtype))
193-
194-
attention_scores = attention_scores / norm_factor
191+
attention_scores = attention_scores / self._norm_factor
195192
attention_scores = self._masked_softmax(
196193
attention_scores, attention_mask
197194
)

0 commit comments

Comments
 (0)