Skip to content

Commit

Permalink
fix when mask is empty
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 20, 2025
1 parent 44c58ae commit 3601089
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,16 @@ def model_recog(
_seq_label_print("gather backrefs", seq_label)

got_new_label_cpu = rf.copy_to_device(got_new_label, "cpu")
(target_, lm_state_, seq_label_), packed_new_label_dim, packed_new_label_dim_map = _masked_select_tree(
(target, lm_state, seq_label),
mask=got_new_label,
mask_cpu=got_new_label_cpu,
dims=batch_dims + [beam_dim],
)
# packed_new_label_dim_map: old dim -> new dim. see _masked_select_prepare_dims
if got_new_label_cpu.raw_tensor.sum().item() > 0:
(target_, lm_state_, seq_label_), packed_new_label_dim, packed_new_label_dim_map = _masked_select_tree(
(target, lm_state, seq_label),
mask=got_new_label,
mask_cpu=got_new_label_cpu,
dims=batch_dims + [beam_dim],
)
# packed_new_label_dim_map: old dim -> new dim. see _masked_select_prepare_dims
assert packed_new_label_dim.get_dim_value() > 0

if packed_new_label_dim.get_dim_value() > 0:
print(
f"* feed target"
f" {[model.target_dim.vocab.id_to_label(l.item()) for l in target_.raw_tensor[:3].cpu()]}"
Expand Down

0 comments on commit 3601089

Please sign in to comment.