diff --git a/botorch/utils/safe_math.py b/botorch/utils/safe_math.py index 7c4c30c984..4ec2892e90 100644 --- a/botorch/utils/safe_math.py +++ b/botorch/utils/safe_math.py @@ -180,7 +180,11 @@ def _inf_max_helper( y_inf.sum(dim=dim, keepdim=True), M_no_inf + max_fun(y_no_inf, dim=dim, keepdim=True), ) - return res if keepdim else res.squeeze(dim) + # NOTE: Using `sum` instead of `squeeze` because PyTorch < 2.0 does not support + # tuple `dim` arguments. `sum` and `squeeze` are equivalent here because the + # `dim` dimensions have length one after the reductions in the previous lines. + # TODO: Replace `sum` with `squeeze` once PyTorch >= 2.0 is required. + return res if keepdim else res.sum(dim=dim) def _any(x: Tensor, dim: Union[int, Tuple[int, ...]], keepdim: bool = False) -> Tensor: