@@ -61,27 +61,33 @@ def py():
61
61
use_flashlight = True
62
62
use_greedy = False
63
63
epochs = 500
64
- self_training_rounds = 4
64
+ self_training_rounds = 1
65
65
train_small = True
66
66
with_prior = True
67
67
empirical_prior = True
68
- aux_loss = True
68
+ prior_from_max = False
69
+ aux_loss = False
69
70
alt_decoder = True
70
- calc_last_pseudo_labels = True
71
- tune_hyperparameters = True
71
+ calc_last_pseudo_labels = False
72
+ tune_hyperparameters = False
73
+ from_scratch = False
72
74
73
- use_sum_criterion = False
75
+ use_sum_criterion = True
74
76
horizontal_prior = True
75
77
blank_prior = True
76
78
prior_gradient = False
79
+ empirical_prior_full_sum = False
80
+ prior_from_max_full_sum = False
77
81
LM_order = 2
78
- top_k = 1
79
- self_train_subset = None
82
+ top_k = 3
83
+ self_train_subset = 18000 # 18000
84
+
85
+ assert (empirical_prior_full_sum and empirical_prior ) or not empirical_prior_full_sum
80
86
81
87
if train_small :
82
88
epochs = 50
83
89
if self_training_rounds > 0 :
84
- self_epochs = 113 # 450, 225, 113, 75, 56, 45
90
+ self_epochs = 56 # 450, 225, 113, 75, 56, 45
85
91
86
92
decoder_hyperparameters = None
87
93
if use_greedy :
@@ -105,7 +111,7 @@ def py():
105
111
if with_prior :
106
112
decoder_hyperparameters ["prior_weight" ] = 0.3 # 0.2 if not using emprirical prior
107
113
108
- p0 = f"_p{ str (decoder_hyperparameters ['prior_weight' ]).replace ('.' , '' )} " + ("-emp" if empirical_prior else "" ) if with_prior else ""
114
+ p0 = f"_p{ str (decoder_hyperparameters ['prior_weight' ]).replace ('.' , '' )} " + ("-emp" if empirical_prior else ( "-from_max" if prior_from_max else "" ) ) if with_prior else ""
109
115
p1 = "sum" if decoder_hyperparameters ['log_add' ] else "max"
110
116
p2 = f"n{ decoder_hyperparameters ['nbest' ]} "
111
117
p3 = f"b{ decoder_hyperparameters ['beam_size' ]} "
@@ -131,7 +137,7 @@ def py():
131
137
else :
132
138
str_add = ""
133
139
134
- a0 = f"_p{ str (alt_decoder_hyperparameters ['prior_weight' ]).replace ('.' , '' )} " + ("-emp" if empirical_prior else "" ) if with_prior else ""
140
+ a0 = f"_p{ str (alt_decoder_hyperparameters ['prior_weight' ]).replace ('.' , '' )} " + ("-emp" if empirical_prior else ( "-from_max" if prior_from_max else "" ) ) if with_prior else ""
135
141
a1 = f"b{ alt_decoder_hyperparameters ['beam_size' ]} "
136
142
a2 = f"w{ str (alt_decoder_hyperparameters ['lm_weight' ]).replace ('.' , '' )} "
137
143
a3 = "_tune" if tune_hyperparameters else ""
@@ -157,29 +163,49 @@ def py():
157
163
} if self_training_rounds > 0 else None
158
164
159
165
for am , lm , prior in [
160
- (1 .0 , 0.0 , 0.55 )
166
+ (8 .0 , 0.01 , 0.08 )
161
167
]:
162
168
if use_sum_criterion :
163
- training_scales = {
164
- "am" : am ,
165
- "lm" : lm ,
166
- "prior" : prior
167
- }
169
+ if am != 1.0 or lm != 1.0 or prior != 1.0 :
170
+ scales_not_std = True
171
+ config_full_sum = {
172
+ "am_scale" : am ,
173
+ "lm_scale" : lm ,
174
+ "prior_scale" : prior
175
+ }
176
+ else :
177
+ scales_not_std = False
178
+ config_full_sum = {}
168
179
169
- if list (training_scales .values ()) == [1.0 ] * len (training_scales ):
170
- training_scales = None
180
+ if not horizontal_prior :
181
+ config_full_sum ["horizontal_prior" ] = horizontal_prior
182
+ if not blank_prior :
183
+ config_full_sum ["blank_prior" ] = blank_prior
184
+ if not prior_gradient :
185
+ config_full_sum ["prior_gradient" ] = prior_gradient
186
+ if top_k > 0 :
187
+ config_full_sum ["top_k" ] = top_k
188
+ if empirical_prior_full_sum :
189
+ config_full_sum ["empirical_prior" ] = True
190
+ if prior_from_max_full_sum :
191
+ config_full_sum ["max_prior" ] = True
192
+
193
+ # This is to change the hash when we made chnages in the loss function
194
+ config_full_sum ["version" ] = 1
171
195
172
196
sum_str = f"-full_sum" + \
173
- (f"_p{ str (training_scales [ 'prior ' ]).replace ('.' , '' )} _l{ str (training_scales [ 'lm ' ]).replace ('.' , '' )} _a{ str (training_scales [ 'am ' ]).replace ('.' , '' )} " if training_scales else "" ) + \
197
+ (f"_p{ str (config_full_sum [ 'prior_scale ' ]).replace ('.' , '' )} _l{ str (config_full_sum [ 'lm_scale ' ]).replace ('.' , '' )} _a{ str (config_full_sum [ 'am_scale ' ]).replace ('.' , '' )} " if scales_not_std else "" ) + \
174
198
(f"_LMorder{ LM_order } " if LM_order > 2 else "" ) + \
175
199
(f"_topK{ top_k } " if top_k > 0 else "" ) + \
200
+ ("_emp" if empirical_prior_full_sum else "" ) + \
201
+ ("_max_pr" if not empirical_prior_full_sum and prior_from_max_full_sum else "" ) + \
176
202
("_wo_hor_pr" if not horizontal_prior else "" ) + \
177
203
("_wo_blank_pr" if not blank_prior else "" ) + \
178
204
("_wo_pr_grad" if not prior_gradient else "" )
179
205
180
206
alias_name = f"ctc-baseline" + \
181
207
(sum_str if use_sum_criterion else "" ) + \
182
- (f"-self_training_{ self_training_rounds } " + (f"_s{ self_train_subset } " if self_train_subset is not None else "" ) + (f"_e{ self_epochs } " if self_epochs != 450 else "" ) if self_training_rounds > 0 else "" ) + \
208
+ (f"-self_training_{ self_training_rounds } " + ("_from_scratch" if from_scratch else "" ) + ( f"_s{ self_train_subset } " if self_train_subset is not None else "" ) + (f"_e{ self_epochs } " if self_epochs != 450 else "" ) if self_training_rounds > 0 else "" ) + \
183
209
(f"-wo_aux_loss" if not aux_loss else "" ) + \
184
210
(f"-ds100h" if train_small else "" ) + \
185
211
f"-{ vocab } " + \
@@ -194,22 +220,20 @@ def py():
194
220
model_config = {"enc_conformer_layer" : enc_conformer_layer_default , "feature_batch_norm" : True },
195
221
config_updates = config_updates ,
196
222
config_updates_self_training = config_updates_self_training ,
223
+ config_full_sum = config_full_sum if use_sum_criterion else None ,
197
224
vocab = vocab ,
198
225
self_training_rounds = self_training_rounds ,
199
226
train_small = train_small ,
200
227
with_prior = with_prior ,
201
228
empirical_prior = empirical_prior ,
229
+ prior_from_max = prior_from_max ,
202
230
use_sum_criterion = use_sum_criterion ,
203
231
aux_loss = aux_loss ,
204
- horizontal_prior = horizontal_prior ,
205
- blank_prior = blank_prior ,
206
- prior_gradient = prior_gradient ,
207
232
LM_order = LM_order ,
208
- top_k = top_k ,
209
- training_scales = training_scales if use_sum_criterion else None ,
210
233
self_train_subset = self_train_subset ,
211
234
calc_last_pseudo_labels = calc_last_pseudo_labels ,
212
235
tune_hyperparameters = tune_hyperparameters ,
236
+ from_scratch = from_scratch ,
213
237
)
214
238
215
239
@@ -231,6 +255,7 @@ def train_exp(
231
255
model_config : Optional [Dict [str , Any ]] = None ,
232
256
config_updates : Optional [Dict [str , Any ]] = None ,
233
257
config_updates_self_training : Optional [Dict [str , Any ]] = None ,
258
+ config_full_sum : Optional [Dict [str , Any ]] = None ,
234
259
config_deletes : Optional [Sequence [str ]] = None ,
235
260
post_config_updates : Optional [Dict [str , Any ]] = None ,
236
261
epilog : Sequence [serialization .SerializerObject ] = (),
@@ -244,17 +269,14 @@ def train_exp(
244
269
train_small : bool = False ,
245
270
with_prior : bool = False ,
246
271
empirical_prior : bool = False ,
272
+ prior_from_max : bool = False ,
247
273
use_sum_criterion : bool = False ,
248
274
aux_loss : bool = False ,
249
- horizontal_prior : bool = True ,
250
- blank_prior : bool = True ,
251
- prior_gradient : bool = True ,
252
275
LM_order : int = 2 ,
253
- top_k : int = 0 ,
254
- training_scales : Optional [Dict [str , float ]] = None ,
255
276
self_train_subset : Optional [int ] = None ,
256
277
calc_last_pseudo_labels : bool = False ,
257
278
tune_hyperparameters : bool = False ,
279
+ from_scratch : bool = False ,
258
280
) -> Optional [ModelWithCheckpoints ]:
259
281
"""
260
282
Train experiment
@@ -329,10 +351,11 @@ def train_exp(
329
351
save_pseudo_labels = (pseudo_labels_ds , train_100_ds ) if calc_last_pseudo_labels or self_training_rounds > 0 else None ,
330
352
calculate_pseudo_label_scores = True , # NOTE: breaks hash
331
353
recog_post_proc_funcs = recog_post_proc_funcs ,
332
- num_shards_recog = 16 , # NOTE: breaks hash
354
+ # num_shards_recog=16, # NOTE: breaks hash
333
355
num_shards_pseudo = 64 ,
334
356
# num_shards_prior=64,
335
357
is_last = self_training_rounds == 0 ,
358
+ prior_from_max = prior_from_max ,
336
359
empirical_prior = emp_prior if with_prior and empirical_prior else None ,
337
360
)
338
361
@@ -361,26 +384,19 @@ def train_exp(
361
384
362
385
if use_sum_criterion :
363
386
train_def = ctc_sum_training
387
+ config_self = dict_update_deep (config_self , config_full_sum )
364
388
config_self ["lm_path" ] = get_count_based_n_gram (task .train_dataset .vocab , LM_order )
365
389
366
- if not horizontal_prior :
367
- config_self ["horizontal_prior" ] = horizontal_prior
368
- if not blank_prior :
369
- config_self ["blank_prior" ] = blank_prior
370
- if training_scales :
371
- config_self ["am_scale" ] = training_scales ["am" ]
372
- config_self ["lm_scale" ] = training_scales ["lm" ]
373
- config_self ["prior_scale" ] = training_scales ["prior" ]
374
- if not prior_gradient :
375
- config_self ["prior_gradient" ] = prior_gradient
376
- if top_k > 0 :
377
- config_self ["top_k" ] = top_k
390
+ if config_self .get ("empirical_prior" , False ):
391
+ config_self ["empirical_prior" ] = emp_prior
378
392
379
393
# When testing on a smaller subset we only want one gpu
380
394
if self_train_subset is not None :
381
395
config_self ["__num_processes" ] = 1
382
396
# config_self["learning_rate_piecewise_steps"] = [4_500, 9_000, 10_000]
383
397
config_self ["learning_rate_piecewise_steps" ] = [2_250 , 4_500 , 5_000 ]
398
+ peak_lr = 1e-4
399
+ config_self ["learning_rate_piecewise_values" ] = [peak_lr * 1.001e-1 , peak_lr , peak_lr * 3e-2 , peak_lr * 3e-3 ]
384
400
if not aux_loss :
385
401
config_self .pop ("aux_loss_layers" )
386
402
@@ -395,7 +411,10 @@ def train_exp(
395
411
config_self ["learning_rate_piecewise_values" ] = [peak_lr * 1e-1 , peak_lr , peak_lr * 3e-2 , peak_lr * 3e-3 ]
396
412
config_self ["learning_rate_piecewise_steps" ] = [20_000 ] + config_self ["learning_rate_piecewise_steps" ][1 :]
397
413
398
- init_checkpoint = model_with_checkpoint [i ].get_last_fixed_epoch ().checkpoint
414
+ if i == 0 and from_scratch :
415
+ init_checkpoint = None
416
+ else :
417
+ init_checkpoint = model_with_checkpoint [i ].get_last_fixed_epoch ().checkpoint
399
418
400
419
model_with_checkpoint .append (train (
401
420
prefix_self_training ,
@@ -409,7 +428,7 @@ def train_exp(
409
428
num_epochs = num_epochs ,
410
429
gpu_mem = gpu_mem ,
411
430
num_processes = num_processes ,
412
- time_rqmt = time_rqmt if time_rqmt else ((10 if self_train_subset else 156 ) if use_sum_criterion else 156 ),
431
+ time_rqmt = time_rqmt if time_rqmt else ((4 if self_train_subset else 156 ) if use_sum_criterion else 156 ),
413
432
))
414
433
train_job = model_with_checkpoint [i + 1 ].get_training_job ()
415
434
if env_updates :
@@ -438,6 +457,7 @@ def train_exp(
438
457
recog_post_proc_funcs = recog_post_proc_funcs ,
439
458
num_shards_recog = 16 , # NOTE: breaks hash
440
459
num_shards_prior = 64 ,
460
+ prior_from_max = prior_from_max ,
441
461
empirical_prior = emp_prior if with_prior and empirical_prior else None ,
442
462
return_summary = True
443
463
)
@@ -457,6 +477,7 @@ def train_exp(
457
477
recog_post_proc_funcs = recog_post_proc_funcs ,
458
478
num_shards_recog = 16 , # NOTE: breaks hash
459
479
num_shards_prior = 64 ,
480
+ prior_from_max = prior_from_max ,
460
481
empirical_prior = emp_prior if with_prior and empirical_prior else None ,
461
482
return_summary = True
462
483
)
@@ -480,6 +501,7 @@ def train_exp(
480
501
num_shards_pseudo = 64 ,
481
502
num_shards_prior = 64 ,
482
503
is_last = i + 1 == self_training_rounds ,
504
+ prior_from_max = prior_from_max ,
483
505
empirical_prior = emp_prior if with_prior and empirical_prior else None ,
484
506
)
485
507
@@ -1004,28 +1026,44 @@ def ctc_training(*, model: Model, data: rf.Tensor, data_spatial_dim: Dim, target
1004
1026
ctc_training : TrainDef [Model ]
1005
1027
ctc_training .learning_rate_control_error_measure = "ctc"
1006
1028
1007
- def ctc_sum_training (* , model : Model , data : rf .Tensor , data_spatial_dim : Dim , lm_path : tk .Path ):
1029
+ def ctc_sum_training (* , model : Model , data : rf .Tensor , data_spatial_dim : Dim , lm_path : tk .Path , seq_tags : rf . Tensor = None ):
1008
1030
"""Function is run within RETURNN."""
1009
1031
from returnn .config import get_global_config
1010
1032
from .sum_criterion import sum_loss , safe_logsumexp
1011
1033
1012
1034
# torch.autograd.set_detect_anomaly(True)
1013
1035
1014
- def _calc_log_prior (log_probs : torch .Tensor , lengths : torch .Tensor ) -> torch .Tensor :
1036
+ def _calc_log_prior (log_probs : torch .Tensor , lengths : torch .Tensor , use_max : bool = False , separate_eos : bool = False ) -> torch .Tensor :
1015
1037
lengths = lengths .to (log_probs .device )
1016
1038
assert lengths .size (0 ) == log_probs .size (0 ), "Prior calculation batch lengths are not the same (full_sum)!"
1017
1039
1018
- mask = torch .arange (log_probs .size (1 ), device = log_probs .device ).expand (log_probs .size (0 ), - 1 ) < lengths .unsqueeze (1 )
1019
- mask = torch .where (mask , 0.0 , float ("-inf" ))
1040
+ mask_bool = torch .arange (log_probs .size (1 ), device = log_probs .device ).expand (log_probs .size (0 ), - 1 ) < lengths .unsqueeze (1 )
1041
+ mask = torch .where (mask_bool , 0.0 , float ("-inf" ))
1020
1042
mask = mask .unsqueeze (- 1 ).expand (- 1 , - 1 , log_probs .size (2 ))
1021
1043
log_probs = log_probs + mask
1022
1044
1023
1045
sum_frames = lengths .sum ()
1024
- log_sum_probs = torch .full ([log_probs .size (2 ) + 1 ,], float ("-inf" ), device = log_probs .device )
1025
- log_sum_probs [1 :- 1 ] = safe_logsumexp (safe_logsumexp (log_probs [:,:,1 :], dim = 0 ), dim = 0 ) # Sum over batch and time
1026
- log_sum_probs [0 ] = safe_logsumexp (log_probs [:,0 ,0 ], dim = 0 ) # BOS prob
1027
- log_sum_probs [- 1 ] = safe_logsumexp (safe_logsumexp (log_probs [:,1 :,0 ], dim = 0 ), dim = 0 ) # EOS prob
1028
-
1046
+ if use_max :
1047
+ if separate_eos :
1048
+ raise NotImplementedError ("Separate EOS not implemented for max prior" )
1049
+ else :
1050
+ argmaxs = log_probs .argmax (dim = 2 )
1051
+ argmaxs = argmaxs .flatten ()
1052
+ argmaxs = argmaxs [mask_bool .flatten ()]
1053
+ assert argmaxs .size (0 ) == sum_frames , f"Prior calculation frame count does not match (max) ({ argmaxs .size (0 )} != { sum_frames } )"
1054
+ sum_probs = argmaxs .bincount (minlength = log_probs .size (2 ))
1055
+ sum_frames += (sum_probs == 0 ).sum ()
1056
+ sum_probs = torch .where (sum_probs == 0 , 1 , sum_probs )
1057
+ log_sum_probs = sum_probs .log ()
1058
+ else :
1059
+ if separate_eos :
1060
+ log_sum_probs = torch .full ((log_probs .size (2 ) + 1 ,), float ("-inf" ), device = log_probs .device )
1061
+ log_sum_probs [1 :- 1 ] = safe_logsumexp (safe_logsumexp (log_probs [:,:,1 :], dim = 0 ), dim = 0 ) # Sum over batch and time
1062
+ log_sum_probs [0 ] = safe_logsumexp (log_probs [:,0 ,0 ], dim = 0 ) # BOS prob
1063
+ log_sum_probs [- 1 ] = safe_logsumexp (safe_logsumexp (log_probs [:,1 :,0 ], dim = 0 ), dim = 0 ) # EOS prob
1064
+ else :
1065
+ log_sum_probs = safe_logsumexp (safe_logsumexp (log_probs , dim = 0 ), dim = 0 )
1066
+
1029
1067
log_mean_probs = log_sum_probs - sum_frames .log ()
1030
1068
1031
1069
with torch .no_grad ():
@@ -1047,6 +1085,8 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
1047
1085
horizontal_prior = config .bool ("horizontal_prior" , True )
1048
1086
blank_prior = config .bool ("blank_prior" , True )
1049
1087
prior_gradient = config .bool ("prior_gradient" , True )
1088
+ empirical_prior = config .typed_value ("empirical_prior" , None )
1089
+ max_prior = config .bool ("max_prior" , False )
1050
1090
top_k = config .int ("top_k" , 0 )
1051
1091
use_prior = prior_scale > 0.0
1052
1092
@@ -1066,6 +1106,7 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
1066
1106
1067
1107
collected_outputs = {}
1068
1108
logits , enc , enc_spatial_dim = model (data , in_spatial_dim = data_spatial_dim , collected_outputs = collected_outputs )
1109
+
1069
1110
if aux_loss_layers :
1070
1111
for i , layer_idx in enumerate (aux_loss_layers ):
1071
1112
if layer_idx > len (model .encoder .layers ):
@@ -1075,9 +1116,14 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
1075
1116
aux_log_probs = model .log_probs_wb_from_logits (aux_logits )
1076
1117
aux_log_probs = aux_log_probs .raw_tensor
1077
1118
if use_prior :
1078
- aux_log_prior = _calc_log_prior (aux_log_probs , enc_spatial_dim .dyn_size_ext .raw_tensor )
1079
- if not prior_gradient :
1080
- aux_log_prior = aux_log_prior .detach ()
1119
+ if empirical_prior is not None :
1120
+ aux_log_prior = np .loadtxt (empirical_prior , dtype = "float32" )
1121
+ aux_log_prior = torch .tensor (aux_log_prior , device = log_probs .device )
1122
+ assert aux_log_prior .size (0 ) == log_probs .size (2 ), "Empirical prior size does not match (full_sum)!"
1123
+ else :
1124
+ aux_log_prior = _calc_log_prior (aux_log_probs , enc_spatial_dim .dyn_size_ext .raw_tensor , use_max = max_prior )
1125
+ if not prior_gradient :
1126
+ aux_log_prior = aux_log_prior .detach ()
1081
1127
else :
1082
1128
aux_log_prior = None
1083
1129
# (B, T, F) -> (T, B, F)
@@ -1106,13 +1152,32 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
1106
1152
custom_inv_norm_factor = enc_spatial_dim .get_size_tensor (),
1107
1153
use_normalized_loss = use_normalized_loss ,
1108
1154
)
1155
+
1156
+ fixed_seqs = ["train-other-500/5756-305214-0041/5756-305214-0041" ] # MONICA DREW FRESH HOPE FROM HER SON'S WRITINGS THEY WERE FULL OF NOBLE THOUGHTS AND HIGH ASPIRATIONS
1157
+ print_for_idx = []
1158
+
1159
+ seq_tags = seq_tags .raw_tensor
1160
+ for seq in fixed_seqs :
1161
+ if seq in seq_tags :
1162
+ idx = np .where (seq_tags == seq )[0 ]
1163
+ print ("Found seq" , seq , enc_spatial_dim .dyn_size_ext .raw_tensor [idx ])
1164
+ print_for_idx .append (idx [0 ])
1165
+
1166
+ # seq = seq_tags[0]
1167
+ # idx = np.where(seq_tags == seq)[0]
1168
+ # print_for_idx.append(idx[0])
1109
1169
1110
1170
log_probs = model .log_probs_wb_from_logits (logits )
1111
1171
log_probs = log_probs .raw_tensor
1112
1172
if use_prior :
1113
- log_prior = _calc_log_prior (log_probs , enc_spatial_dim .dyn_size_ext .raw_tensor )
1114
- if not prior_gradient :
1115
- log_prior = log_prior .detach ()
1173
+ if empirical_prior is not None :
1174
+ log_prior = np .loadtxt (empirical_prior , dtype = "float32" )
1175
+ log_prior = torch .tensor (log_prior , device = log_probs .device )
1176
+ assert log_prior .size (0 ) == log_probs .size (2 ), "Empirical prior size does not match (full_sum)!"
1177
+ else :
1178
+ log_prior = _calc_log_prior (log_probs , enc_spatial_dim .dyn_size_ext .raw_tensor , use_max = max_prior )
1179
+ if not prior_gradient :
1180
+ log_prior = log_prior .detach ()
1116
1181
else :
1117
1182
log_prior = None
1118
1183
# (B, T, F) -> (T, B, F)
@@ -1132,7 +1197,8 @@ def _calc_log_prior(log_probs: torch.Tensor, lengths: torch.Tensor) -> torch.Ten
1132
1197
blank_idx = model .blank_idx ,
1133
1198
eos_idx = model .eos_idx ,
1134
1199
unk_idx = 1 ,
1135
- device = log_probs .device
1200
+ device = log_probs .device ,
1201
+ print_best_path_for_idx = print_for_idx
1136
1202
)
1137
1203
loss = rtf .TorchBackend .convert_to_tensor (loss , dims = [batch_dim ], dtype = "float32" , name = f"full_sum" )
1138
1204
loss .mark_as_loss (
0 commit comments