From 21a341310810dc06d59e164225cab1d16e03f7e6 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Jan 2025 18:48:23 +0100 Subject: [PATCH] fix masked scatter --- .../experiments/exp2024_04_23_baselines/recog_ext/ctc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py b/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py index 326483aba..b00f45175 100644 --- a/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py +++ b/users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py @@ -416,6 +416,10 @@ def _masked_scatter( ) -> T: if isinstance(s, Tensor): assert isinstance(backup, Tensor) + if in_dim not in s.dims: + s = rf.expand_dim(s, in_dim) + if in_dim not in backup.dims: + backup = rf.expand_dim(backup, in_dim) # Do the reverse of _masked_select above. # First replace the dims back. if any(d in reverse_dim_map for d in s.dims):