diff --git a/ax/models/torch/botorch_modular/list_surrogate.py b/ax/models/torch/botorch_modular/list_surrogate.py index fb4c67de017..39602aab926 100644 --- a/ax/models/torch/botorch_modular/list_surrogate.py +++ b/ax/models/torch/botorch_modular/list_surrogate.py @@ -9,7 +9,6 @@ import inspect from typing import Any, Dict, List, Optional, Type -from ax.exceptions.core import UserInputError from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger @@ -18,6 +17,8 @@ from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform from botorch.utils.datasets import SupervisedDataset +from gpytorch.kernels import Kernel +from gpytorch.likelihoods.likelihood import Likelihood from gpytorch.mlls import ExactMarginalLogLikelihood from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood @@ -66,8 +67,10 @@ class ListSurrogate(Surrogate): mll_options: Dict[str, Any] submodel_outcome_transforms: Dict[str, OutcomeTransform] submodel_input_transforms: Dict[str, InputTransform] - # TODO: Allow passing down `covar_module_class`, `covar_module_options`, - # `likelihood_class`, and `likelihood_options`. + submodel_covar_module_class: Dict[str, Type[Kernel]] + submodel_covar_module_options: Dict[str, Dict[str, Any]] + submodel_likelihood_class: Dict[str, Type[Likelihood]] + submodel_likelihood_options: Dict[str, Dict[str, Any]] _model: Optional[Model] = None # Special setting for surrogates instantiated via `Surrogate.from_botorch`, # to avoid re-constructing the underlying BoTorch model on `Surrogate.fit` @@ -84,6 +87,10 @@ def __init__( mll_options: Optional[Dict[str, Any]] = None, submodel_outcome_transforms: Optional[Dict[str, OutcomeTransform]] = None, submodel_input_transforms: Optional[Dict[str, InputTransform]] = None, + submodel_covar_module_class: Optional[Dict[str, Type[Kernel]]] = None, + submodel_covar_module_options: Optional[Dict[str, Dict[str, Any]]] = None, + submodel_likelihood_class: Optional[Dict[str, Type[Likelihood]]] = None, + submodel_likelihood_options: Optional[Dict[str, Dict[str, Any]]] = None, ) -> None: if not bool(botorch_submodel_class_per_outcome) ^ bool(botorch_submodel_class): raise ValueError( # pragma: no cover @@ -99,6 +106,10 @@ def __init__( self.submodel_options = submodel_options or {} self.submodel_outcome_transforms = submodel_outcome_transforms or {} self.submodel_input_transforms = submodel_input_transforms or {} + self.submodel_covar_module_class = submodel_covar_module_class or {} + self.submodel_covar_module_options = submodel_covar_module_options or {} + self.submodel_likelihood_class = submodel_likelihood_class or {} + self.submodel_likelihood_options = submodel_likelihood_options or {} super().__init__( botorch_model_class=ModelListGP, mll_class=mll_class, @@ -175,21 +186,27 @@ def construct( # way to filter the arguments. See the comment in `Surrogate.construct` # regarding potential use of a `ModelFactory` in the future. model_cls_args = inspect.getfullargspec(model_cls).args + covar_module_class = self.submodel_covar_module_class.get(m) + covar_module_options = self.submodel_covar_module_options.get(m) + likelihood_class = self.submodel_likelihood_class.get(m) + likelihood_options = self.submodel_likelihood_options.get(m) outcome_transform = self.submodel_outcome_transforms.get(m) input_transform = self.submodel_input_transforms.get(m) - for input_name, input_obj in ( - ("outcome_transform", outcome_transform), - ("input_transform", input_transform), - ): - if input_obj is not None: - if input_name not in model_cls_args: - raise UserInputError( - f"The model class {model_cls} does not support an " - f"{input_name} argument." - ) - formatted_model_inputs[input_name] = input_obj + + self._set_formatted_inputs( + formatted_model_inputs=formatted_model_inputs, + inputs=[ + ["covar_module", covar_module_class, covar_module_options, None], + ["likelihood", likelihood_class, likelihood_options, None], + ["outcome_transform", None, None, outcome_transform], + ["input_transform", None, None, input_transform], + ], + dataset=dataset, + botorch_model_class_args=model_cls_args, + ) # pyre-ignore[45]: Py raises informative error if model is abstract. submodels.append(model_cls(**formatted_model_inputs)) + self._model = ModelListGP(*submodels) def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]: @@ -206,4 +223,8 @@ def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]: "mll_options": self.mll_options, "submodel_outcome_transforms": self.submodel_outcome_transforms, "submodel_input_transforms": self.submodel_input_transforms, + "submodel_covar_module_class": self.submodel_covar_module_class, + "submodel_covar_module_options": self.submodel_covar_module_options, + "submodel_likelihood_class": self.submodel_likelihood_class, + "submodel_likelihood_options": self.submodel_likelihood_options, } diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index fa905a42fdc..779ce2d8246 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -10,13 +10,14 @@ import inspect import warnings from logging import Logger -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type import torch from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata from ax.exceptions.core import AxWarning, UnsupportedError, UserInputError from ax.models.model_utils import best_in_sample_point +from ax.models.torch.botorch_modular.utils import fit_botorch_model from ax.models.torch.utils import ( _to_inequality_constraints, pick_best_out_of_sample_point_acqf_class, @@ -28,10 +29,7 @@ from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import checked_cast, checked_cast_optional, not_none -from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_model -from botorch.models import ModelListGP, SaasFullyBayesianSingleTaskGP -from botorch.models.gpytorch import GPyTorchModel -from botorch.models.model import Model, ModelList +from botorch.models.model import Model from botorch.models.pairwise_gp import PairwiseGP from botorch.models.transforms.input import InputTransform from botorch.models.transforms.outcome import OutcomeTransform @@ -42,7 +40,6 @@ from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from torch import Tensor - NOT_YET_FIT_MSG = ( "Underlying BoTorch `Model` has not yet received its training_data. " "Please fit the model first." @@ -52,27 +49,6 @@ logger: Logger = get_logger(__name__) -def fit_botorch_model( - model: Union[Model, ModelList, ModelListGP], - mll_class: Optional[Type[MarginalLogLikelihood]] = None, - mll_options: Optional[Dict[str, Any]] = None, -) -> None: - """Fit a BoTorch model.""" - models = model.models if isinstance(model, (ModelListGP, ModelList)) else [model] - for m in models: - # TODO: Support deterministic models when we support `ModelList` - if isinstance(m, SaasFullyBayesianSingleTaskGP): - fit_fully_bayesian_model_nuts(m, disable_progbar=True) - elif isinstance(m, GPyTorchModel) or isinstance(m, PairwiseGP): - mll_options = mll_options or {} - mll = not_none(mll_class)(likelihood=m.likelihood, model=m, **mll_options) - fit_gpytorch_model(mll) - else: - raise ValueError( - f"Model of type {m.__class__.__name__} is currently not supported." - ) - - class Surrogate(Base): """ **All classes in 'botorch_modular' directory are under @@ -244,18 +220,36 @@ def construct(self, datasets: List[SupervisedDataset], **kwargs: Any) -> None: self._training_data = [dataset] - # TODO: Can we warn if the elements of `input_constructor_kwargs` - # are not used? formatted_model_inputs = self.botorch_model_class.construct_inputs( training_data=dataset, **input_constructor_kwargs ) + self._set_formatted_inputs( + formatted_model_inputs=formatted_model_inputs, + inputs=[ + [ + "covar_module", + self.covar_module_class, + self.covar_module_options, + None, + ], + ["likelihood", self.likelihood_class, self.likelihood_options, None], + ["outcome_transform", None, None, self.outcome_transform], + ["input_transform", None, None, self.input_transform], + ], + dataset=dataset, + botorch_model_class_args=botorch_model_class_args, + ) + # pyre-ignore [45] + self._model = self.botorch_model_class(**formatted_model_inputs) - for input_name, input_class, input_options, input_object in ( - ("covar_module", self.covar_module_class, self.covar_module_options, None), - ("likelihood", self.likelihood_class, self.likelihood_options, None), - ("outcome_transform", None, None, self.outcome_transform), - ("input_transform", None, None, self.input_transform), - ): + def _set_formatted_inputs( + self, + formatted_model_inputs: Dict[str, Any], + inputs: List[List[Any]], + dataset: SupervisedDataset, + botorch_model_class_args: Any, + ) -> None: + for input_name, input_class, input_options, input_object in inputs: if input_class is None and input_object is None: continue if input_name not in botorch_model_class_args: @@ -271,14 +265,10 @@ def construct(self, datasets: List[SupervisedDataset], **kwargs: Any) -> None: raise RuntimeError(f"Got both a class and an object for {input_name}.") if input_class is not None: input_options = input_options or {} - # pyre-ignore [45] formatted_model_inputs[input_name] = input_class(**input_options) else: formatted_model_inputs[input_name] = input_object - # pyre-ignore [45] - self._model = self.botorch_model_class(**formatted_model_inputs) - def fit( self, datasets: List[SupervisedDataset], diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 3fb54d8f231..16e44462e87 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import warnings -from typing import Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch from ax.core.search_space import SearchSpaceDigest @@ -19,16 +19,20 @@ from botorch.acquisition.multi_objective.monte_carlo import ( qNoisyExpectedHypervolumeImprovement, ) +from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_model +from botorch.models import ModelListGP, SaasFullyBayesianSingleTaskGP from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP from botorch.models.gp_regression_fidelity import ( FixedNoiseMultiFidelityGP, SingleTaskMultiFidelityGP, ) from botorch.models.gp_regression_mixed import MixedSingleTaskGP -from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel -from botorch.models.model import Model +from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel, GPyTorchModel +from botorch.models.model import Model, ModelList from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP +from botorch.models.pairwise_gp import PairwiseGP from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset +from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from torch import Tensor @@ -251,3 +255,24 @@ def _get_shared_rows(Xs: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: same = (X_shared == X.unsqueeze(-2)).all(dim=-1).any(dim=-1) idcs_shared.append(torch.arange(same.shape[-1], device=X_shared.device)[same]) return X_shared, idcs_shared + + +def fit_botorch_model( + model: Union[Model, ModelList, ModelListGP], + mll_class: Type[MarginalLogLikelihood], + mll_options: Optional[Dict[str, Any]] = None, +) -> None: + """Fit a BoTorch model.""" + models = model.models if isinstance(model, (ModelListGP, ModelList)) else [model] + for m in models: + # TODO: Support deterministic models when we support `ModelList` + if isinstance(m, SaasFullyBayesianSingleTaskGP): + fit_fully_bayesian_model_nuts(m, disable_progbar=True) + elif isinstance(m, (GPyTorchModel, PairwiseGP)): + mll_options = mll_options or {} + mll = mll_class(likelihood=m.likelihood, model=m, **mll_options) + fit_gpytorch_model(mll) + else: + raise NotImplementedError( + f"Model of type {m.__class__.__name__} is currently not supported." + ) diff --git a/ax/models/torch/tests/test_list_surrogate.py b/ax/models/torch/tests/test_list_surrogate.py index d43d8ab2a33..97e783fc006 100644 --- a/ax/models/torch/tests/test_list_surrogate.py +++ b/ax/models/torch/tests/test_list_surrogate.py @@ -24,10 +24,16 @@ from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset -from gpytorch.mlls import ExactMarginalLogLikelihood - +from gpytorch.constraints import GreaterThan, Interval +from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel # noqa: F401 +from gpytorch.likelihoods import ( # noqa: F401 + GaussianLikelihood, + Likelihood, # noqa: F401 +) +from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood SURROGATE_PATH = f"{Surrogate.__module__}" +UTILS_PATH = f"{choose_model_class.__module__}" CURRENT_PATH = f"{__name__}" ACQUISITION_PATH = f"{Acquisition.__module__}" RANK = "rank" @@ -244,8 +250,8 @@ def test_construct_per_outcome_error_raises(self, mock_MTGP_construct_inputs): @patch(f"{CURRENT_PATH}.ModelListGP.load_state_dict", return_value=None) @patch(f"{CURRENT_PATH}.ExactMarginalLogLikelihood") - @patch(f"{SURROGATE_PATH}.fit_gpytorch_model") - @patch(f"{SURROGATE_PATH}.fit_fully_bayesian_model_nuts") + @patch(f"{UTILS_PATH}.fit_gpytorch_model") + @patch(f"{UTILS_PATH}.fit_fully_bayesian_model_nuts") def test_fit(self, mock_fit_nuts, mock_fit_gpytorch, mock_MLL, mock_state_dict): default_class = self.botorch_submodel_class_per_outcome surrogates = [ @@ -310,7 +316,7 @@ def test_fit(self, mock_fit_nuts, mock_fit_gpytorch, mock_MLL, mock_state_dict): ) # Fitting with unknown model should raise with self.assertRaisesRegex( - ValueError, + NotImplementedError, "Model of type GenericDeterministicModel is currently not supported.", ): fit_botorch_model( @@ -330,7 +336,7 @@ def test_with_botorch_transforms(self): submodel_outcome_transforms=outcome_transforms, submodel_input_transforms=input_transforms, ) - with self.assertRaisesRegex(UserInputError, "The model class"): + with self.assertRaisesRegex(UserInputError, "The BoTorch model class"): surrogate.construct( datasets=self.supervised_training_data, metric_names=self.outcomes, @@ -366,3 +372,46 @@ def test_serialize_attributes_as_kwargs(self): ): expected.pop(attr_name) self.assertEqual(self.surrogate._serialize_attributes_as_kwargs(), expected) + + def test_construct_custom_model(self): + noise_con1, noise_con2 = Interval(1e-6, 1e-1), GreaterThan(1e-4) + surrogate = ListSurrogate( + botorch_submodel_class=SingleTaskGP, + mll_class=LeaveOneOutPseudoLikelihood, + submodel_covar_module_class={ + "outcome_1": RBFKernel, + "outcome_2": MaternKernel, + }, + submodel_covar_module_options={ + "outcome_1": {"ard_num_dims": 1}, + "outcome_2": {"ard_num_dims": 3}, + }, + submodel_likelihood_class={ + "outcome_1": GaussianLikelihood, + "outcome_2": GaussianLikelihood, + }, + submodel_likelihood_options={ + "outcome_1": {"noise_constraint": noise_con1}, + "outcome_2": {"noise_constraint": noise_con2}, + }, + ) + surrogate.construct( + datasets=self.supervised_training_data, + metric_names=self.outcomes, + ) + self.assertEqual(len(surrogate._model.models), 2) + self.assertEqual(surrogate.mll_class, LeaveOneOutPseudoLikelihood) + for i, m in enumerate(surrogate._model.models): + self.assertEqual(type(m.likelihood), GaussianLikelihood) + if i == 0: + self.assertEqual(type(m.covar_module), RBFKernel) + self.assertEqual(m.covar_module.ard_num_dims, 1) + self.assertEqual( + m.likelihood.noise_covar.raw_noise_constraint, noise_con1 + ) + else: + self.assertEqual(type(m.covar_module), MaternKernel) + self.assertEqual(m.covar_module.ard_num_dims, 3) + self.assertEqual( + m.likelihood.noise_covar.raw_noise_constraint, noise_con2 + ) diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 6fd3dda7b09..efcdf540d41 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -12,6 +12,7 @@ from ax.exceptions.core import UnsupportedError, UserInputError from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.models.torch.botorch_modular.utils import fit_botorch_model from ax.models.torch_base import TorchOptConfig from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase @@ -29,7 +30,7 @@ from gpytorch.likelihoods import ( # noqa: F401 FixedNoiseGaussianLikelihood, GaussianLikelihood, - Likelihood, + Likelihood, # noqa: F401 ) from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood from torch import Tensor @@ -38,6 +39,7 @@ ACQUISITION_PATH = f"{Acquisition.__module__}" CURRENT_PATH = f"{__name__}" SURROGATE_PATH = f"{Surrogate.__module__}" +UTILS_PATH = f"{fit_botorch_model.__module__}" class SingleTaskGPWithDifferentConstructor(SingleTaskGP): @@ -93,7 +95,7 @@ def test_init(self, mock_Likelihood, mock_Kernel): self.assertEqual(surrogate.botorch_model_class, botorch_model_class) self.assertEqual(surrogate.mll_class, self.mll_class) - @patch(f"{SURROGATE_PATH}.fit_gpytorch_model") + @patch(f"{UTILS_PATH}.fit_gpytorch_model") def test_mll_options(self, _): mock_mll = MagicMock(self.mll_class) surrogate = Surrogate( @@ -259,8 +261,8 @@ def test_construct_custom_model(self): self.assertEqual(surrogate._model.covar_module.ard_num_dims, 1) @patch(f"{CURRENT_PATH}.SingleTaskGP.load_state_dict", return_value=None) - @patch(f"{SURROGATE_PATH}.fit_fully_bayesian_model_nuts") - @patch(f"{SURROGATE_PATH}.fit_gpytorch_model") + @patch(f"{UTILS_PATH}.fit_fully_bayesian_model_nuts") + @patch(f"{UTILS_PATH}.fit_gpytorch_model") @patch(f"{CURRENT_PATH}.ExactMarginalLogLikelihood") def test_fit(self, mock_MLL, mock_fit_gpytorch, mock_fit_saas, mock_state_dict): for mock_fit, botorch_model_class in zip( @@ -412,8 +414,8 @@ def test_best_out_of_sample_point( self.assertTrue(torch.equal(acqf_value, torch.tensor([1.0]))) @patch(f"{CURRENT_PATH}.SingleTaskGP.load_state_dict", return_value=None) - @patch(f"{SURROGATE_PATH}.fit_fully_bayesian_model_nuts") - @patch(f"{SURROGATE_PATH}.fit_gpytorch_model") + @patch(f"{UTILS_PATH}.fit_fully_bayesian_model_nuts") + @patch(f"{UTILS_PATH}.fit_gpytorch_model") @patch(f"{CURRENT_PATH}.ExactMarginalLogLikelihood") def test_update(self, mock_MLL, mock_fit_gpytorch, mock_fit_saas, mock_state_dict): for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: