Skip to content

Commit db7d737

Browse files
committed
fix expand slice
1 parent 713ce8f commit db7d737

File tree

1 file changed

+19
-14
lines changed
  • users/zeyer/experiments/exp2024_04_23_baselines/recog_ext

1 file changed

+19
-14
lines changed

users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -472,17 +472,14 @@ def _masked_scatter(
472472
if any(d in reverse_dim_map for d in s.dims):
473473
for d in s.dims:
474474
if d in reverse_dim_map:
475-
s = _expand_slice(s, axis=d, expanded_size=reverse_dim_map[d])
475+
s = _expand_slice(s, old_dim=d, new_dim=reverse_dim_map[d])
476476
# We also might need to replace newly merged dims, both in s and backup.
477477
for d in s.dims:
478478
if d in merged_dim_map:
479-
s, _ = rf.slice(s, axis=d, size=merged_dim_map[d])
480-
# There is currently the implicit assumption that the backup might need extra padding,
481-
# while the s needs slicing...
482-
# (We think of the hist_dim, where s should only have more frames than backup, or the same.)
479+
s, _ = _expand_slice(s, old_dim=d, new_dim=merged_dim_map[d])
483480
for d in backup.dims:
484481
if d in merged_dim_map:
485-
backup = _expand_slice(backup, axis=d, expanded_size=merged_dim_map[d])
482+
backup = _expand_slice(backup, old_dim=d, new_dim=merged_dim_map[d])
486483
# The unpacking itself (reversing the masked_select, i.e. masked_scatter).
487484
s = rf.masked_scatter(s, backup, mask=mask, dims=dims, in_dim=in_dim)
488485
return s
@@ -497,14 +494,22 @@ def _masked_scatter(
497494
raise TypeError(f"_masked_scatter: unexpected type ({type(s)})")
498495

499496

500-
def _expand_slice(source: Tensor, axis: Dim, expanded_size: Dim) -> Tensor:
501-
res, _ = rf.pad(
502-
source,
503-
axes=[axis],
504-
padding=[(0, expanded_size.get_dim_value_tensor() - axis.get_dim_value_tensor())],
505-
out_dims=[expanded_size],
506-
value=0,
507-
)
497+
def _expand_slice(source: Tensor, old_dim: Dim, new_dim: Dim) -> Tensor:
498+
assert old_dim in source.dims
499+
old_size = old_dim.get_dim_value_tensor()
500+
new_size = new_dim.get_dim_value_tensor()
501+
if old_size == new_size:
502+
res, _ = rf.replace_dim(source, in_dim=old_dim, out_dim=new_dim)
503+
elif old_size < new_size:
504+
res, _ = rf.pad(
505+
source,
506+
axes=[old_dim],
507+
padding=[(0, new_size - old_size)],
508+
out_dims=[new_dim],
509+
value=0,
510+
)
511+
else:
512+
res, _ = rf.slice(source, axis=old_dim, size=new_dim)
508513
return res
509514

510515

0 commit comments

Comments
 (0)