-
Notifications
You must be signed in to change notification settings - Fork 5
/
codebook.py
executable file
·98 lines (80 loc) · 3.59 KB
/
codebook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from re import X
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
class VectorQuantizerEMA(nn.Module):
"""
Vector Quantizer with Exponential Moving Average (EMA) for the codebook.
Adapted from https://github.com/devnkong/GOAT
Args:
num_embeddings (int): The number of embeddings in the codebook.
embedding_dim (int): The dimensionality of each embedding.
decay (float, optional): The decay rate for the EMA. Defaults to 0.99.
Attributes:
_embedding_dim (int): The dimensionality of each embedding.
_num_embeddings (int): The number of embeddings in the codebook.
_decay (float): The decay rate for the EMA.
_embedding (nn.Embedding): The embedding matrix.
_ema_cluster_size (torch.Tensor): The exponential moving average of the cluster sizes.
_ema_w (torch.Tensor): The exponential moving average of the embedding updates.
"""
def __init__(self, num_embeddings, embedding_dim, decay=0.99):
super(VectorQuantizerEMA, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self.register_buffer(
"_embedding", torch.randn(self._num_embeddings, self._embedding_dim * 2)
)
self.register_buffer(
"_embedding_output",
torch.randn(self._num_embeddings, self._embedding_dim * 2),
)
self.register_buffer("_ema_cluster_size", torch.zeros(num_embeddings))
self.register_buffer(
"_ema_w", torch.randn(self._num_embeddings, self._embedding_dim * 2)
)
self._decay = decay
self.bn = torch.nn.BatchNorm1d(self._embedding_dim * 2, affine=False)
def get_k(self):
"""
Returns the key tensor of the embedding matrix.
"""
return self._embedding_output
def get_v(self):
"""
Returns the value tensor of the embedding matrix.
"""
return self._embedding_output[:, : self._embedding_dim]
def update(self, x):
inputs_normalized = self.bn(x)
embedding_normalized = self._embedding
# Calculate distances
distances = (
torch.sum(inputs_normalized**2, dim=1, keepdim=True)
+ torch.sum(embedding_normalized**2, dim=1)
- 2 * torch.matmul(inputs_normalized, embedding_normalized.t())
)
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(
encoding_indices.shape[0], self._num_embeddings, device=x.device
)
encodings.scatter_(1, encoding_indices, 1)
# Use EMA to update the embedding vectors
if self.training:
self._ema_cluster_size.data = self._ema_cluster_size * self._decay + (
1 - self._decay
) * torch.sum(encodings, 0)
# Laplace smoothing of the cluster size
n = torch.sum(self._ema_cluster_size.data)
self._ema_cluster_size.data = (
(self._ema_cluster_size + 1e-5) / (n + self._num_embeddings * 1e-5) * n
)
dw = torch.matmul(encodings.t(), inputs_normalized)
self._ema_w.data = self._ema_w * self._decay + (1 - self._decay) * dw
self._embedding.data = self._ema_w / self._ema_cluster_size.unsqueeze(1)
running_std = torch.sqrt(self.bn.running_var + 1e-5).unsqueeze(dim=0)
running_mean = self.bn.running_mean.unsqueeze(dim=0)
self._embedding_output.data = self._embedding * running_std + running_mean
return encoding_indices