Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 16, 2025
1 parent c8d59a0 commit 863eb0c
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions users/zeyer/decoding/rescoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _returnn_rescore_config(
"""
Create config for rescoring.
"""
from returnn.tensor import Tensor, Dim
from returnn.tensor import Tensor, Dim, batch_dim
from i6_experiments.users.zeyer.serialization_v2 import ReturnnConfigWithNewSerialization

config = config.copy() if config else {}
Expand All @@ -187,6 +187,8 @@ def _returnn_rescore_config(
# Beam dim size unknown. Usually static size, but it's ok to leave this unknown here (right?).
beam_dim = Dim(Tensor("beam_size", dims=[], dtype="int32"), name="beam")

data_flat_spatial_dim = Dim(None, name="data_flat_spatial")

# Note: we should not put SPM/BPE directly here,
# because the recog output still has individual labels,
# so no SPM/BPE encoding on the text.
Expand All @@ -198,10 +200,11 @@ def _returnn_rescore_config(
"default_input": None, # no input
"target": "data_flat", # needed for get_model to know the target dim
"_beam_dim": beam_dim,
"_data_flat_spatial_dim": data_flat_spatial_dim,
"extern_data": {
# data_flat dyn dim is the flattened dim, no need to define dim tags now
"data_flat": {"shape": [None], "dtype": "int32", "vocab": vocab_opts},
"data_seq_lens": {"dims": [beam_dim], "dtype": "int32"},
"data_flat": {"dims": [batch_dim, data_flat_spatial_dim], "dtype": "int32", "vocab": vocab_opts},
"data_seq_lens": {"dims": [batch_dim, beam_dim], "dtype": "int32"},
},
}
)
Expand Down Expand Up @@ -280,7 +283,7 @@ def _returnn_score_step(*, model, extern_data: TensorDict, **_kwargs_unused):

targets_beam_dim = config.typed_value("_beam_dim")
targets_flat = extern_data["data_flat"]
targets_flat_time_dim = targets_flat.get_time_dim_tag()
targets_flat_time_dim = config.typed_value("_data_flat_spatial_dim")
targets_seq_lens = extern_data["data_seq_lens"] # [B, beam]
targets_spatial_dim = Dim(targets_seq_lens, name="targets_spatial")
targets = rf.pad_packed(targets_flat, in_dim=targets_flat_time_dim, dims=[targets_beam_dim, targets_spatial_dim])
Expand Down

0 comments on commit 863eb0c

Please sign in to comment.