diff --git a/botorch/acquisition/multi_objective/utils.py b/botorch/acquisition/multi_objective/utils.py index 369c0e6a5c..30448b587b 100644 --- a/botorch/acquisition/multi_objective/utils.py +++ b/botorch/acquisition/multi_objective/utils.py @@ -154,7 +154,7 @@ def prune_inferior_points_multi_objective( probs = pareto_mask.to(dtype=X.dtype).mean(dim=0) idcs = probs.nonzero().view(-1) if idcs.shape[0] > max_points: - counts, order_idcs = torch.sort(probs, descending=True) + counts, order_idcs = torch.sort(probs, stable=True, descending=True) idcs = order_idcs[:max_points] effective_n_w = obj_vals.shape[-2] // X.shape[-2] idcs = (idcs / effective_n_w).long().unique() diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 198228409a..ae4f054321 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -335,15 +335,16 @@ def prune_inferior_points( marginalize_dim=marginalize_dim, ) if infeas.any(): - # set infeasible points to worse than worst objective - # across all samples + # set infeasible points to worse than worst objective across all samples + # Use clone() here to avoid deprecated `index_put_` on an expanded tensor + obj_vals = obj_vals.clone() obj_vals[infeas] = obj_vals.min() - 1 is_best = torch.argmax(obj_vals, dim=-1) idcs, counts = torch.unique(is_best, return_counts=True) if len(idcs) > max_points: - counts, order_idcs = torch.sort(counts, descending=True) + counts, order_idcs = torch.sort(counts, stable=True, descending=True) idcs = order_idcs[:max_points] return X[idcs] diff --git a/test/acquisition/multi_objective/test_utils.py b/test/acquisition/multi_objective/test_utils.py index acdfddbc95..793c936b61 100644 --- a/test/acquisition/multi_objective/test_utils.py +++ b/test/acquisition/multi_objective/test_utils.py @@ -46,7 +46,7 @@ def test_get_default_partitioning_alpha(self): class DummyMCMultiOutputObjective(MCMultiOutputObjective): - def forward(self, samples: Tensor) -> Tensor: + def forward(self, samples: Tensor, X: Tensor | None) -> Tensor: return samples @@ -130,13 +130,12 @@ def test_prune_inferior_points_multi_objective(self): X_pruned = prune_inferior_points_multi_objective( model=mm, X=X, ref_point=ref_point, max_frac=2 / 3 ) - if self.device.type == "cuda": - # sorting has different order on cuda - self.assertTrue( - torch.equal(X_pruned, X[[2, 1]]) or torch.equal(X_pruned, X[[1, 2]]) + self.assertTrue( + torch.equal( + torch.sort(X_pruned, stable=True).values, + torch.sort(X[:2], stable=True).values, ) - else: - self.assertTrue(torch.equal(X_pruned, X[:2])) + ) # test that zero-probability is in fact pruned samples[2, 0, 0] = 10 with mock.patch.object(MockPosterior, "rsample", return_value=samples): @@ -276,10 +275,7 @@ def test_random_search_optimizer(self): input_dim = 3 num_initial = 5 tkwargs = {"device": self.device} - optimizer_kwargs = { - "pop_size": 1000, - "max_tries": 5, - } + optimizer_kwargs = {"pop_size": 1000, "max_tries": 5} for ( dtype, @@ -350,10 +346,7 @@ def test_sample_optimal_points(self): input_dim = 3 num_initial = 5 tkwargs = {"device": self.device} - optimizer_kwargs = { - "pop_size": 100, - "max_tries": 1, - } + optimizer_kwargs = {"pop_size": 100, "max_tries": 1} num_samples = 2 num_points = 1 diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index c8a6484cca..61845a387a 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -270,11 +270,12 @@ def test_prune_inferior_points(self): with mock.patch.object(MockPosterior, "rsample", return_value=samples): mm = MockModel(MockPosterior(samples=samples)) X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3) - if self.device.type == "cuda": - # sorting has different order on cuda - self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0))) - else: - self.assertTrue(torch.equal(X_pruned, X[:2])) + self.assertTrue( + torch.equal( + torch.sort(X_pruned, stable=True).values, + torch.sort(X[:2], stable=True).values, + ) + ) # test that zero-probability is in fact pruned samples[2, 0, 0] = 10 with mock.patch.object(MockPosterior, "rsample", return_value=samples): @@ -289,11 +290,7 @@ def test_prune_inferior_points(self): device=self.device, dtype=dtype, ) - mm = MockModel( - MockPosterior( - samples=samples, - ) - ) + mm = MockModel(MockPosterior(samples=samples)) X_pruned = prune_inferior_points( model=mm, X=X,