-
Notifications
You must be signed in to change notification settings - Fork 27
/
main.py
373 lines (306 loc) · 16.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
# -*- coding: utf-8 -*-
#########################################################################
# This file is derived from Curious AI/mean-teacher, under the Creative Commons Attribution-NonCommercial
# Copyright Nicolas Turpault, Romain Serizel, Justin Salamon, Ankit Parag Shah, 2019, v1.0
# This software is distributed under the terms of the License MIT
#########################################################################
import argparse
import os
import time
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch import nn
from utils import ramps
from DatasetDcase2019Task4 import DatasetDcase2019Task4
from DataLoad import DataLoadDf, ConcatDataset, MultiStreamBatchSampler
from utils.Scaler import Scaler
from TestModel import test_model
from evaluation_measures import get_f_measure_by_class, get_predictions, audio_tagging_results, compute_strong_metrics
from models.CRNN import CRNN
import config as cfg
from utils.utils import ManyHotEncoder, create_folder, SaveBest, to_cuda_if_available, weights_init, \
get_transforms, AverageMeterSet
from utils.Logger import LOG
def adjust_learning_rate(optimizer, rampup_value, rampdown_value):
# LR warm-up to handle large minibatch sizes from https://arxiv.org/abs/1706.02677
lr = rampup_value * rampdown_value * cfg.max_learning_rate
beta1 = rampdown_value * cfg.beta1_before_rampdown + (1. - rampdown_value) * cfg.beta1_after_rampdown
beta2 = (1. - rampup_value) * cfg.beta2_during_rampdup + rampup_value * cfg.beta2_after_rampup
weight_decay = (1 - rampup_value) * cfg.weight_decay_during_rampup + cfg.weight_decay_after_rampup * rampup_value
for param_group in optimizer.param_groups:
param_group['lr'] = lr
param_group['betas'] = (beta1, beta2)
param_group['weight_decay'] = weight_decay
def update_ema_variables(model, ema_model, alpha, global_step):
# Use the true average until the exponential average is more correct
alpha = min(1 - 1 / (global_step + 1), alpha)
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
def train(train_loader, model, optimizer, epoch, ema_model=None, weak_mask=None, strong_mask=None):
""" One epoch of a Mean Teacher model
:param train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch.
Should return 3 values: teacher input, student input, labels
:param model: torch.Module, model to be trained, should return a weak and strong prediction
:param optimizer: torch.Module, optimizer used to train the model
:param epoch: int, the current epoch of training
:param ema_model: torch.Module, student model, should return a weak and strong prediction
:param weak_mask: mask the batch to get only the weak labeled data (used to calculate the loss)
:param strong_mask: mask the batch to get only the strong labeled data (used to calcultate the loss)
"""
class_criterion = nn.BCELoss()
consistency_criterion = nn.MSELoss()
[class_criterion, consistency_criterion] = to_cuda_if_available(
[class_criterion, consistency_criterion])
meters = AverageMeterSet()
LOG.debug("Nb batches: {}".format(len(train_loader)))
start = time.time()
rampup_length = len(train_loader) * cfg.n_epoch // 2
for i, (batch_input, ema_batch_input, target) in enumerate(train_loader):
global_step = epoch * len(train_loader) + i
if global_step < rampup_length:
rampup_value = ramps.sigmoid_rampup(global_step, rampup_length)
else:
rampup_value = 1.0
# Todo check if this improves the performance
# adjust_learning_rate(optimizer, rampup_value, rampdown_value)
meters.update('lr', optimizer.param_groups[0]['lr'])
[batch_input, ema_batch_input, target] = to_cuda_if_available([batch_input, ema_batch_input, target])
LOG.debug(batch_input.mean())
# Outputs
strong_pred_ema, weak_pred_ema = ema_model(ema_batch_input)
strong_pred_ema = strong_pred_ema.detach()
weak_pred_ema = weak_pred_ema.detach()
strong_pred, weak_pred = model(batch_input)
loss = None
# Weak BCE Loss
# Take the max in the time axis
target_weak = target.max(-2)[0]
if weak_mask is not None:
weak_class_loss = class_criterion(weak_pred[weak_mask], target_weak[weak_mask])
ema_class_loss = class_criterion(weak_pred_ema[weak_mask], target_weak[weak_mask])
if i == 0:
LOG.debug("target: {}".format(target.mean(-2)))
LOG.debug("Target_weak: {}".format(target_weak))
LOG.debug("Target_weak mask: {}".format(target_weak[weak_mask]))
LOG.debug(weak_class_loss)
LOG.debug("rampup_value: {}".format(rampup_value))
meters.update('weak_class_loss', weak_class_loss.item())
meters.update('Weak EMA loss', ema_class_loss.item())
loss = weak_class_loss
# Strong BCE loss
if strong_mask is not None:
strong_class_loss = class_criterion(strong_pred[strong_mask], target[strong_mask])
meters.update('Strong loss', strong_class_loss.item())
strong_ema_class_loss = class_criterion(strong_pred_ema[strong_mask], target[strong_mask])
meters.update('Strong EMA loss', strong_ema_class_loss.item())
if loss is not None:
loss += strong_class_loss
else:
loss = strong_class_loss
# Teacher-student consistency cost
if ema_model is not None:
consistency_cost = cfg.max_consistency_cost * rampup_value
meters.update('Consistency weight', consistency_cost)
# Take consistency about strong predictions (all data)
consistency_loss_strong = consistency_cost * consistency_criterion(strong_pred,
strong_pred_ema)
meters.update('Consistency strong', consistency_loss_strong.item())
if loss is not None:
loss += consistency_loss_strong
else:
loss = consistency_loss_strong
meters.update('Consistency weight', consistency_cost)
# Take consistency about weak predictions (all data)
consistency_loss_weak = consistency_cost * consistency_criterion(weak_pred, weak_pred_ema)
meters.update('Consistency weak', consistency_loss_weak.item())
if loss is not None:
loss += consistency_loss_weak
else:
loss = consistency_loss_weak
assert not (np.isnan(loss.item()) or loss.item() > 1e5), 'Loss explosion: {}'.format(loss.item())
assert not loss.item() < 0, 'Loss problem, cannot be negative'
meters.update('Loss', loss.item())
# compute gradient and do optimizer step
optimizer.zero_grad()
loss.backward()
optimizer.step()
global_step += 1
if ema_model is not None:
update_ema_variables(model, ema_model, 0.999, global_step)
epoch_time = time.time() - start
LOG.info(
'Epoch: {}\t'
'Time {:.2f}\t'
'{meters}'.format(
epoch, epoch_time, meters=meters))
if __name__ == '__main__':
LOG.info("MEAN TEACHER")
parser = argparse.ArgumentParser(description="")
parser.add_argument("-s", '--subpart_data', type=int, default=None, dest="subpart_data",
help="Number of files to be used. Useful when testing on small number of files.")
parser.add_argument("-n", '--no_synthetic', dest='no_synthetic', action='store_true', default=False,
help="Not using synthetic labels during training")
f_args = parser.parse_args()
reduced_number_of_data = f_args.subpart_data
no_synthetic = f_args.no_synthetic
LOG.info("subpart_data = {}".format(reduced_number_of_data))
LOG.info("Using synthetic data = {}".format(not no_synthetic))
if no_synthetic:
add_dir_model_name = "_no_synthetic"
else:
add_dir_model_name = "_with_synthetic"
store_dir = os.path.join("stored_data", "MeanTeacher" + add_dir_model_name)
saved_model_dir = os.path.join(store_dir, "model")
saved_pred_dir = os.path.join(store_dir, "predictions")
create_folder(store_dir)
create_folder(saved_model_dir)
create_folder(saved_pred_dir)
pooling_time_ratio = cfg.pooling_time_ratio # --> Be careful, it depends of the model time axis pooling
# ##############
# DATA
# ##############
dataset = DatasetDcase2019Task4(cfg.workspace,
base_feature_dir=os.path.join(cfg.workspace, "dataset", "features"),
save_log_feature=False)
weak_df = dataset.initialize_and_get_df(cfg.weak, reduced_number_of_data)
unlabel_df = dataset.initialize_and_get_df(cfg.unlabel, reduced_number_of_data)
# Event if synthetic not used for training, used on validation purpose
synthetic_df = dataset.initialize_and_get_df(cfg.synthetic, reduced_number_of_data, download=False)
validation_df = dataset.initialize_and_get_df(cfg.validation, reduced_number_of_data)
classes = cfg.classes
many_hot_encoder = ManyHotEncoder(classes, n_frames=cfg.max_frames // pooling_time_ratio)
transforms = get_transforms(cfg.max_frames)
# Divide weak in train and valid
train_weak_df = weak_df.sample(frac=0.8, random_state=26)
valid_weak_df = weak_df.drop(train_weak_df.index).reset_index(drop=True)
train_weak_df = train_weak_df.reset_index(drop=True)
LOG.debug(valid_weak_df.event_labels.value_counts())
# Divide synthetic in train and valid
filenames_train = synthetic_df.filename.drop_duplicates().sample(frac=0.8, random_state=26)
train_synth_df = synthetic_df[synthetic_df.filename.isin(filenames_train)]
valid_synth_df = synthetic_df.drop(train_synth_df.index).reset_index(drop=True)
# Put train_synth in frames so many_hot_encoder can work.
# Not doing it for valid, because not using labels (when prediction) and event based metric expect sec.
train_synth_df.onset = train_synth_df.onset * cfg.sample_rate // cfg.hop_length // pooling_time_ratio
train_synth_df.offset = train_synth_df.offset * cfg.sample_rate // cfg.hop_length // pooling_time_ratio
LOG.debug(valid_synth_df.event_label.value_counts())
train_weak_data = DataLoadDf(train_weak_df, dataset.get_feature_file, many_hot_encoder.encode_strong_df,
transform=transforms)
unlabel_data = DataLoadDf(unlabel_df, dataset.get_feature_file, many_hot_encoder.encode_strong_df,
transform=transforms)
train_synth_data = DataLoadDf(train_synth_df, dataset.get_feature_file, many_hot_encoder.encode_strong_df,
transform=transforms)
if not no_synthetic:
list_dataset = [train_weak_data, unlabel_data, train_synth_data]
batch_sizes = [cfg.batch_size//4, cfg.batch_size//2, cfg.batch_size//4]
strong_mask = slice(cfg.batch_size//4 + cfg.batch_size//2, cfg.batch_size)
else:
list_dataset = [train_weak_data, unlabel_data]
batch_sizes = [cfg.batch_size // 4, 3 * cfg.batch_size // 4]
strong_mask = None
# Assume weak data is always the first one
weak_mask = slice(batch_sizes[0])
scaler = Scaler()
scaler.calculate_scaler(ConcatDataset(list_dataset))
LOG.debug(scaler.mean_)
transforms = get_transforms(cfg.max_frames, scaler, augment_type="noise")
for i in range(len(list_dataset)):
list_dataset[i].set_transform(transforms)
concat_dataset = ConcatDataset(list_dataset)
sampler = MultiStreamBatchSampler(concat_dataset,
batch_sizes=batch_sizes)
training_data = DataLoader(concat_dataset, batch_sampler=sampler)
transforms_valid = get_transforms(cfg.max_frames, scaler=scaler)
valid_synth_data = DataLoadDf(valid_synth_df, dataset.get_feature_file, many_hot_encoder.encode_strong_df,
transform=transforms_valid)
valid_weak_data = DataLoadDf(valid_weak_df, dataset.get_feature_file, many_hot_encoder.encode_weak,
transform=transforms_valid)
# Eval 2018
eval_2018_df = dataset.initialize_and_get_df(cfg.eval2018, reduced_number_of_data)
eval_2018 = DataLoadDf(eval_2018_df, dataset.get_feature_file, many_hot_encoder.encode_strong_df,
transform=transforms_valid)
# ##############
# Model
# ##############
crnn_kwargs = cfg.crnn_kwargs
crnn = CRNN(**crnn_kwargs)
crnn_ema = CRNN(**crnn_kwargs)
crnn.apply(weights_init)
crnn_ema.apply(weights_init)
LOG.info(crnn)
for param in crnn_ema.parameters():
param.detach_()
optim_kwargs = {"lr": 0.001, "betas": (0.9, 0.999)}
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, crnn.parameters()), **optim_kwargs)
bce_loss = nn.BCELoss()
state = {
'model': {"name": crnn.__class__.__name__,
'args': '',
"kwargs": crnn_kwargs,
'state_dict': crnn.state_dict()},
'model_ema': {"name": crnn_ema.__class__.__name__,
'args': '',
"kwargs": crnn_kwargs,
'state_dict': crnn_ema.state_dict()},
'optimizer': {"name": optimizer.__class__.__name__,
'args': '',
"kwargs": optim_kwargs,
'state_dict': optimizer.state_dict()},
"pooling_time_ratio": pooling_time_ratio,
"scaler": scaler.state_dict(),
"many_hot_encoder": many_hot_encoder.state_dict()
}
save_best_cb = SaveBest("sup")
# ##############
# Train
# ##############
for epoch in range(cfg.n_epoch):
crnn = crnn.train()
crnn_ema = crnn_ema.train()
[crnn, crnn_ema] = to_cuda_if_available([crnn, crnn_ema])
train(training_data, crnn, optimizer, epoch, ema_model=crnn_ema, weak_mask=weak_mask, strong_mask=strong_mask)
crnn = crnn.eval()
LOG.info("\n ### Valid synthetic metric ### \n")
predictions = get_predictions(crnn, valid_synth_data, many_hot_encoder.decode_strong, pooling_time_ratio,
save_predictions=None)
valid_events_metric = compute_strong_metrics(predictions, valid_synth_df)
LOG.info("\n ### Valid weak metric ### \n")
weak_metric = get_f_measure_by_class(crnn, len(classes),
DataLoader(valid_weak_data, batch_size=cfg.batch_size))
LOG.info("Weak F1-score per class: \n {}".format(pd.DataFrame(weak_metric * 100, many_hot_encoder.labels)))
LOG.info("Weak F1-score macro averaged: {}".format(np.mean(weak_metric)))
state['model']['state_dict'] = crnn.state_dict()
state['model_ema']['state_dict'] = crnn_ema.state_dict()
state['optimizer']['state_dict'] = optimizer.state_dict()
state['epoch'] = epoch
state['valid_metric'] = valid_events_metric.results()
if cfg.checkpoint_epochs is not None and (epoch + 1) % cfg.checkpoint_epochs == 0:
model_fname = os.path.join(saved_model_dir, "baseline_epoch_" + str(epoch))
torch.save(state, model_fname)
if cfg.save_best:
if not no_synthetic:
global_valid = valid_events_metric.results_class_wise_average_metrics()['f_measure']['f_measure']
global_valid = global_valid + np.mean(weak_metric)
else:
global_valid = np.mean(weak_metric)
if save_best_cb.apply(global_valid):
model_fname = os.path.join(saved_model_dir, "baseline_best")
torch.save(state, model_fname)
if cfg.save_best:
model_fname = os.path.join(saved_model_dir, "baseline_best")
state = torch.load(model_fname)
LOG.info("testing model: {}".format(model_fname))
else:
LOG.info("testing model of last epoch: {}".format(cfg.n_epoch))
# ##############
# Validation
# ##############
predicitons_fname = os.path.join(saved_pred_dir, "baseline_validation.tsv")
test_model(state, cfg.validation, reduced_number_of_data, predicitons_fname)
# ##############
# Evaluation
# ##############
predicitons_eval2019_fname = os.path.join(saved_pred_dir, "baseline_eval2019.tsv")
test_model(state, cfg.eval_desed, reduced_number_of_data, predicitons_eval2019_fname)