Skip to content

Commit

Permalink
Robustify prune_inferior_points tests against sorting order
Browse files Browse the repository at this point in the history
Our nightly CI started failing, likely due to a sorting order change introduced in pytorch/pytorch#127936

This change robustifies the tests against the point order (and also fixes a torch deprecation warning)
  • Loading branch information
Balandat committed Sep 21, 2024
1 parent ab5ffea commit 9295acc
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
5 changes: 3 additions & 2 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,9 @@ def prune_inferior_points(
marginalize_dim=marginalize_dim,
)
if infeas.any():
# set infeasible points to worse than worst objective
# across all samples
# set infeasible points to worse than worst objective across all samples
# Use clone() here to avoid deprecated `index_put_`` on an expanded tensor
obj_vals = obj_vals.clone()
obj_vals[infeas] = obj_vals.min() - 1

is_best = torch.argmax(obj_vals, dim=-1)
Expand Down
13 changes: 7 additions & 6 deletions test/acquisition/multi_objective/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,14 @@ def test_prune_inferior_points_multi_objective(self):
X_pruned = prune_inferior_points_multi_objective(
model=mm, X=X, ref_point=ref_point, max_frac=2 / 3
)
if self.device.type == "cuda":
# sorting has different order on cuda
self.assertTrue(
torch.equal(X_pruned, X[[2, 1]]) or torch.equal(X_pruned, X[[1, 2]])
# sorting has different order on cuda
X_expected = X[1:3] if self.device.type == "cuda" else X[:2]
self.assertTrue(
torch.equal(
torch.sort(X_pruned, stable=True).values,
torch.sort(X_expected, stable=True).values,
)
else:
self.assertTrue(torch.equal(X_pruned, X[:2]))
)
# test that zero-probability is in fact pruned
samples[2, 0, 0] = 10
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
Expand Down
19 changes: 9 additions & 10 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,14 @@ def test_prune_inferior_points(self):
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
mm = MockModel(MockPosterior(samples=samples))
X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3)
if self.device.type == "cuda":
# sorting has different order on cuda
self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0)))
else:
self.assertTrue(torch.equal(X_pruned, X[:2]))
# sorting has different order on cuda
X_expected = X[1:3] if self.device.type == "cuda" else X[:2]
self.assertTrue(
torch.equal(
torch.sort(X_pruned, stable=True).values,
torch.sort(X_expected, stable=True).values,
)
)
# test that zero-probability is in fact pruned
samples[2, 0, 0] = 10
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
Expand All @@ -289,11 +292,7 @@ def test_prune_inferior_points(self):
device=self.device,
dtype=dtype,
)
mm = MockModel(
MockPosterior(
samples=samples,
)
)
mm = MockModel(MockPosterior(samples=samples))
X_pruned = prune_inferior_points(
model=mm,
X=X,
Expand Down

0 comments on commit 9295acc

Please sign in to comment.