-
-
Notifications
You must be signed in to change notification settings - Fork 4
/
rcil.py
273 lines (271 loc) · 17.1 KB
/
rcil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
'''
Function:
Implementation of "Representation Compensation Networks for Continual Semantic Segmentation"
Author:
Zhenchao Jin
'''
import copy
import math
import torch
import functools
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from .mib import MIBRunner
from .base import BaseRunner
'''RCILRunner'''
class RCILRunner(BaseRunner):
def __init__(self, mode, cmd_args, runner_cfg):
super(RCILRunner, self).__init__(
mode=mode, cmd_args=cmd_args, runner_cfg=runner_cfg
)
'''convertsegmentors'''
def convertsegmentors(self):
# merge
def merge(conv2d, bn2d, conv_bias=None):
if conv_bias is not None: conv_bias = conv_bias.clone().to(conv2d.weight.device)
k = conv2d.weight.clone()
running_mean, running_var, eps = bn2d.running_mean, bn2d.running_var, bn2d.eps
gamma, beta = bn2d.weight.abs() + eps, bn2d.bias
gamma, beta = gamma / 2., beta / 2.
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
if conv_bias is not None:
return k * t, beta - running_mean * gamma / std + t.view(-1) * conv_bias.view(-1)
else:
return k * t, beta - running_mean * gamma / std
# mergex
def mergex(conv2d, bn2d, index, conv_bias=None, feats_channels=256):
if conv_bias is not None: conv_bias = conv_bias.clone().to(conv2d.weight.device)
k = conv2d.weight.clone()
running_mean = bn2d.running_mean[index * feats_channels: (1 + index) * feats_channels]
running_var = bn2d.running_var[index * feats_channels: (1 + index) * feats_channels]
eps = bn2d.eps
gamma = bn2d.weight.abs()[index * feats_channels: (1 + index) * feats_channels] + eps
beta = bn2d.bias[index * feats_channels: (1 + index) * feats_channels]
gamma, beta = gamma / 2., beta / 2.
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
if conv_bias is not None:
return k * t, beta - running_mean * gamma / std + t.view(-1) * conv_bias.view(-1)
else:
return k * t, beta - running_mean * gamma / std
# iter to convert segmentor
for name, module in self.segmentor.named_modules():
if hasattr(module, 'conv2') and hasattr(module, 'bn2') and hasattr(module, 'conv2_branch2') and hasattr(module, 'bn2_branch2'):
module.conv2.bias = nn.Parameter(torch.zeros(module.conv2.weight.shape[0]).to(module.conv2.weight.device))
elif hasattr(module, 'parallel_convs_branch1'):
for idx in range(len(module.parallel_convs_branch1)):
module.parallel_convs_branch1[idx].bias = nn.Parameter(torch.zeros(module.parallel_convs_branch1[idx].weight.shape[0]).to(module.parallel_convs_branch1[idx].weight.device))
for name, module in self.segmentor.named_modules():
if hasattr(module, 'conv2') and hasattr(module, 'bn2') and hasattr(module, 'conv2_branch2') and hasattr(module, 'bn2_branch2'):
k1, b1 = merge(module.conv2, module.bn2, module.conv2.bias.data)
k2, b2 = merge(module.conv2_branch2, module.bn2_branch2, None)
k, b = k1 + k2, b1 + b2
module.conv2.weight.data[:, :, :, :] = k[:, :, :, :]
module.conv2.bias = nn.Parameter(b)
module.bn2.bias.data[:] = torch.zeros((module.bn2.weight.shape[0],))[:]
module.bn2.running_var.data[:] = torch.ones((module.bn2.weight.shape[0],))[:]
module.bn2.eps = 0
module.bn2.weight.data[:] = torch.ones((module.bn2.weight.shape[0],))[:]
module.bn2.running_mean.data[:] = torch.zeros((module.bn2.weight.shape[0],))[:]
module.bn2.eval()
module.conv2.eval()
for param in module.bn2.parameters():
param.requires_grad = False
for param in module.conv2.parameters():
param.requires_grad = False
elif hasattr(module, 'parallel_convs_branch1'):
for idx in range(len(module.parallel_convs_branch1)):
k1, b1 = mergex(module.parallel_convs_branch1[idx], module.parallel_bn_branch1[0], idx, module.parallel_convs_branch1[idx].bias.data)
k2, b2 = mergex(module.parallel_convs_branch2[idx], module.parallel_bn_branch2[0], idx, None)
k, b = k1 + k2, b1 + b2
module.parallel_convs_branch1[idx].weight.data[:, :, :, :] = k[:, :, :, :]
module.parallel_convs_branch1[idx].bias = nn.Parameter(b)
module.parallel_convs_branch1[idx].eval()
for param in module.parallel_convs_branch1[idx].parameters():
param.requires_grad = False
module.parallel_bn_branch1[0].bias.data[:] = torch.zeros((module.parallel_bn_branch1[0].weight.shape[0],))[:]
module.parallel_bn_branch1[0].running_var.data[:] = torch.ones((module.parallel_bn_branch1[0].weight.shape[0],))[:]
module.parallel_bn_branch1[0].eps = 0
module.parallel_bn_branch1[0].weight.data[:] = torch.ones((module.parallel_bn_branch1[0].weight.shape[0],))[:]
module.parallel_bn_branch1[0].running_mean.data[:] = torch.zeros((module.parallel_bn_branch1[0].weight.shape[0],))[:]
module.parallel_bn_branch1.eval()
for param in module.parallel_bn_branch1.parameters():
param.requires_grad = False
# iter to convert history_segmentor
if self.runner_cfg['task_id'] > 1:
for name, module in self.history_segmentor.named_modules():
if hasattr(module, 'conv2') and hasattr(module, 'bn2') and hasattr(module, 'conv2_branch2') and hasattr(module, 'bn2_branch2'):
module.conv2.bias = nn.Parameter(torch.zeros(module.conv2.weight.shape[0]).to(module.conv2.weight.device))
elif hasattr(module, 'parallel_convs_branch1'):
for idx in range(len(module.parallel_convs_branch1)):
module.parallel_convs_branch1[idx].bias = nn.Parameter(torch.zeros(module.parallel_convs_branch1[idx].weight.shape[0]).to(module.parallel_convs_branch1[idx].weight.device))
'''train'''
def train(self, cur_epoch):
# initialize
losses_cfgs = copy.deepcopy(self.losses_cfgs)
init_losses_log_dict = {
'algorithm': self.runner_cfg['algorithm'], 'task_id': self.runner_cfg['task_id'],
'epoch': self.scheduler.cur_epoch, 'iteration': self.scheduler.cur_iter, 'lr': self.scheduler.cur_lr
}
losses_log_dict = copy.deepcopy(init_losses_log_dict)
self.segmentor.train()
self.train_loader.sampler.set_epoch(cur_epoch)
if self.runner_cfg['task_id'] > 0:
for name, module in self.segmentor.named_modules():
if hasattr(module, 'conv2') and hasattr(module, 'bn2') and hasattr(module, 'conv2_branch2') and hasattr(module, 'bn2_branch2'):
for param in module.conv2.parameters():
param.requires_grad = False
for param in module.bn2.parameters():
param.requires_grad = False
module.bn2.eval()
elif hasattr(module, 'parallel_convs_branch1'):
for param in module.parallel_convs_branch1.parameters():
param.requires_grad = False
for param in module.parallel_bn_branch1.parameters():
param.requires_grad = False
module.parallel_bn_branch1.eval()
# start to iter
for batch_idx, data_meta in enumerate(self.train_loader):
# --fetch data
images = data_meta['image'].to(self.device, dtype=torch.float32)
seg_targets = data_meta['seg_target'].to(self.device, dtype=torch.long)
# --feed to history_segmentor
if self.history_segmentor is not None:
with torch.no_grad():
history_outputs = self.history_segmentor(images)
history_distillation_feats = history_outputs['distillation_feats']
history_distillation_feats.append(history_outputs['seg_logits'])
# --forward to segmentor
outputs = self.segmentor(images)
# --calculate segmentation losses
seg_losses_cfgs = copy.deepcopy(losses_cfgs['segmentation_cl']) if self.history_segmentor is not None else copy.deepcopy(losses_cfgs['segmentation_init'])
if self.history_segmentor is not None:
num_history_known_classes = functools.reduce(lambda a, b: a + b, self.runner_cfg['segmentor_cfg']['num_known_classes_list'][:-1])
for _, seg_losses_cfg in seg_losses_cfgs.items():
for loss_type, loss_cfg in seg_losses_cfg.items():
loss_cfg.update({'num_history_known_classes': num_history_known_classes})
seg_total_loss, seg_losses_log_dict = self.segmentor.module.calculateseglosses(
seg_logits=outputs['seg_logits'],
seg_targets=seg_targets,
losses_cfgs=seg_losses_cfgs,
)
# --calculate pod distillation losses
pod_total_loss, pod_losses_log_dict = 0, {}
if self.history_segmentor is not None:
distillation_feats = outputs['distillation_feats']
distillation_feats.append(outputs['seg_logits'])
pod_total_loss, pod_losses_log_dict = self.featuresdistillation(
history_distillation_feats=history_distillation_feats,
distillation_feats=distillation_feats,
num_known_classes_list=self.runner_cfg['segmentor_cfg']['num_known_classes_list'],
dataset_type=self.runner_cfg['dataset_cfg']['type'],
**losses_cfgs['distillation_rcil']
)
# --calculate mib distillation losses
kd_total_loss, kd_losses_log_dict = 0, {}
if self.history_segmentor is not None:
kd_total_loss, kd_losses_log_dict = MIBRunner.featuresdistillation(
history_distillation_feats=F.interpolate(history_outputs['seg_logits'], size=images.shape[2:], mode="bilinear", align_corners=self.segmentor.module.align_corners),
distillation_feats=F.interpolate(outputs['seg_logits'], size=images.shape[2:], mode="bilinear", align_corners=self.segmentor.module.align_corners),
**losses_cfgs['distillation_mib']
)
# --merge three losses
loss_total = pod_total_loss + kd_total_loss + seg_total_loss
# --perform back propagation
with amp.scale_loss(loss_total, self.optimizer) as scaled_loss_total:
scaled_loss_total.backward()
self.scheduler.step()
# --set zero gradient
self.scheduler.zerograd()
# --logging training loss info
seg_losses_log_dict.update(pod_losses_log_dict)
seg_losses_log_dict.update(kd_losses_log_dict)
seg_losses_log_dict.pop('loss_total')
seg_losses_log_dict['loss_total'] = loss_total.item()
losses_log_dict = self.loggingtraininginfo(seg_losses_log_dict, losses_log_dict, init_losses_log_dict)
'''featuresdistillation'''
def featuresdistillation(self, history_distillation_feats, distillation_feats, num_known_classes_list=None, dataset_type='VOCDataset', scale_factor=1.0, spp_scales=[4, 8, 12, 16, 20, 24]):
pod_total_loss = self.featuresdistillationchannel(history_distillation_feats, distillation_feats, num_known_classes_list, dataset_type) + \
self.featuresdistillationspatial(history_distillation_feats, distillation_feats, num_known_classes_list, dataset_type, spp_scales)
pod_total_loss = pod_total_loss * scale_factor
value = pod_total_loss.data.clone()
dist.all_reduce(value.div_(dist.get_world_size()))
pod_losses_log_dict = {'loss_pod': value.item()}
return pod_total_loss, pod_losses_log_dict
'''featuresdistillationchannel'''
@staticmethod
def featuresdistillationchannel(history_distillation_feats, distillation_feats, num_known_classes_list=None, dataset_type='VOCDataset'):
# assert and initialize
assert len(history_distillation_feats) == len(distillation_feats)
device = history_distillation_feats[0].device
loss = torch.tensor(0.).to(device)
num_known_classes = functools.reduce(lambda a, b: a + b, num_known_classes_list)
num_curtask_classes = num_known_classes_list[-1]
num_history_known_classes = num_known_classes - num_curtask_classes
distillation_feats = distillation_feats[:-1]
history_distillation_feats = history_distillation_feats[:-1]
# start to iter
for idx, (history_distillation, distillation) in enumerate(zip(history_distillation_feats, distillation_feats)):
if history_distillation.shape[1] != distillation.shape[1]:
distillation_tmp = torch.zeros_like(history_distillation).to(history_distillation.dtype).to(device)
distillation_tmp[:, 0] = distillation[:, 0] + distillation[:, num_history_known_classes:].sum(dim=1)
distillation_tmp[:, 1:] = distillation[:, 1:num_history_known_classes]
distillation = distillation_tmp
history_distillation, distillation = history_distillation ** 2, distillation ** 2
history_distillation_p = F.avg_pool2d(history_distillation.permute(0, 2, 1, 3), (3, 1), stride=1, padding=(1, 0))
distillation_p = F.avg_pool2d(distillation.permute(0, 2, 1, 3), (3, 1), stride=1, padding=(1, 0))
layer_loss = torch.frobenius_norm((history_distillation_p - distillation_p).view(history_distillation.shape[0], -1), dim=-1).mean()
if idx == len(history_distillation_feats) - 1:
if dataset_type == 'ADE20kDataset':
pckd_factor = 5e-7
elif dataset_type == 'VOCDataset':
pckd_factor = 0.0005
else:
if dataset_type == 'ADE20kDataset':
pckd_factor = 5e-6
elif dataset_type == 'VOCDataset':
pckd_factor = 0.01
loss = loss + layer_loss.mean() * math.sqrt(num_known_classes / num_curtask_classes) * pckd_factor
# summarize and return
loss = loss / len(history_distillation_feats)
return loss
'''featuresdistillationspatial'''
@staticmethod
def featuresdistillationspatial(history_distillation_feats, distillation_feats, num_known_classes_list=None, dataset_type='VOCDataset', spp_scales=[4, 8, 12, 16, 20, 24]):
# assert and initialize
assert len(history_distillation_feats) == len(distillation_feats)
device = history_distillation_feats[0].device
loss = torch.tensor(0.).to(device)
num_known_classes = functools.reduce(lambda a, b: a + b, num_known_classes_list)
num_curtask_classes = num_known_classes_list[-1]
num_history_known_classes = num_known_classes - num_curtask_classes
# start to iter
for idx, (history_distillation, distillation) in enumerate(zip(history_distillation_feats, distillation_feats)):
if history_distillation.shape[1] != distillation.shape[1]:
distillation_tmp = torch.zeros_like(history_distillation).to(history_distillation.dtype).to(device)
distillation_tmp[:, 0] = distillation[:, 0] + distillation[:, num_history_known_classes:].sum(dim=1)
distillation_tmp[:, 1:] = distillation[:, 1:num_history_known_classes]
distillation = distillation_tmp
history_distillation, distillation = history_distillation ** 2, distillation ** 2
layer_loss = torch.tensor(0.).to(device)
for spp_scale in spp_scales:
history_distillation_affinity = F.avg_pool2d(history_distillation, (spp_scale, spp_scale), stride=1, padding=spp_scale//2)
distillation_affinity = F.avg_pool2d(distillation, (spp_scale, spp_scale), stride=1, padding=spp_scale//2)
layer_loss = layer_loss + torch.frobenius_norm((history_distillation_affinity - distillation_affinity).view(history_distillation.shape[0], -1), dim=-1).mean()
layer_loss = layer_loss / len(spp_scales)
if idx == len(history_distillation_feats) - 1:
if dataset_type == 'ADE20kDataset':
pckd_factor = 5e-7
elif dataset_type == "VOCDataset":
pckd_factor = 0.0005
else:
if dataset_type == 'ADE20kDataset':
pckd_factor = 5e-6
elif dataset_type == 'VOCDataset':
pckd_factor = 0.01
loss = loss + layer_loss.mean() * math.sqrt(num_known_classes / num_curtask_classes) * pckd_factor
# summarize and return
loss = loss / len(history_distillation_feats)
return loss