forked from ml-jku/hopfield-layers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
activation.py
326 lines (286 loc) · 17.5 KB
/
activation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Linear, Module, Parameter
from typing import Optional
from .functional import hopfield_core_forward
try:
from torch.nn.modules.linear import _LinearWithBias
except ImportError:
_LinearWithBias = None
class HopfieldCore(Module):
r"""Allows the model to jointly attend to information
from different representation subspaces.
See references: "Hopfield Networks is All You Need" and
"Attention Is All You Need" (on which this implementation is partly based on).
.. math::
\text{HopfieldHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
kdim: total number of features in key. Default: None.
vdim: total number of features in value. Default: None.
Note: if kdim and vdim are None, they will be set to embed_dim such that
query, key, and value have the same number of features.
Examples::
>>> hopfield_attn = HopfieldCore(embed_dim, num_heads)
>>> attn_output, attn_output_weights, attn_matrix = hopfield_attn(query, key, value)
"""
__annotations__ = {
'bias_k': torch._jit_internal.Optional[torch.Tensor],
'bias_v': torch._jit_internal.Optional[torch.Tensor],
}
def __init__(self,
embed_dim=None, # type: Optional[int]
num_heads=1, # type: int
dropout=0.0, # type: float
bias=True, # type: bool
add_bias_kv=False, # type: bool
add_zero_attn=False, # type: bool
kdim=None, # type: Optional[int]
vdim=None, # type: Optional[int]
head_dim=None, # type: Optional[int]
out_dim=None, # type: Optional[int]
disable_out_projection=False, # type: bool
key_as_static=False, # type: bool
query_as_static=False, # type: bool
value_as_static=False, # type: bool
value_as_connected=False, # type: bool
normalize_pattern=False, # type: bool
normalize_pattern_affine=False # type: bool
):
super(HopfieldCore, self).__init__()
assert (type(key_as_static) == bool) and (type(query_as_static) == bool) and (type(value_as_static) == bool)
self.key_as_static, self.query_as_static, self.value_as_static = key_as_static, query_as_static, value_as_static
num_non_static = 3 - (self.key_as_static + self.query_as_static + self.value_as_static)
assert 0 <= num_non_static < 4
self.value_as_connected = value_as_connected
self.normalize_pattern, self.normalize_pattern_affine = normalize_pattern, normalize_pattern_affine
self.disable_out_projection = disable_out_projection
# In case of a static-only executions, check corresponding projections and normalizations.
self.static_execution = self._check_execution_mode()
if self.static_execution:
embed_dim, kdim, vdim = None, None, None
if embed_dim is None:
assert self.static_execution, r'static-only execution requires all projections to be deactivated.'
# Check and set all other properties, conditioned on <static_execution>.
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim and not self.value_as_connected
assert (not self.value_as_connected) or (self.kdim == self.vdim), r'key and value need to be of same dimension.'
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = None
if not self.static_execution:
if head_dim is None:
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads."
else:
assert head_dim > 0, "dimension of the association space has to be positive."
self.head_dim = head_dim
self.virtual_hopfield_dim = None if (self.head_dim is None) else (self.num_heads * self.head_dim)
self.out_dim = embed_dim if out_dim is None else out_dim
assert disable_out_projection or (self.out_dim > 0), "output projection dimension has to be positive."
if normalize_pattern_affine:
assert normalize_pattern, "affine pattern normalization without pattern normalization has no effect."
self.p_norm_weight = Parameter(torch.Tensor(head_dim))
self.p_norm_bias = Parameter(torch.Tensor(head_dim))
else:
self.register_parameter('p_norm_weight', None)
self.register_parameter('p_norm_bias', None)
if self._qkv_same_embed_dim is False:
if query_as_static:
self.register_parameter('q_proj_weight', None)
else:
self.q_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, embed_dim))
if key_as_static:
self.register_parameter('k_proj_weight', None)
else:
self.k_proj_weight = Parameter(torch.Tensor(self.virtual_hopfield_dim, self.kdim))
if value_as_static:
self.register_parameter('v_proj_weight', None)
else:
self.v_proj_weight = Parameter(torch.Tensor(
self.virtual_hopfield_dim,
self.virtual_hopfield_dim if (value_as_connected and not key_as_static) else self.vdim))
self.register_parameter('in_proj_weight', None)
else:
if num_non_static > 0:
self.in_proj_weight = Parameter(torch.empty(num_non_static * self.virtual_hopfield_dim, embed_dim))
else:
self.register_parameter('in_proj_weight', None)
self.register_parameter('q_proj_weight', None)
self.register_parameter('k_proj_weight', None)
self.register_parameter('v_proj_weight', None)
if bias and (num_non_static > 0):
self.in_proj_bias = Parameter(torch.empty(num_non_static * self.virtual_hopfield_dim))
else:
self.register_parameter('in_proj_bias', None)
if disable_out_projection:
self.register_parameter('out_proj', None)
else:
if bias and _LinearWithBias is not None:
self.out_proj = _LinearWithBias(self.virtual_hopfield_dim, self.out_dim)
else:
self.out_proj = Linear(self.virtual_hopfield_dim, self.out_dim, bias=bias)
self.bias_k, self.bias_v = None, None
if add_bias_kv:
if not key_as_static:
self.bias_k = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim))
if not value_as_static:
self.bias_v = Parameter(torch.empty(1, 1, self.virtual_hopfield_dim))
assert not (self.bias_k is None and self.bias_v is None), r'cannot set key/value bias if both are static.'
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _check_execution_mode(self) -> bool:
return all((
self.key_as_static, self.query_as_static, self.value_as_static, not self.value_as_connected,
not self.normalize_pattern, not self.normalize_pattern_affine, self.disable_out_projection)
)
def _reset_parameters(self):
if self.p_norm_weight is not None:
nn.init.ones_(self.p_norm_weight)
nn.init.zeros_(self.p_norm_bias)
if self._qkv_same_embed_dim and (self.in_proj_weight is not None):
nn.init.normal_(self.in_proj_weight, mean=0.0, std=0.02)
else:
if self.q_proj_weight is not None:
nn.init.normal_(self.q_proj_weight, mean=0.0, std=0.02)
if self.k_proj_weight is not None:
nn.init.normal_(self.k_proj_weight, mean=0.0, std=0.02)
if self.v_proj_weight is not None:
nn.init.normal_(self.v_proj_weight, mean=0.0, std=0.02)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.0)
if not self.disable_out_projection:
nn.init.constant_(self.out_proj.bias, 0.0)
if self.bias_k is not None:
nn.init.normal_(self.bias_k, mean=0.0, std=0.02)
if self.bias_v is not None:
nn.init.normal_(self.bias_v, mean=0.0, std=0.02)
def __setstate__(self, state):
super(HopfieldCore, self).__setstate__(state)
def forward(self,
query, # type: Tensor
key, # type: Tensor
value, # type: Tensor
key_padding_mask=None, # type: Optional[Tensor]
need_weights=True, # type: bool
attn_mask=None, # type: Optional[Tensor]
scaling=None, # type: Optional[Tensor]
update_steps_max=0, # type: Optional[int]
update_steps_eps=1e-4, # type: float
return_raw_associations=False, # type: bool
return_pattern_projections=False # type: bool
):
# type: (...) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
See "Hopfield Networks is All You Need" for more details in the setting of Hopfield networks.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
scaling: scaling of association heads, often represented as beta (one entry per head).
update_steps_max: maximum count of association update steps (None equals to infinity).
update_steps_eps: minimum difference threshold between two consecutive association update steps.
return_raw_associations: return raw association (softmax) values, unmodified.
return_pattern_projections: return pattern projection values, unmodified.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- scaling: :math:`(num_heads,)`, where num_heads is the amount of heads.
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
- attn_raw: :math:``(N, num_heads, L, S)`, where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
if self.query_as_static and self.key_as_static:
assert query.shape[2] == key.shape[2], \
f'query shape[2] of {query.shape[2]} and key shape[2] of {key.shape[2]} need to be equal'
head_dim, embed_dim_to_check = query.shape[2], query.shape[2]
else:
assert self.query_as_static or (query.shape[2] == self.embed_dim), \
f'query shape[2] of {query.shape[2]} invalid, needs to be {self.embed_dim}.'
assert (not self.query_as_static) or (self.query_as_static and query.shape[2] == self.head_dim), \
f'query shape[2] of {query.shape[2]} invalid, needs to be {self.head_dim}'
assert self.key_as_static or (key.shape[2] == self.kdim), \
f'key shape[2] of {key.shape[2]} invalid, needs to be {self.kdim}.'
assert (not self.key_as_static) or (self.key_as_static and key.shape[2] == self.head_dim), \
f'key shape[2] of {key.shape[2]} invalid, needs to be {self.head_dim}'
head_dim, embed_dim_to_check = self.head_dim, self.head_dim if self.query_as_static else self.embed_dim
assert self.value_as_static or (value.shape[2] == self.vdim), \
f'value shape[2] of {value.shape[2]} invalid, needs to be {self.vdim}.'
assert any((
not self.value_as_static, self.value_as_static and value.shape[2] == self.head_dim,
self.disable_out_projection)
), f'value shape[2] of {value.shape[2]} invalid, needs to be {self.head_dim}'
out_weights, out_bias = None, None
if not self.disable_out_projection:
out_weights, out_bias = self.out_proj.weight, self.out_proj.bias
if not self._qkv_same_embed_dim:
return hopfield_core_forward(
query, key, value, embed_dim_to_check, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, out_weights, out_bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight,
key_as_static=self.key_as_static, query_as_static=self.query_as_static,
value_as_static=self.value_as_static, value_as_connected=self.value_as_connected,
normalize_pattern=self.normalize_pattern,
p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias,
head_dim=head_dim, scaling=scaling,
update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections)
else:
return hopfield_core_forward(
query, key, value, embed_dim_to_check, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, out_weights, out_bias,
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask,
key_as_static=self.key_as_static, query_as_static=self.query_as_static,
value_as_static=self.value_as_static, value_as_connected=self.value_as_connected,
normalize_pattern=self.normalize_pattern,
p_norm_weight=self.p_norm_weight, p_norm_bias=self.p_norm_bias,
head_dim=head_dim, scaling=scaling,
update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
return_raw_associations=return_raw_associations, return_projected_patterns=return_pattern_projections)