@@ -472,17 +472,14 @@ def _masked_scatter(
472
472
if any (d in reverse_dim_map for d in s .dims ):
473
473
for d in s .dims :
474
474
if d in reverse_dim_map :
475
- s = _expand_slice (s , axis = d , expanded_size = reverse_dim_map [d ])
475
+ s = _expand_slice (s , old_dim = d , new_dim = reverse_dim_map [d ])
476
476
# We also might need to replace newly merged dims, both in s and backup.
477
477
for d in s .dims :
478
478
if d in merged_dim_map :
479
- s , _ = rf .slice (s , axis = d , size = merged_dim_map [d ])
480
- # There is currently the implicit assumption that the backup might need extra padding,
481
- # while the s needs slicing...
482
- # (We think of the hist_dim, where s should only have more frames than backup, or the same.)
479
+ s , _ = _expand_slice (s , old_dim = d , new_dim = merged_dim_map [d ])
483
480
for d in backup .dims :
484
481
if d in merged_dim_map :
485
- backup = _expand_slice (backup , axis = d , expanded_size = merged_dim_map [d ])
482
+ backup = _expand_slice (backup , old_dim = d , new_dim = merged_dim_map [d ])
486
483
# The unpacking itself (reversing the masked_select, i.e. masked_scatter).
487
484
s = rf .masked_scatter (s , backup , mask = mask , dims = dims , in_dim = in_dim )
488
485
return s
@@ -497,14 +494,22 @@ def _masked_scatter(
497
494
raise TypeError (f"_masked_scatter: unexpected type ({ type (s )} )" )
498
495
499
496
500
- def _expand_slice (source : Tensor , axis : Dim , expanded_size : Dim ) -> Tensor :
501
- res , _ = rf .pad (
502
- source ,
503
- axes = [axis ],
504
- padding = [(0 , expanded_size .get_dim_value_tensor () - axis .get_dim_value_tensor ())],
505
- out_dims = [expanded_size ],
506
- value = 0 ,
507
- )
497
+ def _expand_slice (source : Tensor , old_dim : Dim , new_dim : Dim ) -> Tensor :
498
+ assert old_dim in source .dims
499
+ old_size = old_dim .get_dim_value_tensor ()
500
+ new_size = new_dim .get_dim_value_tensor ()
501
+ if old_size == new_size :
502
+ res , _ = rf .replace_dim (source , in_dim = old_dim , out_dim = new_dim )
503
+ elif old_size < new_size :
504
+ res , _ = rf .pad (
505
+ source ,
506
+ axes = [old_dim ],
507
+ padding = [(0 , new_size - old_size )],
508
+ out_dims = [new_dim ],
509
+ value = 0 ,
510
+ )
511
+ else :
512
+ res , _ = rf .slice (source , axis = old_dim , size = new_dim )
508
513
return res
509
514
510
515
0 commit comments