Skip to content

Commit 657ecca

Browse files
committed
expand slice, stupid fix
1 parent 719c2bd commit 657ecca

File tree

1 file changed

+10
-5
lines changed
  • users/zeyer/experiments/exp2024_04_23_baselines/recog_ext

1 file changed

+10
-5
lines changed

users/zeyer/experiments/exp2024_04_23_baselines/recog_ext/ctc.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ def _masked_scatter(
471471
if any(d in reverse_dim_map for d in s.dims):
472472
for d in s.dims:
473473
if d in reverse_dim_map:
474-
s = _expand_slice(s, old_dim=d, new_dim=reverse_dim_map[d])
474+
s = _expand_slice(s, old_dim=d, new_dim=reverse_dim_map[d], expect_expand=True)
475475
# We also might need to replace newly merged dims, both in s and backup.
476476
for d in s.dims:
477477
if d in merged_dim_map:
@@ -493,21 +493,26 @@ def _masked_scatter(
493493
raise TypeError(f"_masked_scatter: unexpected type ({type(s)})")
494494

495495

496-
def _expand_slice(source: Tensor, old_dim: Dim, new_dim: Dim) -> Tensor:
496+
def _expand_slice(source: Tensor, old_dim: Dim, new_dim: Dim, *, expect_expand: Optional[bool] = None) -> Tensor:
497497
assert old_dim in source.dims
498-
old_size = old_dim.get_dim_value_tensor()
499-
new_size = new_dim.get_dim_value_tensor()
498+
old_size = old_dim.get_dim_value()
499+
new_size = new_dim.get_dim_value()
500500
if old_size == new_size:
501501
res, _ = rf.replace_dim(source, in_dim=old_dim, out_dim=new_dim)
502502
elif old_size < new_size:
503503
res, _ = rf.pad(
504504
source,
505505
axes=[old_dim],
506-
padding=[(0, new_size - old_size)],
506+
padding=[(0, new_dim.get_dim_value_tensor() - old_dim.get_dim_value_tensor())],
507507
out_dims=[new_dim],
508508
value=0,
509509
)
510510
else:
511+
if expect_expand is True:
512+
raise ValueError(
513+
f"expected expand, but got reduce (slice): {old_size} -> {new_size},"
514+
f" for {old_dim=} {new_dim=}, in {source=}"
515+
)
511516
res, _ = rf.slice(source, axis=old_dim, size=new_dim)
512517
return res
513518

0 commit comments

Comments
 (0)