@@ -53,12 +53,13 @@ def build(self, inputs_shape):
53
53
# u = num query heads
54
54
# v = num key/value heads
55
55
# h = head dim
56
- self ._hidden_dim = inputs_shape [- 1 ]
57
- self ._head_dim = self ._hidden_dim // self .num_query_heads
56
+ hidden_dim = inputs_shape [- 1 ]
57
+ head_dim = hidden_dim // self .num_query_heads
58
+ self ._norm_factor = ops .sqrt (ops .cast (head_dim , self .compute_dtype ))
58
59
59
60
self ._query_dense = keras .layers .EinsumDense (
60
61
equation = "bqm,muh->bquh" ,
61
- output_shape = (None , self .num_query_heads , self . _head_dim ),
62
+ output_shape = (None , self .num_query_heads , head_dim ),
62
63
kernel_initializer = self .kernel_initializer ,
63
64
dtype = self .dtype_policy ,
64
65
name = "query" ,
@@ -70,7 +71,7 @@ def build(self, inputs_shape):
70
71
output_shape = (
71
72
None ,
72
73
self .num_key_value_heads ,
73
- self . _head_dim ,
74
+ head_dim ,
74
75
),
75
76
kernel_initializer = self .kernel_initializer ,
76
77
dtype = self .dtype_policy ,
@@ -83,7 +84,7 @@ def build(self, inputs_shape):
83
84
output_shape = (
84
85
None ,
85
86
self .num_key_value_heads ,
86
- self . _head_dim ,
87
+ head_dim ,
87
88
),
88
89
kernel_initializer = self .kernel_initializer ,
89
90
dtype = self .dtype_policy ,
@@ -104,14 +105,12 @@ def build(self, inputs_shape):
104
105
105
106
self ._output_dense = keras .layers .EinsumDense (
106
107
equation = "bquh,uhm->bqm" ,
107
- output_shape = (None , self . _hidden_dim ),
108
+ output_shape = (None , hidden_dim ),
108
109
kernel_initializer = self .kernel_initializer ,
109
110
dtype = self .dtype_policy ,
110
111
name = "attention_output" ,
111
112
)
112
- self ._output_dense .build (
113
- (None , None , self .num_query_heads , self ._head_dim )
114
- )
113
+ self ._output_dense .build ((None , None , self .num_query_heads , head_dim ))
115
114
116
115
self .rotary_embedding_layer = RotaryEmbedding (
117
116
max_wavelength = self .rope_max_wavelength ,
@@ -189,9 +188,7 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
189
188
def _compute_attention (self , query , key , value , attention_mask = None ):
190
189
attention_scores = ops .einsum (self ._dot_product_equation , query , key )
191
190
192
- norm_factor = ops .sqrt (ops .cast (self ._head_dim , self .compute_dtype ))
193
-
194
- attention_scores = attention_scores / norm_factor
191
+ attention_scores = attention_scores / self ._norm_factor
195
192
attention_scores = self ._masked_softmax (
196
193
attention_scores , attention_mask
197
194
)
0 commit comments