@@ -471,7 +471,7 @@ def _masked_scatter(
471
471
if any (d in reverse_dim_map for d in s .dims ):
472
472
for d in s .dims :
473
473
if d in reverse_dim_map :
474
- s = _expand_slice (s , old_dim = d , new_dim = reverse_dim_map [d ])
474
+ s = _expand_slice (s , old_dim = d , new_dim = reverse_dim_map [d ], expect_expand = True )
475
475
# We also might need to replace newly merged dims, both in s and backup.
476
476
for d in s .dims :
477
477
if d in merged_dim_map :
@@ -493,21 +493,26 @@ def _masked_scatter(
493
493
raise TypeError (f"_masked_scatter: unexpected type ({ type (s )} )" )
494
494
495
495
496
- def _expand_slice (source : Tensor , old_dim : Dim , new_dim : Dim ) -> Tensor :
496
+ def _expand_slice (source : Tensor , old_dim : Dim , new_dim : Dim , * , expect_expand : Optional [ bool ] = None ) -> Tensor :
497
497
assert old_dim in source .dims
498
- old_size = old_dim .get_dim_value_tensor ()
499
- new_size = new_dim .get_dim_value_tensor ()
498
+ old_size = old_dim .get_dim_value ()
499
+ new_size = new_dim .get_dim_value ()
500
500
if old_size == new_size :
501
501
res , _ = rf .replace_dim (source , in_dim = old_dim , out_dim = new_dim )
502
502
elif old_size < new_size :
503
503
res , _ = rf .pad (
504
504
source ,
505
505
axes = [old_dim ],
506
- padding = [(0 , new_size - old_size )],
506
+ padding = [(0 , new_dim . get_dim_value_tensor () - old_dim . get_dim_value_tensor () )],
507
507
out_dims = [new_dim ],
508
508
value = 0 ,
509
509
)
510
510
else :
511
+ if expect_expand is True :
512
+ raise ValueError (
513
+ f"expected expand, but got reduce (slice): { old_size } -> { new_size } ,"
514
+ f" for { old_dim = } { new_dim = } , in { source = } "
515
+ )
511
516
res , _ = rf .slice (source , axis = old_dim , size = new_dim )
512
517
return res
513
518
0 commit comments