diff --git a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23_rf/dependencies/returnn/network_builder_rf/global_/decoder.py b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23_rf/dependencies/returnn/network_builder_rf/global_/decoder.py index 82d6e9279..4e06921f4 100644 --- a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23_rf/dependencies/returnn/network_builder_rf/global_/decoder.py +++ b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23_rf/dependencies/returnn/network_builder_rf/global_/decoder.py @@ -25,13 +25,13 @@ def decoder_default_initial_state( """Default initial state""" state = rf.State() - if self.trafo_att: + if self.trafo_att and not use_mini_att and not use_zero_att: state.trafo_att = self.trafo_att.default_initial_state(batch_dims=batch_dims) - att_dim = self.trafo_att.model_dim - else: - att_dim = self.att_num_heads * self.enc_out_dim + # att_dim = self.trafo_att.model_dim + # else: + # att_dim = self.att_num_heads * self.enc_out_dim - state.att = rf.zeros(list(batch_dims) + [att_dim]) + state.att = rf.zeros(list(batch_dims) + [self.att_dim]) state.att.feature_dim_axis = len(state.att.dims) - 1 if "lstm" in self.decoder_state: diff --git a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23_rf/dependencies/returnn/network_builder_rf/segmental/recog.py b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23_rf/dependencies/returnn/network_builder_rf/segmental/recog.py index dc8836ca1..01312f608 100644 --- a/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23_rf/dependencies/returnn/network_builder_rf/segmental/recog.py +++ b/users/schmitt/experiments/config/pipelines/global_vs_segmental_2022_23_rf/dependencies/returnn/network_builder_rf/segmental/recog.py @@ -438,17 +438,27 @@ def get_score( ilm_eos_log_prob = rf.zeros(batch_dims, dtype="float32") if ilm_state is not None: - ilm_step_out, ilm_state = model.label_decoder.loop_step( - **att_enc_args, - enc_spatial_dim=enc_spatial_dim, - input_embed=input_embed_label_model, - segment_lens=segment_lens, - segment_starts=segment_starts, - center_positions=center_positions, - state=ilm_state, - use_mini_att=ilm_type == "mini_att", - use_zero_att=ilm_type == "zero_att", - ) + if model.center_window_size is None: + ilm_step_out, ilm_state = model.label_decoder.loop_step( + **att_enc_args, + enc_spatial_dim=enc_spatial_dim, + input_embed=input_embed_label_model, + state=ilm_state, + use_mini_att=ilm_type == "mini_att", + use_zero_att=ilm_type == "zero_att", + ) + else: + ilm_step_out, ilm_state = model.label_decoder.loop_step( + **att_enc_args, + enc_spatial_dim=enc_spatial_dim, + input_embed=input_embed_label_model, + segment_lens=segment_lens, + segment_starts=segment_starts, + center_positions=center_positions, + state=ilm_state, + use_mini_att=ilm_type == "mini_att", + use_zero_att=ilm_type == "zero_att", + ) ilm_logits, _ = model.label_decoder.decode_logits( input_embed=input_embed_label_model, s=ilm_step_out["s"], @@ -773,7 +783,19 @@ def model_recog( # ILM if ilm_type is not None: - ilm_state = model.label_decoder.default_initial_state(batch_dims=batch_dims_, use_mini_att=ilm_type == "mini_att") + if isinstance(model.label_decoder, GlobalAttDecoder): + ilm_state = model.label_decoder.decoder_default_initial_state( + batch_dims=batch_dims_, + use_mini_att=ilm_type == "mini_att", + use_zero_att=ilm_type == "zero_att", + enc_spatial_dim=enc_spatial_dim + ) + else: + ilm_state = model.label_decoder.default_initial_state( + batch_dims=batch_dims_, + use_mini_att=ilm_type == "mini_att", + use_zero_att=ilm_type == "zero_att", + ) else: ilm_state = None