Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-p-schmitt committed Dec 4, 2024
1 parent 4d51834 commit 4692f63
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def decoder_default_initial_state(
"""Default initial state"""
state = rf.State()

if self.trafo_att:
if self.trafo_att and not use_mini_att and not use_zero_att:
state.trafo_att = self.trafo_att.default_initial_state(batch_dims=batch_dims)
att_dim = self.trafo_att.model_dim
else:
att_dim = self.att_num_heads * self.enc_out_dim
# att_dim = self.trafo_att.model_dim
# else:
# att_dim = self.att_num_heads * self.enc_out_dim

state.att = rf.zeros(list(batch_dims) + [att_dim])
state.att = rf.zeros(list(batch_dims) + [self.att_dim])
state.att.feature_dim_axis = len(state.att.dims) - 1

if "lstm" in self.decoder_state:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,17 +438,27 @@ def get_score(

ilm_eos_log_prob = rf.zeros(batch_dims, dtype="float32")
if ilm_state is not None:
ilm_step_out, ilm_state = model.label_decoder.loop_step(
**att_enc_args,
enc_spatial_dim=enc_spatial_dim,
input_embed=input_embed_label_model,
segment_lens=segment_lens,
segment_starts=segment_starts,
center_positions=center_positions,
state=ilm_state,
use_mini_att=ilm_type == "mini_att",
use_zero_att=ilm_type == "zero_att",
)
if model.center_window_size is None:
ilm_step_out, ilm_state = model.label_decoder.loop_step(
**att_enc_args,
enc_spatial_dim=enc_spatial_dim,
input_embed=input_embed_label_model,
state=ilm_state,
use_mini_att=ilm_type == "mini_att",
use_zero_att=ilm_type == "zero_att",
)
else:
ilm_step_out, ilm_state = model.label_decoder.loop_step(
**att_enc_args,
enc_spatial_dim=enc_spatial_dim,
input_embed=input_embed_label_model,
segment_lens=segment_lens,
segment_starts=segment_starts,
center_positions=center_positions,
state=ilm_state,
use_mini_att=ilm_type == "mini_att",
use_zero_att=ilm_type == "zero_att",
)
ilm_logits, _ = model.label_decoder.decode_logits(
input_embed=input_embed_label_model,
s=ilm_step_out["s"],
Expand Down Expand Up @@ -773,7 +783,19 @@ def model_recog(

# ILM
if ilm_type is not None:
ilm_state = model.label_decoder.default_initial_state(batch_dims=batch_dims_, use_mini_att=ilm_type == "mini_att")
if isinstance(model.label_decoder, GlobalAttDecoder):
ilm_state = model.label_decoder.decoder_default_initial_state(
batch_dims=batch_dims_,
use_mini_att=ilm_type == "mini_att",
use_zero_att=ilm_type == "zero_att",
enc_spatial_dim=enc_spatial_dim
)
else:
ilm_state = model.label_decoder.default_initial_state(
batch_dims=batch_dims_,
use_mini_att=ilm_type == "mini_att",
use_zero_att=ilm_type == "zero_att",
)
else:
ilm_state = None

Expand Down

0 comments on commit 4692f63

Please sign in to comment.