-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnknn.py
72 lines (57 loc) · 1.95 KB
/
nknn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations
from typing import Annotated
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import einsum
__all__ = [
'NKNN',
]
def _compute_omega(
s: Annotated[torch.Tensor, '*B', 'N'], k: int, t: float
) -> Annotated[torch.Tensor, '*B', 'N', 'K']:
alpha = F.softmax(s, dim=-1)
omega = torch.empty(*s.shape, k)
omega[..., 0] = F.softmax(alpha / t, dim=-1)
for i in range(1, k):
alpha = alpha + torch.log(1 - omega[..., i - 1])
omega[..., i] = F.softmax(alpha / t, dim=-1)
return omega
class NKNN(nn.Module):
_k: int
_temp: float
_dim: int
_feature: int
_no_values: bool = False
def __init__(
self, k: int, dim: int, temp: float, feature: int | None = None
) -> None:
super().__init__()
self._k = k
self._temp = temp
self._dim = dim
self._feature = feature or dim
if feature is None:
self._no_values = True
def _similarity(
self,
query: Annotated[torch.Tensor, '*B', 'D'],
key: Annotated[torch.Tensor, '*B', 'D', 'N'],
) -> Annotated[torch.Tensor, '*B', 'N']:
return -einsum(query, key, '... D, ... D N -> ... N') / (self._dim**0.5)
def forward(
self,
query: Annotated[torch.Tensor, '*B', 'D'],
keys: Annotated[torch.Tensor, '*B', 'D', 'N'],
values: Annotated[torch.Tensor, '*B', 'F', 'N'] | None = None,
) -> Annotated[torch.Tensor, '*B', 'K', 'F']:
if values is None:
assert self._no_values
values = keys
assert query.shape[-1] == keys.shape[-2] == self._dim
assert values.shape[-2] == self._feature
assert keys.shape[-1] == values.shape[-1]
sims = self._similarity(query, keys)
omega = _compute_omega(s=sims, k=self._k, t=self._temp)
k_nearest = einsum(omega, values, '... N K, ... F N -> ... K F')
return k_nearest