Skip to content

Commit

Permalink
Change ListSurrogate to allow passing down covar_module and likelihood
Browse files Browse the repository at this point in the history
Summary: See title.

Reviewed By: Balandat

Differential Revision: D37330410

fbshipit-source-id: 544413cb35da60512c2c8defb371a652fc2a66a0
  • Loading branch information
dme65 authored and facebook-github-bot committed Jul 18, 2022
1 parent 798efdd commit aad4cda
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 68 deletions.
49 changes: 35 additions & 14 deletions ax/models/torch/botorch_modular/list_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import inspect
from typing import Any, Dict, List, Optional, Type

from ax.exceptions.core import UserInputError
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
Expand All @@ -18,6 +17,8 @@
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.utils.datasets import SupervisedDataset
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood

Expand Down Expand Up @@ -66,8 +67,10 @@ class ListSurrogate(Surrogate):
mll_options: Dict[str, Any]
submodel_outcome_transforms: Dict[str, OutcomeTransform]
submodel_input_transforms: Dict[str, InputTransform]
# TODO: Allow passing down `covar_module_class`, `covar_module_options`,
# `likelihood_class`, and `likelihood_options`.
submodel_covar_module_class: Dict[str, Type[Kernel]]
submodel_covar_module_options: Dict[str, Dict[str, Any]]
submodel_likelihood_class: Dict[str, Type[Likelihood]]
submodel_likelihood_options: Dict[str, Dict[str, Any]]
_model: Optional[Model] = None
# Special setting for surrogates instantiated via `Surrogate.from_botorch`,
# to avoid re-constructing the underlying BoTorch model on `Surrogate.fit`
Expand All @@ -84,6 +87,10 @@ def __init__(
mll_options: Optional[Dict[str, Any]] = None,
submodel_outcome_transforms: Optional[Dict[str, OutcomeTransform]] = None,
submodel_input_transforms: Optional[Dict[str, InputTransform]] = None,
submodel_covar_module_class: Optional[Dict[str, Type[Kernel]]] = None,
submodel_covar_module_options: Optional[Dict[str, Dict[str, Any]]] = None,
submodel_likelihood_class: Optional[Dict[str, Type[Likelihood]]] = None,
submodel_likelihood_options: Optional[Dict[str, Dict[str, Any]]] = None,
) -> None:
if not bool(botorch_submodel_class_per_outcome) ^ bool(botorch_submodel_class):
raise ValueError( # pragma: no cover
Expand All @@ -99,6 +106,10 @@ def __init__(
self.submodel_options = submodel_options or {}
self.submodel_outcome_transforms = submodel_outcome_transforms or {}
self.submodel_input_transforms = submodel_input_transforms or {}
self.submodel_covar_module_class = submodel_covar_module_class or {}
self.submodel_covar_module_options = submodel_covar_module_options or {}
self.submodel_likelihood_class = submodel_likelihood_class or {}
self.submodel_likelihood_options = submodel_likelihood_options or {}
super().__init__(
botorch_model_class=ModelListGP,
mll_class=mll_class,
Expand Down Expand Up @@ -175,21 +186,27 @@ def construct(
# way to filter the arguments. See the comment in `Surrogate.construct`
# regarding potential use of a `ModelFactory` in the future.
model_cls_args = inspect.getfullargspec(model_cls).args
covar_module_class = self.submodel_covar_module_class.get(m)
covar_module_options = self.submodel_covar_module_options.get(m)
likelihood_class = self.submodel_likelihood_class.get(m)
likelihood_options = self.submodel_likelihood_options.get(m)
outcome_transform = self.submodel_outcome_transforms.get(m)
input_transform = self.submodel_input_transforms.get(m)
for input_name, input_obj in (
("outcome_transform", outcome_transform),
("input_transform", input_transform),
):
if input_obj is not None:
if input_name not in model_cls_args:
raise UserInputError(
f"The model class {model_cls} does not support an "
f"{input_name} argument."
)
formatted_model_inputs[input_name] = input_obj

self._set_formatted_inputs(
formatted_model_inputs=formatted_model_inputs,
inputs=[
["covar_module", covar_module_class, covar_module_options, None],
["likelihood", likelihood_class, likelihood_options, None],
["outcome_transform", None, None, outcome_transform],
["input_transform", None, None, input_transform],
],
dataset=dataset,
botorch_model_class_args=model_cls_args,
)
# pyre-ignore[45]: Py raises informative error if model is abstract.
submodels.append(model_cls(**formatted_model_inputs))

self._model = ModelListGP(*submodels)

def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]:
Expand All @@ -206,4 +223,8 @@ def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]:
"mll_options": self.mll_options,
"submodel_outcome_transforms": self.submodel_outcome_transforms,
"submodel_input_transforms": self.submodel_input_transforms,
"submodel_covar_module_class": self.submodel_covar_module_class,
"submodel_covar_module_options": self.submodel_covar_module_options,
"submodel_likelihood_class": self.submodel_likelihood_class,
"submodel_likelihood_options": self.submodel_likelihood_options,
}
68 changes: 29 additions & 39 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import inspect
import warnings
from logging import Logger
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.exceptions.core import AxWarning, UnsupportedError, UserInputError
from ax.models.model_utils import best_in_sample_point
from ax.models.torch.botorch_modular.utils import fit_botorch_model
from ax.models.torch.utils import (
_to_inequality_constraints,
pick_best_out_of_sample_point_acqf_class,
Expand All @@ -28,10 +29,7 @@
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast, checked_cast_optional, not_none
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_model
from botorch.models import ModelListGP, SaasFullyBayesianSingleTaskGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import Model, ModelList
from botorch.models.model import Model
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
Expand All @@ -42,7 +40,6 @@
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor


NOT_YET_FIT_MSG = (
"Underlying BoTorch `Model` has not yet received its training_data. "
"Please fit the model first."
Expand All @@ -52,27 +49,6 @@
logger: Logger = get_logger(__name__)


def fit_botorch_model(
model: Union[Model, ModelList, ModelListGP],
mll_class: Optional[Type[MarginalLogLikelihood]] = None,
mll_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Fit a BoTorch model."""
models = model.models if isinstance(model, (ModelListGP, ModelList)) else [model]
for m in models:
# TODO: Support deterministic models when we support `ModelList`
if isinstance(m, SaasFullyBayesianSingleTaskGP):
fit_fully_bayesian_model_nuts(m, disable_progbar=True)
elif isinstance(m, GPyTorchModel) or isinstance(m, PairwiseGP):
mll_options = mll_options or {}
mll = not_none(mll_class)(likelihood=m.likelihood, model=m, **mll_options)
fit_gpytorch_model(mll)
else:
raise ValueError(
f"Model of type {m.__class__.__name__} is currently not supported."
)


class Surrogate(Base):
"""
**All classes in 'botorch_modular' directory are under
Expand Down Expand Up @@ -244,18 +220,36 @@ def construct(self, datasets: List[SupervisedDataset], **kwargs: Any) -> None:

self._training_data = [dataset]

# TODO: Can we warn if the elements of `input_constructor_kwargs`
# are not used?
formatted_model_inputs = self.botorch_model_class.construct_inputs(
training_data=dataset, **input_constructor_kwargs
)
self._set_formatted_inputs(
formatted_model_inputs=formatted_model_inputs,
inputs=[
[
"covar_module",
self.covar_module_class,
self.covar_module_options,
None,
],
["likelihood", self.likelihood_class, self.likelihood_options, None],
["outcome_transform", None, None, self.outcome_transform],
["input_transform", None, None, self.input_transform],
],
dataset=dataset,
botorch_model_class_args=botorch_model_class_args,
)
# pyre-ignore [45]
self._model = self.botorch_model_class(**formatted_model_inputs)

for input_name, input_class, input_options, input_object in (
("covar_module", self.covar_module_class, self.covar_module_options, None),
("likelihood", self.likelihood_class, self.likelihood_options, None),
("outcome_transform", None, None, self.outcome_transform),
("input_transform", None, None, self.input_transform),
):
def _set_formatted_inputs(
self,
formatted_model_inputs: Dict[str, Any],
inputs: List[List[Any]],
dataset: SupervisedDataset,
botorch_model_class_args: Any,
) -> None:
for input_name, input_class, input_options, input_object in inputs:
if input_class is None and input_object is None:
continue
if input_name not in botorch_model_class_args:
Expand All @@ -271,14 +265,10 @@ def construct(self, datasets: List[SupervisedDataset], **kwargs: Any) -> None:
raise RuntimeError(f"Got both a class and an object for {input_name}.")
if input_class is not None:
input_options = input_options or {}
# pyre-ignore [45]
formatted_model_inputs[input_name] = input_class(**input_options)
else:
formatted_model_inputs[input_name] = input_object

# pyre-ignore [45]
self._model = self.botorch_model_class(**formatted_model_inputs)

def fit(
self,
datasets: List[SupervisedDataset],
Expand Down
31 changes: 28 additions & 3 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
from ax.core.search_space import SearchSpaceDigest
Expand All @@ -19,16 +19,20 @@
from botorch.acquisition.multi_objective.monte_carlo import (
qNoisyExpectedHypervolumeImprovement,
)
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_model
from botorch.models import ModelListGP, SaasFullyBayesianSingleTaskGP
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
from botorch.models.gp_regression_fidelity import (
FixedNoiseMultiFidelityGP,
SingleTaskMultiFidelityGP,
)
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.model import Model
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel, GPyTorchModel
from botorch.models.model import Model, ModelList
from botorch.models.multitask import FixedNoiseMultiTaskGP, MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from torch import Tensor


Expand Down Expand Up @@ -251,3 +255,24 @@ def _get_shared_rows(Xs: List[Tensor]) -> Tuple[Tensor, List[Tensor]]:
same = (X_shared == X.unsqueeze(-2)).all(dim=-1).any(dim=-1)
idcs_shared.append(torch.arange(same.shape[-1], device=X_shared.device)[same])
return X_shared, idcs_shared


def fit_botorch_model(
model: Union[Model, ModelList, ModelListGP],
mll_class: Type[MarginalLogLikelihood],
mll_options: Optional[Dict[str, Any]] = None,
) -> None:
"""Fit a BoTorch model."""
models = model.models if isinstance(model, (ModelListGP, ModelList)) else [model]
for m in models:
# TODO: Support deterministic models when we support `ModelList`
if isinstance(m, SaasFullyBayesianSingleTaskGP):
fit_fully_bayesian_model_nuts(m, disable_progbar=True)
elif isinstance(m, (GPyTorchModel, PairwiseGP)):
mll_options = mll_options or {}
mll = mll_class(likelihood=m.likelihood, model=m, **mll_options)
fit_gpytorch_model(mll)
else:
raise NotImplementedError(
f"Model of type {m.__class__.__name__} is currently not supported."
)
61 changes: 55 additions & 6 deletions ax/models/torch/tests/test_list_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.utils.datasets import FixedNoiseDataset, SupervisedDataset
from gpytorch.mlls import ExactMarginalLogLikelihood

from gpytorch.constraints import GreaterThan, Interval
from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel # noqa: F401
from gpytorch.likelihoods import ( # noqa: F401
GaussianLikelihood,
Likelihood, # noqa: F401
)
from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood

SURROGATE_PATH = f"{Surrogate.__module__}"
UTILS_PATH = f"{choose_model_class.__module__}"
CURRENT_PATH = f"{__name__}"
ACQUISITION_PATH = f"{Acquisition.__module__}"
RANK = "rank"
Expand Down Expand Up @@ -244,8 +250,8 @@ def test_construct_per_outcome_error_raises(self, mock_MTGP_construct_inputs):

@patch(f"{CURRENT_PATH}.ModelListGP.load_state_dict", return_value=None)
@patch(f"{CURRENT_PATH}.ExactMarginalLogLikelihood")
@patch(f"{SURROGATE_PATH}.fit_gpytorch_model")
@patch(f"{SURROGATE_PATH}.fit_fully_bayesian_model_nuts")
@patch(f"{UTILS_PATH}.fit_gpytorch_model")
@patch(f"{UTILS_PATH}.fit_fully_bayesian_model_nuts")
def test_fit(self, mock_fit_nuts, mock_fit_gpytorch, mock_MLL, mock_state_dict):
default_class = self.botorch_submodel_class_per_outcome
surrogates = [
Expand Down Expand Up @@ -310,7 +316,7 @@ def test_fit(self, mock_fit_nuts, mock_fit_gpytorch, mock_MLL, mock_state_dict):
)
# Fitting with unknown model should raise
with self.assertRaisesRegex(
ValueError,
NotImplementedError,
"Model of type GenericDeterministicModel is currently not supported.",
):
fit_botorch_model(
Expand All @@ -330,7 +336,7 @@ def test_with_botorch_transforms(self):
submodel_outcome_transforms=outcome_transforms,
submodel_input_transforms=input_transforms,
)
with self.assertRaisesRegex(UserInputError, "The model class"):
with self.assertRaisesRegex(UserInputError, "The BoTorch model class"):
surrogate.construct(
datasets=self.supervised_training_data,
metric_names=self.outcomes,
Expand Down Expand Up @@ -366,3 +372,46 @@ def test_serialize_attributes_as_kwargs(self):
):
expected.pop(attr_name)
self.assertEqual(self.surrogate._serialize_attributes_as_kwargs(), expected)

def test_construct_custom_model(self):
noise_con1, noise_con2 = Interval(1e-6, 1e-1), GreaterThan(1e-4)
surrogate = ListSurrogate(
botorch_submodel_class=SingleTaskGP,
mll_class=LeaveOneOutPseudoLikelihood,
submodel_covar_module_class={
"outcome_1": RBFKernel,
"outcome_2": MaternKernel,
},
submodel_covar_module_options={
"outcome_1": {"ard_num_dims": 1},
"outcome_2": {"ard_num_dims": 3},
},
submodel_likelihood_class={
"outcome_1": GaussianLikelihood,
"outcome_2": GaussianLikelihood,
},
submodel_likelihood_options={
"outcome_1": {"noise_constraint": noise_con1},
"outcome_2": {"noise_constraint": noise_con2},
},
)
surrogate.construct(
datasets=self.supervised_training_data,
metric_names=self.outcomes,
)
self.assertEqual(len(surrogate._model.models), 2)
self.assertEqual(surrogate.mll_class, LeaveOneOutPseudoLikelihood)
for i, m in enumerate(surrogate._model.models):
self.assertEqual(type(m.likelihood), GaussianLikelihood)
if i == 0:
self.assertEqual(type(m.covar_module), RBFKernel)
self.assertEqual(m.covar_module.ard_num_dims, 1)
self.assertEqual(
m.likelihood.noise_covar.raw_noise_constraint, noise_con1
)
else:
self.assertEqual(type(m.covar_module), MaternKernel)
self.assertEqual(m.covar_module.ard_num_dims, 3)
self.assertEqual(
m.likelihood.noise_covar.raw_noise_constraint, noise_con2
)
Loading

0 comments on commit aad4cda

Please sign in to comment.