forked from gned0/avalanche
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
234 lines (188 loc) · 8.19 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import copy
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import BCELoss
from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.regularization import cross_entropy_with_oh_targets
from avalanche._annotations import deprecated
class ICaRLLossPlugin(SupervisedPlugin):
"""
ICaRLLossPlugin
Similar to the Knowledge Distillation Loss. Works as follows:
The target is constructed by taking the one-hot vector target for the
current sample and assigning to the position corresponding to the
past classes the output of the old model on the current sample.
Doesn't work if classes observed in previous experiences might be
observed again in future training experiences.
"""
def __init__(self):
super().__init__()
self.criterion = BCELoss()
self.old_classes = []
self.old_model = None
self.old_logits = None
def before_forward(self, strategy, **kwargs):
if self.old_model is not None:
with torch.no_grad():
self.old_logits = self.old_model(strategy.mb_x)
def __call__(self, logits, targets):
predictions = torch.sigmoid(logits)
one_hot = torch.zeros(
targets.shape[0],
logits.shape[1],
dtype=torch.float,
device=logits.device,
)
one_hot[range(len(targets)), targets.long()] = 1
if self.old_logits is not None:
old_predictions = torch.sigmoid(self.old_logits)
one_hot[:, self.old_classes] = old_predictions[:, self.old_classes]
self.old_logits = None
return self.criterion(predictions, one_hot)
def after_training_exp(self, strategy, **kwargs):
if self.old_model is None:
old_model = copy.deepcopy(strategy.model)
self.old_model = old_model.to(strategy.device)
self.old_model.load_state_dict(strategy.model.state_dict())
self.old_classes += np.unique(strategy.experience.dataset.targets).tolist()
class SCRLoss(torch.nn.Module):
"""
Supervised Contrastive Replay Loss as defined in Eq. 5 of
https://arxiv.org/pdf/2103.13885.pdf.
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
Original GitHub repository: https://github.com/HobbitLong/SupContrast/
LICENSE: BSD 2-Clause License
"""
def __init__(self, temperature=0.07, contrast_mode="all", base_temperature=0.07):
super().__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature
def forward(self, features, labels=None, mask=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
features: [bsz, n_views, f_dim]
`n_views` is the number of crops from each image, better
be L2 normalized in f_dim dimension
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
device = features.device
if len(features.shape) < 3:
raise ValueError(
"`features` needs to be [bsz, n_views, ...],"
"at least 3 dimensions are required"
)
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError("Cannot define both `labels` and `mask`")
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError("Num of labels does not match num of features")
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == "one":
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == "all":
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError("Unknown mode: {}".format(self.contrast_mode))
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T), self.temperature
)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0,
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss
class MaskedCrossEntropy(SupervisedPlugin):
"""
Masked Cross Entropy
This criterion can be used for instance in Class Incremental
Learning Problems when no examplars are used
(i.e LwF in Class Incremental Learning would need to use mask="new").
"""
def __init__(self, classes=None, mask="seen", reduction="mean"):
"""
param: classes: Initial value for current classes
param: mask: "all" normal cross entropy, uses all the classes seen so far
"old" cross entropy only on the old classes
"new" cross entropy only on the new classes
param: reduction: "mean" or "none", average or per-sample loss
"""
super().__init__()
assert mask in ["seen", "new", "old", "all"]
if classes is not None:
self.current_classes = set(classes)
else:
self.current_classes = set()
self.old_classes = set()
self.reduction = reduction
self.mask = mask
def __call__(self, logits, targets):
oh_targets = F.one_hot(targets, num_classes=logits.shape[1])
oh_targets = oh_targets[:, self.current_mask(logits.shape[1])]
logits = logits[:, self.current_mask(logits.shape[1])]
return cross_entropy_with_oh_targets(
logits,
oh_targets.float(),
reduction=self.reduction,
)
def current_mask(self, logit_shape):
if self.mask == "seen":
return list(self.current_classes.union(self.old_classes))
elif self.mask == "new":
return list(self.current_classes)
elif self.mask == "old":
return list(self.old_classes)
elif self.mask == "all":
return list(range(int(logit_shape)))
def _adaptation(self, new_classes):
self.old_classes = self.old_classes.union(self.current_classes)
self.current_classes = set(new_classes)
def pre_adapt(self, agent, exp):
self._adaptation(exp.classes_in_this_experience)
@deprecated(0.7, "Please switch to the `pre_adapt`or `_adaptation` methods.")
def adaptation(self, new_classes):
self._adaptation(new_classes)
@deprecated(0.7, "Please switch to the `pre_adapt` method.")
def before_training_exp(self, strategy, **kwargs):
self._adaptation(strategy.experience.classes_in_this_experience)
__all__ = ["ICaRLLossPlugin", "SCRLoss", "MaskedCrossEntropy"]