Skip to content

Commit

Permalink
Robustify prune_inferior_points tests against sorting order (#2548)
Browse files Browse the repository at this point in the history
Summary:
Our nightly CI started failing, likely due to a sorting order change introduced in pytorch/pytorch#127936

This change robustifies the tests against the point order (and also fixes a torch deprecation warning)

Pull Request resolved: #2548

Reviewed By: sdaulton

Differential Revision: D63260870

Pulled By: Balandat
  • Loading branch information
Balandat authored and facebook-github-bot committed Sep 23, 2024
1 parent 161a9a8 commit 96476af
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 20 deletions.
2 changes: 1 addition & 1 deletion botorch/acquisition/multi_objective/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def prune_inferior_points_multi_objective(
probs = pareto_mask.to(dtype=X.dtype).mean(dim=0)
idcs = probs.nonzero().view(-1)
if idcs.shape[0] > max_points:
counts, order_idcs = torch.sort(probs, descending=True)
counts, order_idcs = torch.sort(probs, stable=True, descending=True)
idcs = order_idcs[:max_points]
effective_n_w = obj_vals.shape[-2] // X.shape[-2]
idcs = (idcs / effective_n_w).long().unique()
Expand Down
7 changes: 4 additions & 3 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,16 @@ def prune_inferior_points(
marginalize_dim=marginalize_dim,
)
if infeas.any():
# set infeasible points to worse than worst objective
# across all samples
# set infeasible points to worse than worst objective across all samples
# Use clone() here to avoid deprecated `index_put_` on an expanded tensor
obj_vals = obj_vals.clone()
obj_vals[infeas] = obj_vals.min() - 1

is_best = torch.argmax(obj_vals, dim=-1)
idcs, counts = torch.unique(is_best, return_counts=True)

if len(idcs) > max_points:
counts, order_idcs = torch.sort(counts, descending=True)
counts, order_idcs = torch.sort(counts, stable=True, descending=True)
idcs = order_idcs[:max_points]

return X[idcs]
Expand Down
13 changes: 7 additions & 6 deletions test/acquisition/multi_objective/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,14 @@ def test_prune_inferior_points_multi_objective(self):
X_pruned = prune_inferior_points_multi_objective(
model=mm, X=X, ref_point=ref_point, max_frac=2 / 3
)
if self.device.type == "cuda":
# sorting has different order on cuda
self.assertTrue(
torch.equal(X_pruned, X[[2, 1]]) or torch.equal(X_pruned, X[[1, 2]])
# sorting has different order on cuda
X_expected = X[1:3] if self.device.type == "cuda" else X[:2]
self.assertTrue(
torch.equal(
torch.sort(X_pruned, stable=True).values,
torch.sort(X_expected, stable=True).values,
)
else:
self.assertTrue(torch.equal(X_pruned, X[:2]))
)
# test that zero-probability is in fact pruned
samples[2, 0, 0] = 10
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
Expand Down
19 changes: 9 additions & 10 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,14 @@ def test_prune_inferior_points(self):
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
mm = MockModel(MockPosterior(samples=samples))
X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3)
if self.device.type == "cuda":
# sorting has different order on cuda
self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0)))
else:
self.assertTrue(torch.equal(X_pruned, X[:2]))
# sorting has different order on cuda
X_expected = X[1:3] if self.device.type == "cuda" else X[:2]
self.assertTrue(
torch.equal(
torch.sort(X_pruned, stable=True).values,
torch.sort(X_expected, stable=True).values,
)
)
# test that zero-probability is in fact pruned
samples[2, 0, 0] = 10
with mock.patch.object(MockPosterior, "rsample", return_value=samples):
Expand All @@ -289,11 +292,7 @@ def test_prune_inferior_points(self):
device=self.device,
dtype=dtype,
)
mm = MockModel(
MockPosterior(
samples=samples,
)
)
mm = MockModel(MockPosterior(samples=samples))
X_pruned = prune_inferior_points(
model=mm,
X=X,
Expand Down

0 comments on commit 96476af

Please sign in to comment.