17
17
from keras_nlp .utils .keras_utils import clone_initializer
18
18
19
19
20
- class CachedLlamaAttention (keras .layers .Layer ):
20
+ class LlamaAttention (keras .layers .Layer ):
21
21
"""A cached grounded query attention layer with sliding window."""
22
22
23
23
def __init__ (
@@ -31,18 +31,18 @@ def __init__(
31
31
** kwargs ,
32
32
):
33
33
super ().__init__ (** kwargs )
34
- self ._num_query_heads = num_query_heads
35
- self ._num_key_value_heads = num_key_value_heads
36
- self ._dropout = dropout
34
+ self .num_query_heads = num_query_heads
35
+ self .num_key_value_heads = num_key_value_heads
36
+ self .dropout = dropout
37
37
38
- self ._num_key_value_groups = num_query_heads // num_key_value_heads
39
- self ._rope_max_wavelength = rope_max_wavelength
38
+ self .num_key_value_groups = num_query_heads // num_key_value_heads
39
+ self .rope_max_wavelength = rope_max_wavelength
40
40
41
- self ._kernel_initializer = keras .initializers .get (
41
+ self .kernel_initializer = keras .initializers .get (
42
42
clone_initializer (kernel_initializer )
43
43
)
44
44
45
- self ._rope_scaling_factor = rope_scaling_factor
45
+ self .rope_scaling_factor = rope_scaling_factor
46
46
47
47
def build (self , inputs_shape ):
48
48
# Einsum variables:
@@ -54,12 +54,12 @@ def build(self, inputs_shape):
54
54
# v = num key/value heads
55
55
# h = head dim
56
56
self ._hidden_dim = inputs_shape [- 1 ]
57
- self ._head_dim = self ._hidden_dim // self ._num_query_heads
57
+ self ._head_dim = self ._hidden_dim // self .num_query_heads
58
58
59
59
self ._query_dense = keras .layers .EinsumDense (
60
60
equation = "bqm,muh->bquh" ,
61
- output_shape = (None , self ._num_query_heads , self ._head_dim ),
62
- kernel_initializer = self ._kernel_initializer ,
61
+ output_shape = (None , self .num_query_heads , self ._head_dim ),
62
+ kernel_initializer = self .kernel_initializer ,
63
63
dtype = self .dtype_policy ,
64
64
name = "query" ,
65
65
)
@@ -69,10 +69,10 @@ def build(self, inputs_shape):
69
69
equation = "bkm,mvh->bkvh" ,
70
70
output_shape = (
71
71
None ,
72
- self ._num_key_value_heads ,
72
+ self .num_key_value_heads ,
73
73
self ._head_dim ,
74
74
),
75
- kernel_initializer = self ._kernel_initializer ,
75
+ kernel_initializer = self .kernel_initializer ,
76
76
dtype = self .dtype_policy ,
77
77
name = "key" ,
78
78
)
@@ -82,10 +82,10 @@ def build(self, inputs_shape):
82
82
equation = "bkm,mvh->bkvh" ,
83
83
output_shape = (
84
84
None ,
85
- self ._num_key_value_heads ,
85
+ self .num_key_value_heads ,
86
86
self ._head_dim ,
87
87
),
88
- kernel_initializer = self ._kernel_initializer ,
88
+ kernel_initializer = self .kernel_initializer ,
89
89
dtype = self .dtype_policy ,
90
90
name = "value" ,
91
91
)
@@ -98,24 +98,24 @@ def build(self, inputs_shape):
98
98
)
99
99
100
100
self ._dropout_layer = keras .layers .Dropout (
101
- rate = self ._dropout ,
101
+ rate = self .dropout ,
102
102
dtype = self .dtype_policy ,
103
103
)
104
104
105
105
self ._output_dense = keras .layers .EinsumDense (
106
106
equation = "bquh,uhm->bqm" ,
107
107
output_shape = (None , self ._hidden_dim ),
108
- kernel_initializer = self ._kernel_initializer ,
108
+ kernel_initializer = self .kernel_initializer ,
109
109
dtype = self .dtype_policy ,
110
110
name = "attention_output" ,
111
111
)
112
112
self ._output_dense .build (
113
- (None , None , self ._num_query_heads , self ._head_dim )
113
+ (None , None , self .num_query_heads , self ._head_dim )
114
114
)
115
115
116
116
self .rotary_embedding_layer = RotaryEmbedding (
117
- max_wavelength = self ._rope_max_wavelength ,
118
- scaling_factor = self ._rope_scaling_factor ,
117
+ max_wavelength = self .rope_max_wavelength ,
118
+ scaling_factor = self .rope_scaling_factor ,
119
119
dtype = self .dtype_policy ,
120
120
)
121
121
@@ -162,8 +162,8 @@ def call(
162
162
163
163
# [batch_shape, seq_len, num_key_value_heads, head_dim]
164
164
# -> [batch_shape, seq_len, num_heads, head_dim]
165
- key = ops .repeat (key , repeats = self ._num_key_value_groups , axis = 2 )
166
- value = ops .repeat (value , repeats = self ._num_key_value_groups , axis = 2 )
165
+ key = ops .repeat (key , repeats = self .num_key_value_groups , axis = 2 )
166
+ value = ops .repeat (value , repeats = self .num_key_value_groups , axis = 2 )
167
167
168
168
attention_output = self ._compute_attention (
169
169
query , key , value , attention_mask
@@ -206,14 +206,14 @@ def get_config(self):
206
206
config = super ().get_config ()
207
207
config .update (
208
208
{
209
- "num_query_heads" : self ._num_query_heads ,
210
- "num_key_value_heads" : self ._num_key_value_heads ,
211
- "rope_max_wavelength" : self ._rope_max_wavelength ,
212
- "rope_scaling_factor" : self ._rope_scaling_factor ,
209
+ "num_query_heads" : self .num_query_heads ,
210
+ "num_key_value_heads" : self .num_key_value_heads ,
211
+ "rope_max_wavelength" : self .rope_max_wavelength ,
212
+ "rope_scaling_factor" : self .rope_scaling_factor ,
213
213
"kernel_initializer" : keras .initializers .serialize (
214
- self ._kernel_initializer
214
+ self .kernel_initializer
215
215
),
216
- "dropout" : self ._dropout ,
216
+ "dropout" : self .dropout ,
217
217
}
218
218
)
219
219
return config
0 commit comments