Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 15, 2025
1 parent 3fa3ca2 commit e0d4767
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) ->
def _batchify(*args):
assert len(args) == batch_size
if isinstance(args[0], Tensor):
return rf.stack(args, out_dim=batch_dim_)
x, _ = rf.stack(args, out_dim=batch_dim_)
return x
if isinstance(args[0], Dim):
assert all(args[0] == s for s in args)
return args[0]
Expand Down

0 comments on commit e0d4767

Please sign in to comment.