diff --git a/users/zeyer/decoding/rescoring.py b/users/zeyer/decoding/rescoring.py index 3c138ba47..621bd0e9e 100644 --- a/users/zeyer/decoding/rescoring.py +++ b/users/zeyer/decoding/rescoring.py @@ -175,7 +175,7 @@ def _returnn_rescore_config( """ Create config for rescoring. """ - from returnn.tensor import Tensor, Dim + from returnn.tensor import Tensor, Dim, batch_dim from i6_experiments.users.zeyer.serialization_v2 import ReturnnConfigWithNewSerialization config = config.copy() if config else {} @@ -187,6 +187,8 @@ def _returnn_rescore_config( # Beam dim size unknown. Usually static size, but it's ok to leave this unknown here (right?). beam_dim = Dim(Tensor("beam_size", dims=[], dtype="int32"), name="beam") + data_flat_spatial_dim = Dim(None, name="data_flat_spatial") + # Note: we should not put SPM/BPE directly here, # because the recog output still has individual labels, # so no SPM/BPE encoding on the text. @@ -198,10 +200,11 @@ def _returnn_rescore_config( "default_input": None, # no input "target": "data_flat", # needed for get_model to know the target dim "_beam_dim": beam_dim, + "_data_flat_spatial_dim": data_flat_spatial_dim, "extern_data": { # data_flat dyn dim is the flattened dim, no need to define dim tags now - "data_flat": {"shape": [None], "dtype": "int32", "vocab": vocab_opts}, - "data_seq_lens": {"dims": [beam_dim], "dtype": "int32"}, + "data_flat": {"dims": [batch_dim, data_flat_spatial_dim], "dtype": "int32", "vocab": vocab_opts}, + "data_seq_lens": {"dims": [batch_dim, beam_dim], "dtype": "int32"}, }, } ) @@ -280,7 +283,7 @@ def _returnn_score_step(*, model, extern_data: TensorDict, **_kwargs_unused): targets_beam_dim = config.typed_value("_beam_dim") targets_flat = extern_data["data_flat"] - targets_flat_time_dim = targets_flat.get_time_dim_tag() + targets_flat_time_dim = config.typed_value("_data_flat_spatial_dim") targets_seq_lens = extern_data["data_seq_lens"] # [B, beam] targets_spatial_dim = Dim(targets_seq_lens, name="targets_spatial") targets = rf.pad_packed(targets_flat, in_dim=targets_flat_time_dim, dims=[targets_beam_dim, targets_spatial_dim])