From 9295acc13b5965d78c496ac9ccfc860dd788d126 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Sat, 21 Sep 2024 09:21:43 -0700 Subject: [PATCH] Robustify prune_inferior_points tests against sorting order Our nightly CI started failing, likely due to a sorting order change introduced in https://github.com/pytorch/pytorch/pull/127936 This change robustifies the tests against the point order (and also fixes a torch deprecation warning) --- botorch/acquisition/utils.py | 5 +++-- .../acquisition/multi_objective/test_utils.py | 13 +++++++------ test/acquisition/test_utils.py | 19 +++++++++---------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 198228409a..d0ab28d906 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -335,8 +335,9 @@ def prune_inferior_points( marginalize_dim=marginalize_dim, ) if infeas.any(): - # set infeasible points to worse than worst objective - # across all samples + # set infeasible points to worse than worst objective across all samples + # Use clone() here to avoid deprecated `index_put_`` on an expanded tensor + obj_vals = obj_vals.clone() obj_vals[infeas] = obj_vals.min() - 1 is_best = torch.argmax(obj_vals, dim=-1) diff --git a/test/acquisition/multi_objective/test_utils.py b/test/acquisition/multi_objective/test_utils.py index acdfddbc95..786c72ad9c 100644 --- a/test/acquisition/multi_objective/test_utils.py +++ b/test/acquisition/multi_objective/test_utils.py @@ -130,13 +130,14 @@ def test_prune_inferior_points_multi_objective(self): X_pruned = prune_inferior_points_multi_objective( model=mm, X=X, ref_point=ref_point, max_frac=2 / 3 ) - if self.device.type == "cuda": - # sorting has different order on cuda - self.assertTrue( - torch.equal(X_pruned, X[[2, 1]]) or torch.equal(X_pruned, X[[1, 2]]) + # sorting has different order on cuda + X_expected = X[1:3] if self.device.type == "cuda" else X[:2] + self.assertTrue( + torch.equal( + torch.sort(X_pruned, stable=True).values, + torch.sort(X_expected, stable=True).values, ) - else: - self.assertTrue(torch.equal(X_pruned, X[:2])) + ) # test that zero-probability is in fact pruned samples[2, 0, 0] = 10 with mock.patch.object(MockPosterior, "rsample", return_value=samples): diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index c8a6484cca..c9552da886 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -270,11 +270,14 @@ def test_prune_inferior_points(self): with mock.patch.object(MockPosterior, "rsample", return_value=samples): mm = MockModel(MockPosterior(samples=samples)) X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3) - if self.device.type == "cuda": - # sorting has different order on cuda - self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0))) - else: - self.assertTrue(torch.equal(X_pruned, X[:2])) + # sorting has different order on cuda + X_expected = X[1:3] if self.device.type == "cuda" else X[:2] + self.assertTrue( + torch.equal( + torch.sort(X_pruned, stable=True).values, + torch.sort(X_expected, stable=True).values, + ) + ) # test that zero-probability is in fact pruned samples[2, 0, 0] = 10 with mock.patch.object(MockPosterior, "rsample", return_value=samples): @@ -289,11 +292,7 @@ def test_prune_inferior_points(self): device=self.device, dtype=dtype, ) - mm = MockModel( - MockPosterior( - samples=samples, - ) - ) + mm = MockModel(MockPosterior(samples=samples)) X_pruned = prune_inferior_points( model=mm, X=X,