Skip to content

Commit

Permalink
Replace squeeze with sum in _inf_max_helper (#2083)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2083

This commit replaces `squeeze` with `sum` over length-1 dimensions in `_inf_max_helper` for PyTorch 1.13 compatibility.

Reviewed By: Balandat

Differential Revision: D51030343

fbshipit-source-id: 87ce5c5d71812a70553688d55c82aac56ddebbec
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Nov 6, 2023
1 parent 6bde5d4 commit 787cc7f
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion botorch/utils/safe_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,11 @@ def _inf_max_helper(
y_inf.sum(dim=dim, keepdim=True),
M_no_inf + max_fun(y_no_inf, dim=dim, keepdim=True),
)
return res if keepdim else res.squeeze(dim)
# NOTE: Using `sum` instead of `squeeze` because PyTorch < 2.0 does not support
# tuple `dim` arguments. `sum` and `squeeze` are equivalent here because the
# `dim` dimensions have length one after the reductions in the previous lines.
# TODO: Replace `sum` with `squeeze` once PyTorch >= 2.0 is required.
return res if keepdim else res.sum(dim=dim)


def _any(x: Tensor, dim: Union[int, Tuple[int, ...]], keepdim: bool = False) -> Tensor:
Expand Down

0 comments on commit 787cc7f

Please sign in to comment.