@@ -290,8 +290,11 @@ def py():
290
290
291
291
# Rescoring.
292
292
from .ctc import model_recog as model_recog_ctc_only , _ctc_model_def_blank_idx
293
- from i6_experiments .users .zeyer .decoding .lm_rescoring import lm_framewise_prior_rescore
294
- from i6_experiments .users .zeyer .decoding .prior_rescoring import Prior
293
+ from i6_experiments .users .zeyer .decoding .lm_rescoring import (
294
+ lm_framewise_prior_rescore ,
295
+ lm_labelwise_prior_rescore ,
296
+ )
297
+ from i6_experiments .users .zeyer .decoding .prior_rescoring import Prior , PriorRemoveLabelRenormJob
295
298
from i6_experiments .users .zeyer .datasets .utils .vocab import (
296
299
ExtractVocabLabelsJob ,
297
300
ExtractVocabSpecialLabelsJob ,
@@ -303,6 +306,13 @@ def py():
303
306
vocab_w_blank_file = ExtendVocabLabelsByNewLabelJob (
304
307
vocab = vocab_file , new_label = model_recog_ctc_only .output_blank_label , new_label_idx = _ctc_model_def_blank_idx
305
308
).out_vocab
309
+ log_prior_wo_blank = PriorRemoveLabelRenormJob (
310
+ prior_file = prior ,
311
+ prior_type = "prob" ,
312
+ vocab = vocab_w_blank_file ,
313
+ remove_label = model_recog_ctc_only .output_blank_label ,
314
+ out_prior_type = "log_prob" ,
315
+ ).out_prior
306
316
307
317
for beam_size , prior_scale , lm_scale in [
308
318
(16 , 0.5 , 1.0 ),
@@ -377,7 +387,7 @@ def py():
377
387
378
388
scales_results = {}
379
389
for lm_scale in np .linspace (0.0 , 1.0 , 11 ):
380
- for prior_scale_ in np .linspace (0.0 , 1.0 , 11 ):
390
+ for prior_scale_rel in np .linspace (0.0 , 1.0 , 11 ):
381
391
res = recog_model (
382
392
task = task ,
383
393
model = ctc_model ,
@@ -388,7 +398,7 @@ def py():
388
398
lm_framewise_prior_rescore ,
389
399
# framewise standard prior
390
400
prior = Prior (file = prior , type = "prob" , vocab = vocab_w_blank_file ),
391
- prior_scale = lm_scale * prior_scale_ ,
401
+ prior_scale = lm_scale * prior_scale_rel ,
392
402
lm = lm ,
393
403
lm_scale = lm_scale ,
394
404
lm_rescore_rqmt = {"cpu" : 4 , "mem" : 30 , "time" : 24 , "gpu_mem" : 24 },
@@ -398,14 +408,48 @@ def py():
398
408
],
399
409
)
400
410
tk .register_output (
401
- f"{ prefix } /rescore-beam{ beam_size } -lm_{ lm_out_name } -lmScale{ lm_scale } -priorScaleRel{ prior_scale } " ,
411
+ f"{ prefix } /rescore-beam{ beam_size } -lm_{ lm_out_name } -lmScale{ lm_scale } -priorScaleRel{ prior_scale_rel } " ,
402
412
res .output ,
403
413
)
404
- scales_results [(prior_scale_ , lm_scale )] = res .output
414
+ scales_results [(prior_scale_rel , lm_scale )] = res .output
405
415
_plot_scales (
406
416
f"rescore-beam{ beam_size } -lm_{ lm_out_name } -priorScaleRel" , scales_results , x_axis_name = "prior_scale_rel"
407
417
)
408
418
419
+ scales_results = {}
420
+ for lm_scale in np .linspace (0.0 , 1.0 , 3 ):
421
+ for prior_scale_rel in np .linspace (0.0 , 1.0 , 3 ):
422
+ res = recog_model (
423
+ task = task ,
424
+ model = ctc_model ,
425
+ recog_def = model_recog_ctc_only ,
426
+ config = {"beam_size" : beam_size },
427
+ recog_pre_post_proc_funcs_ext = [
428
+ functools .partial (
429
+ lm_labelwise_prior_rescore ,
430
+ # labelwise prior
431
+ prior = Prior (file = log_prior_wo_blank , type = "log_prob" , vocab = vocab_file ),
432
+ prior_scale = lm_scale * prior_scale_rel ,
433
+ lm = lm ,
434
+ lm_scale = lm_scale ,
435
+ lm_rescore_rqmt = {"cpu" : 4 , "mem" : 30 , "time" : 24 , "gpu_mem" : 24 },
436
+ vocab = vocab_file ,
437
+ vocab_opts_file = vocab_opts_file ,
438
+ )
439
+ ],
440
+ )
441
+ tk .register_output (
442
+ f"{ prefix } /rescore-beam{ beam_size } -lm_{ lm_out_name } -lmScale{ lm_scale } "
443
+ f"-labelPrior-priorScaleRel{ prior_scale_rel } " ,
444
+ res .output ,
445
+ )
446
+ scales_results [(prior_scale_rel , lm_scale )] = res .output
447
+ _plot_scales (
448
+ f"rescore-beam{ beam_size } -lm_{ lm_out_name } -labelPrior-priorScaleRel" ,
449
+ scales_results ,
450
+ x_axis_name = "prior_scale_rel" ,
451
+ )
452
+
409
453
410
454
_sis_prefix : Optional [str ] = None
411
455
0 commit comments