diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py b/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py index 326483aba..b00f45175 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py @@ -416,6 +416,10 @@ def _masked_scatter( ) -> T: if isinstance(s, Tensor): assert isinstance(backup, Tensor) + if in_dim not in s.dims: + s = rf.expand_dim(s, in_dim) + if in_dim not in backup.dims: + backup = rf.expand_dim(backup, in_dim) # Do the reverse of _masked_select above. # First replace the dims back. if any(d in reverse_dim_map for d in s.dims):