Skip to content

Commit dd72c1a

Browse files
committed
better check
1 parent e258e29 commit dd72c1a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,6 @@ def finish(self, state: LMState):
710710
assert seq_len <= max_seq_len
711711
results = fl_decoder.decode(emissions_ptr, seq_len, model.wb_target_dim.dimension)
712712
hyps_per_batch = [result.tokens for result in results]
713-
assert all(len(hyp) == seq_len for hyp in hyps_per_batch)
714713
scores_per_batch = [result.score for result in results]
715714
best_word_seq = [
716715
model.wb_target_dim.vocab.id_to_label(label_idx) if label_idx >= 0 else str(label_idx)
@@ -725,6 +724,9 @@ def finish(self, state: LMState):
725724
f" LM recalc whole seq count {fl_lm._count_recalc_whole_seq}"
726725
f" mem usage {dev_s}: {' '.join(_collect_mem_stats())}"
727726
)
727+
assert all(
728+
len(hyp) == seq_len for hyp in hyps_per_batch
729+
), f"seq_len {seq_len}, hyps lens {[len(hyp) for hyp in hyps_per_batch]}"
728730
if len(results) >= n_best:
729731
hyps_per_batch = hyps_per_batch[:n_best]
730732
scores_per_batch = scores_per_batch[:n_best]

0 commit comments

Comments
 (0)