diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py index 337348287..2e50bdd17 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py @@ -740,8 +740,6 @@ def _ctc_model_rescore( targets_spatial_dim: Dim, ): """RescoreDef API""" - targets_beam_dim # noqa # unused here - import returnn.frontend as rf from returnn.tensor import Tensor, Dim @@ -754,18 +752,26 @@ def _ctc_model_rescore( batch_dims = targets.remaining_dims(targets_spatial_dim) - # Note: This requires quite a lot of memory, as we broadcast the log_probs over the beam dim. - # We could also do a loop over the beam dim (or over chunks of the beam dim) to avoid this. - - # Note: gradient does not matter (not used), thus no need use our ctc_loss_fixed_grad. - neg_log_prob = rf.ctc_loss( - logits=log_probs, - logits_normalized=True, - targets=targets, - input_spatial_dim=enc_spatial_dim, - targets_spatial_dim=targets_spatial_dim, - blank_index=model.blank_idx, - ) + # Note: Using ctc_loss directly requires quite a lot of memory, + # as we would broadcast the log_probs over the beam dim. + # Instead, we do a loop over the beam dim to avoid this. + neg_log_prob_ = [] + for beam_idx in range(targets_beam_dim.get_dim_value()): + targets_b = rf.gather(targets, axis=targets_beam_dim, indices=beam_idx) + targets_b_seq_lens = rf.gather(targets_spatial_dim.dyn_size_ext, axis=targets_beam_dim, indices=beam_idx) + targets_b_spatial_dim = Dim(targets_b_seq_lens, name=f"{targets_spatial_dim.name}_beam{beam_idx}") + targets_b, _ = rf.replace_dim(targets_b, in_dim=targets_spatial_dim, out_dim=targets_b_spatial_dim) + # Note: gradient does not matter (not used), thus no need use our ctc_loss_fixed_grad. + neg_log_prob = rf.ctc_loss( + logits=log_probs, + logits_normalized=True, + targets=targets_b, + input_spatial_dim=enc_spatial_dim, + targets_spatial_dim=targets_b_spatial_dim, + blank_index=model.blank_idx, + ) + neg_log_prob_.append(neg_log_prob) + neg_log_prob, _ = rf.stack(neg_log_prob_, out_dim=targets_beam_dim) log_prob_targets_seq = -neg_log_prob assert log_prob_targets_seq.dims_set == set(batch_dims) return log_prob_targets_seq