Skip to content

Commit

Permalink
fix gpu oom
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 19, 2025
1 parent 4b2a384 commit 8256873
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc_recog_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,6 @@ def _ctc_model_rescore(
targets_spatial_dim: Dim,
):
"""RescoreDef API"""
targets_beam_dim # noqa # unused here

import returnn.frontend as rf
from returnn.tensor import Tensor, Dim

Expand All @@ -754,18 +752,26 @@ def _ctc_model_rescore(

batch_dims = targets.remaining_dims(targets_spatial_dim)

# Note: This requires quite a lot of memory, as we broadcast the log_probs over the beam dim.
# We could also do a loop over the beam dim (or over chunks of the beam dim) to avoid this.

# Note: gradient does not matter (not used), thus no need use our ctc_loss_fixed_grad.
neg_log_prob = rf.ctc_loss(
logits=log_probs,
logits_normalized=True,
targets=targets,
input_spatial_dim=enc_spatial_dim,
targets_spatial_dim=targets_spatial_dim,
blank_index=model.blank_idx,
)
# Note: Using ctc_loss directly requires quite a lot of memory,
# as we would broadcast the log_probs over the beam dim.
# Instead, we do a loop over the beam dim to avoid this.
neg_log_prob_ = []
for beam_idx in range(targets_beam_dim.get_dim_value()):
targets_b = rf.gather(targets, axis=targets_beam_dim, indices=beam_idx)
targets_b_seq_lens = rf.gather(targets_spatial_dim.dyn_size_ext, axis=targets_beam_dim, indices=beam_idx)
targets_b_spatial_dim = Dim(targets_b_seq_lens, name=f"{targets_spatial_dim.name}_beam{beam_idx}")
targets_b, _ = rf.replace_dim(targets_b, in_dim=targets_spatial_dim, out_dim=targets_b_spatial_dim)
# Note: gradient does not matter (not used), thus no need use our ctc_loss_fixed_grad.
neg_log_prob = rf.ctc_loss(
logits=log_probs,
logits_normalized=True,
targets=targets_b,
input_spatial_dim=enc_spatial_dim,
targets_spatial_dim=targets_b_spatial_dim,
blank_index=model.blank_idx,
)
neg_log_prob_.append(neg_log_prob)
neg_log_prob, _ = rf.stack(neg_log_prob_, out_dim=targets_beam_dim)
log_prob_targets_seq = -neg_log_prob
assert log_prob_targets_seq.dims_set == set(batch_dims)
return log_prob_targets_seq
Expand Down

0 comments on commit 8256873

Please sign in to comment.