Skip to content

Commit

Permalink
fix masked scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 20, 2025
1 parent f17a530 commit 21a3413
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,10 @@ def _masked_scatter(
) -> T:
if isinstance(s, Tensor):
assert isinstance(backup, Tensor)
if in_dim not in s.dims:
s = rf.expand_dim(s, in_dim)
if in_dim not in backup.dims:
backup = rf.expand_dim(backup, in_dim)
# Do the reverse of _masked_select above.
# First replace the dims back.
if any(d in reverse_dim_map for d in s.dims):
Expand Down

0 comments on commit 21a3413

Please sign in to comment.