-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_utils.py
117 lines (100 loc) · 5.64 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from preprocessing import F0LoudnessPreprocessor, LoudnessPreprocessor
from encoders import SupervisedEncoder, UnsupervisedEncoder
from models import SupervisedAutoencoder, UnsupervisedAutoencoder
from decoders import DecoderWithoutLatent, DecoderWithLatent
from losses import SpectralLoss, MultiLoss
from callbacks import ModelCheckpoint, CustomWandbCallback
from metrics import f0_scaled_L1_loss
from dataloader import make_supervised_dataset, make_unsupervised_dataset
# -------------------------------------- Models -------------------------------------------------
def make_supervised_model(config):
"""Creates the necessary components of a supervised ddsp using the config."""
preprocessor = F0LoudnessPreprocessor(timesteps=config['data']['preprocessing_time'])
if config['model']['encoder']:
encoder = SupervisedEncoder()
decoder = DecoderWithLatent(timesteps=config['model']['decoder_time'])
else:
encoder = None
decoder = DecoderWithoutLatent(timesteps=config['model']['decoder_time'])
assert config['loss']['type'] == 'spectral', 'The supervised ddsp can only be trained with spectral loss.'
loss = SpectralLoss(logmag_weight=config['loss']['logmag_weight'])
model = SupervisedAutoencoder(preprocessor=preprocessor,
encoder=encoder,
decoder=decoder,
add_reverb=config['model']['reverb'],
loss_fn=loss,
n_samples=config['data']['clip_dur']*config['data']['sample_rate'],
sample_rate=config['data']['sample_rate'],
tracker_names=['spec_loss'])
return model
# TODO: enc, dec params
# TODO metric fns
# preprocessor ?
def make_unsupervised_model(config):
preprocessor = LoudnessPreprocessor(timesteps=config['data']['preprocessing_time'])
encoder = UnsupervisedEncoder(timesteps=config['data']['preprocessing_time'])
decoder = DecoderWithLatent(timesteps=config['model']['decoder_time'])
loss = SpectralLoss() if config['loss']['type'] == 'spectral' else MultiLoss()
if loss.name== 'spectral_loss':
tracker_names = ['spec_loss']
else:
tracker_names = ['spec_loss', 'perc_loss', 'total_loss']
metric_fns = {"F0_recons_L1": f0_scaled_L1_loss}
model = UnsupervisedAutoencoder(
encoder=encoder,
decoder=decoder,
preprocessor=preprocessor,
loss_fn=loss,
tracker_names=tracker_names,
metric_fns=metric_fns,
add_reverb=config['model']['reverb'])
return model
# -------------------------------------- Optimizer -------------------------------------------------
def make_optimizer(config):
scheduler = ExponentialDecay(config['optimizer']['lr'],
decay_steps=config['optimizer']['decay_steps'],
decay_rate=config['optimizer']['decay_rate'])
optimizer = Adam(learning_rate=scheduler)
return optimizer
# -------------------------------------- Callbacks -------------------------------------------------
def create_callbacks(config, monitor):
# It looks ugly, but is necessary
if config['model']['dir']: # if dir specified, save there
model_dir = config['model']['dir']
if not config['wandb']:
callbacks = [ModelCheckpoint(save_dir=model_dir, monitor=monitor)]
else:
callbacks = [ModelCheckpoint(save_dir=model_dir, monitor=monitor),
CustomWandbCallback(config)]
else:
if config['wandb']['project_name'] is None: # define a save_dir
model_dir = "model_checkpoints/{}".format(config['run_name'])
callbacks = [ModelCheckpoint(save_dir=model_dir, monitor=monitor)]
else: # save to wandb.run.dir
wandb_callback = CustomWandbCallback(config)
model_dir = os.path.join(wandb_callback.wandb_run_dir, config['run_name'])
callbacks = [ModelCheckpoint(save_dir=model_dir, monitor=monitor),
wandb_callback]
return callbacks
# -------------------------------------- Datasets -------------------------------------------------
def make_supervised_dataset_from_config(config):
try: # deal with no mfcc_nfft control versions
mfcc_nfft = config['data']['mfcc_nfft']
except:
mfcc_nfft = 1024
return make_supervised_dataset(config['data']['path'],
mfcc=config['model']['encoder'], # extract mfcc or not
mfcc_nfft=mfcc_nfft, # number of fft coefficients
batch_size=config['training']['batch_size'],
sample_rate=config['data']['sample_rate'],
normalize=config['data']['normalize'],
conf_threshold=config['data']['confidence_threshold'])
def make_unsupervised_dataset_from_config(config):
return make_unsupervised_dataset(config['data']['path'],
batch_size=config['training']['batch_size'],
sample_rate=config['data']['sample_rate'],
normalize=config['data']['normalize'],
frame_rate=config['data']['preprocessing_time'])