Skip to content

Commit 097d289

Browse files
committed
try ctc with labelwise prior
1 parent 8700988 commit 097d289

File tree

1 file changed

+50
-6
lines changed

1 file changed

+50
-6
lines changed

users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,11 @@ def py():
290290

291291
# Rescoring.
292292
from .ctc import model_recog as model_recog_ctc_only, _ctc_model_def_blank_idx
293-
from i6_experiments.users.zeyer.decoding.lm_rescoring import lm_framewise_prior_rescore
294-
from i6_experiments.users.zeyer.decoding.prior_rescoring import Prior
293+
from i6_experiments.users.zeyer.decoding.lm_rescoring import (
294+
lm_framewise_prior_rescore,
295+
lm_labelwise_prior_rescore,
296+
)
297+
from i6_experiments.users.zeyer.decoding.prior_rescoring import Prior, PriorRemoveLabelRenormJob
295298
from i6_experiments.users.zeyer.datasets.utils.vocab import (
296299
ExtractVocabLabelsJob,
297300
ExtractVocabSpecialLabelsJob,
@@ -303,6 +306,13 @@ def py():
303306
vocab_w_blank_file = ExtendVocabLabelsByNewLabelJob(
304307
vocab=vocab_file, new_label=model_recog_ctc_only.output_blank_label, new_label_idx=_ctc_model_def_blank_idx
305308
).out_vocab
309+
log_prior_wo_blank = PriorRemoveLabelRenormJob(
310+
prior_file=prior,
311+
prior_type="prob",
312+
vocab=vocab_w_blank_file,
313+
remove_label=model_recog_ctc_only.output_blank_label,
314+
out_prior_type="log_prob",
315+
).out_prior
306316

307317
for beam_size, prior_scale, lm_scale in [
308318
(16, 0.5, 1.0),
@@ -377,7 +387,7 @@ def py():
377387

378388
scales_results = {}
379389
for lm_scale in np.linspace(0.0, 1.0, 11):
380-
for prior_scale_ in np.linspace(0.0, 1.0, 11):
390+
for prior_scale_rel in np.linspace(0.0, 1.0, 11):
381391
res = recog_model(
382392
task=task,
383393
model=ctc_model,
@@ -388,7 +398,7 @@ def py():
388398
lm_framewise_prior_rescore,
389399
# framewise standard prior
390400
prior=Prior(file=prior, type="prob", vocab=vocab_w_blank_file),
391-
prior_scale=lm_scale * prior_scale_,
401+
prior_scale=lm_scale * prior_scale_rel,
392402
lm=lm,
393403
lm_scale=lm_scale,
394404
lm_rescore_rqmt={"cpu": 4, "mem": 30, "time": 24, "gpu_mem": 24},
@@ -398,14 +408,48 @@ def py():
398408
],
399409
)
400410
tk.register_output(
401-
f"{prefix}/rescore-beam{beam_size}-lm_{lm_out_name}-lmScale{lm_scale}-priorScaleRel{prior_scale}",
411+
f"{prefix}/rescore-beam{beam_size}-lm_{lm_out_name}-lmScale{lm_scale}-priorScaleRel{prior_scale_rel}",
402412
res.output,
403413
)
404-
scales_results[(prior_scale_, lm_scale)] = res.output
414+
scales_results[(prior_scale_rel, lm_scale)] = res.output
405415
_plot_scales(
406416
f"rescore-beam{beam_size}-lm_{lm_out_name}-priorScaleRel", scales_results, x_axis_name="prior_scale_rel"
407417
)
408418

419+
scales_results = {}
420+
for lm_scale in np.linspace(0.0, 1.0, 3):
421+
for prior_scale_rel in np.linspace(0.0, 1.0, 3):
422+
res = recog_model(
423+
task=task,
424+
model=ctc_model,
425+
recog_def=model_recog_ctc_only,
426+
config={"beam_size": beam_size},
427+
recog_pre_post_proc_funcs_ext=[
428+
functools.partial(
429+
lm_labelwise_prior_rescore,
430+
# labelwise prior
431+
prior=Prior(file=log_prior_wo_blank, type="log_prob", vocab=vocab_file),
432+
prior_scale=lm_scale * prior_scale_rel,
433+
lm=lm,
434+
lm_scale=lm_scale,
435+
lm_rescore_rqmt={"cpu": 4, "mem": 30, "time": 24, "gpu_mem": 24},
436+
vocab=vocab_file,
437+
vocab_opts_file=vocab_opts_file,
438+
)
439+
],
440+
)
441+
tk.register_output(
442+
f"{prefix}/rescore-beam{beam_size}-lm_{lm_out_name}-lmScale{lm_scale}"
443+
f"-labelPrior-priorScaleRel{prior_scale_rel}",
444+
res.output,
445+
)
446+
scales_results[(prior_scale_rel, lm_scale)] = res.output
447+
_plot_scales(
448+
f"rescore-beam{beam_size}-lm_{lm_out_name}-labelPrior-priorScaleRel",
449+
scales_results,
450+
x_axis_name="prior_scale_rel",
451+
)
452+
409453

410454
_sis_prefix: Optional[str] = None
411455

0 commit comments

Comments
 (0)