Skip to content

Commit 70876f5

Browse files
Add BoTorch_Modular to config, giving custom GP and acq func ability
Now users can customize their GP (kernel, MLL, GP class) as well as the acq func in the config without writing any code. Giving users a lot more control and power from any language with little setup
1 parent 0f8df06 commit 70876f5

File tree

6 files changed

+186
-3
lines changed

6 files changed

+186
-3
lines changed

boa/config/converters.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55

66
import ax.early_stopping.strategies as early_stopping_strats
77
import ax.global_stopping.strategies as global_stopping_strats
8+
import botorch.acquisition
9+
import botorch.models
10+
import gpytorch.kernels
11+
import gpytorch.mlls
812
from ax.modelbridge.generation_node import GenerationStep
913
from ax.modelbridge.registry import Models
14+
from ax.models.torch.botorch_modular.surrogate import Surrogate
1015
from ax.service.utils.instantiation import TParameterRepresentation
1116
from ax.service.utils.scheduler_options import SchedulerOptions
1217

@@ -49,6 +54,36 @@ def _gen_strat_converter(gs: Optional[dict] = None) -> dict:
4954
gs["steps"][i] = step
5055
steps.append(step)
5156
continue
57+
if "model_kwargs" in step:
58+
if "botorch_acqf_class" in step["model_kwargs"] and not isinstance(
59+
step["model_kwargs"]["botorch_acqf_class"], botorch.acquisition.AcquisitionFunction
60+
):
61+
step["model_kwargs"]["botorch_acqf_class"] = getattr(
62+
botorch.acquisition, step["model_kwargs"]["botorch_acqf_class"]
63+
)
64+
65+
if "surrogate" in step["model_kwargs"]:
66+
if "mll_class" in step["model_kwargs"]["surrogate"] and not isinstance(
67+
step["model_kwargs"]["surrogate"]["mll_class"], gpytorch.mlls.MarginalLogLikelihood
68+
):
69+
step["model_kwargs"]["surrogate"]["mll_class"] = getattr(
70+
gpytorch.mlls, step["model_kwargs"]["surrogate"]["mll_class"]
71+
)
72+
if "botorch_model_class" in step["model_kwargs"]["surrogate"] and not isinstance(
73+
step["model_kwargs"]["surrogate"]["botorch_model_class"], botorch.models.model.Model
74+
):
75+
step["model_kwargs"]["surrogate"]["botorch_model_class"] = getattr(
76+
botorch.models, step["model_kwargs"]["surrogate"]["botorch_model_class"]
77+
)
78+
if "covar_module_class" in step["model_kwargs"]["surrogate"] and not isinstance(
79+
step["model_kwargs"]["surrogate"]["covar_module_class"], gpytorch.kernels.Kernel
80+
):
81+
step["model_kwargs"]["surrogate"]["covar_module_class"] = getattr(
82+
gpytorch.kernels, step["model_kwargs"]["surrogate"]["covar_module_class"]
83+
)
84+
85+
step["model_kwargs"]["surrogate"] = Surrogate(**step["model_kwargs"]["surrogate"])
86+
5287
try:
5388
step["model"] = Models[step["model"]]
5489
except KeyError:

boa/registry.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
1-
from ax.storage.json_store.registry import CORE_DECODER_REGISTRY, CORE_ENCODER_REGISTRY
1+
from __future__ import annotations
2+
3+
from typing import Type
4+
5+
import botorch.acquisition
6+
import gpytorch.kernels
7+
from ax.storage.botorch_modular_registry import (
8+
ACQUISITION_FUNCTION_REGISTRY,
9+
CLASS_TO_REGISTRY,
10+
CLASS_TO_REVERSE_REGISTRY,
11+
)
12+
from ax.storage.json_store.registry import (
13+
CORE_CLASS_DECODER_REGISTRY,
14+
CORE_CLASS_ENCODER_REGISTRY,
15+
CORE_DECODER_REGISTRY,
16+
CORE_ENCODER_REGISTRY,
17+
botorch_modular_to_dict,
18+
class_from_json,
19+
)
220

321

422
def config_to_dict(inst):
@@ -15,3 +33,25 @@ def _add_common_encodes_and_decodes():
1533
CORE_ENCODER_REGISTRY[BOAConfig] = config_to_dict
1634
# CORE_DECODER_REGISTRY[BOAConfig.__name__] = BOAConfig
1735
CORE_DECODER_REGISTRY[MetricType.__name__] = MetricType
36+
37+
CORE_CLASS_DECODER_REGISTRY["Type[Kernel]"] = class_from_json
38+
CORE_CLASS_ENCODER_REGISTRY[gpytorch.kernels.Kernel] = botorch_modular_to_dict
39+
40+
KERNEL_REGISTRY = {getattr(gpytorch.kernels, kernel): kernel for kernel in gpytorch.kernels.__all__}
41+
42+
REVERSE_KERNEL_REGISTRY: dict[str, Type[gpytorch.kernels.Kernel]] = {v: k for k, v in KERNEL_REGISTRY.items()}
43+
44+
CLASS_TO_REGISTRY[gpytorch.kernels.Kernel] = KERNEL_REGISTRY
45+
CLASS_TO_REVERSE_REGISTRY[gpytorch.kernels.Kernel] = REVERSE_KERNEL_REGISTRY
46+
47+
for acq_func_name in botorch.acquisition.__all__:
48+
acq_func = getattr(botorch.acquisition, acq_func_name)
49+
if acq_func not in ACQUISITION_FUNCTION_REGISTRY:
50+
ACQUISITION_FUNCTION_REGISTRY[acq_func] = acq_func_name
51+
52+
REVERSE_ACQUISITION_FUNCTION_REGISTRY: dict[str, Type[botorch.acquisition.AcquisitionFunction]] = {
53+
v: k for k, v in ACQUISITION_FUNCTION_REGISTRY.items()
54+
}
55+
56+
CLASS_TO_REGISTRY[botorch.acquisition.AcquisitionFunction] = ACQUISITION_FUNCTION_REGISTRY
57+
CLASS_TO_REVERSE_REGISTRY[botorch.acquisition.AcquisitionFunction] = REVERSE_ACQUISITION_FUNCTION_REGISTRY

tests/1unit_tests/test_generation_strategy.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import botorch.acquisition
2+
import botorch.models
3+
import gpytorch.kernels
4+
import gpytorch.mlls
15
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
26
from ax.modelbridge.registry import Models
37

@@ -41,3 +45,32 @@ def test_auto_gen_use_saasbo(saasbo_config, tmp_path):
4145
assert "SAASBO" in gs.name
4246
else:
4347
assert "FullyBayesian" in gs.name
48+
49+
50+
def test_modular_botorch(gen_strat_modular_botorch_config, tmp_path):
51+
controller = Controller(
52+
config=gen_strat_modular_botorch_config,
53+
wrapper=ScriptWrapper(config=gen_strat_modular_botorch_config, experiment_dir=tmp_path),
54+
)
55+
exp = get_experiment(
56+
config=controller.config, runner=WrappedJobRunner(wrapper=controller.wrapper), wrapper=controller.wrapper
57+
)
58+
gs = get_generation_strategy(config=controller.config, experiment=exp)
59+
cfg_botorch_modular = gen_strat_modular_botorch_config.orig_config["generation_strategy"]["steps"][-1]
60+
step = gs._steps[-1]
61+
assert step.model == Models.BOTORCH_MODULAR
62+
mdl_kw = step.model_kwargs
63+
assert mdl_kw["botorch_acqf_class"] == getattr(
64+
botorch.acquisition, cfg_botorch_modular["model_kwargs"]["botorch_acqf_class"]
65+
)
66+
assert mdl_kw["acquisition_options"] == cfg_botorch_modular["model_kwargs"]["acquisition_options"]
67+
68+
assert mdl_kw["surrogate"].mll_class == getattr(
69+
gpytorch.mlls, cfg_botorch_modular["model_kwargs"]["surrogate"]["mll_class"]
70+
)
71+
assert mdl_kw["surrogate"].botorch_model_class == getattr(
72+
botorch.models, cfg_botorch_modular["model_kwargs"]["surrogate"]["botorch_model_class"]
73+
)
74+
assert mdl_kw["surrogate"].covar_module_class == getattr(
75+
gpytorch.kernels, cfg_botorch_modular["model_kwargs"]["surrogate"]["covar_module_class"]
76+
)

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def gen_strat1_config():
8585
return BOAConfig.from_jsonlike(file=config_path)
8686

8787

88+
@pytest.fixture
89+
def gen_strat_modular_botorch_config():
90+
config_path = TEST_DIR / f"scripts/other_langs/r_package_streamlined/config_modular_botorch.yaml"
91+
return BOAConfig.from_jsonlike(file=config_path)
92+
93+
8894
@pytest.fixture
8995
def synth_config():
9096
config_path = TEST_CONFIG_DIR / "test_config_synth.yaml"
@@ -233,3 +239,9 @@ def r_streamlined(tmp_path_factory, cd_to_root_and_back_session):
233239
config_path = TEST_DIR / f"scripts/other_langs/r_package_streamlined/config.yaml"
234240

235241
yield cli_main(split_shell_command(f"--config-path {config_path} -td"), standalone_mode=False)
242+
243+
244+
@pytest.fixture(scope="session")
245+
def r_streamlined_botorch_modular(tmp_path_factory, cd_to_root_and_back_session):
246+
config_path = TEST_DIR / f"scripts/other_langs/r_package_streamlined/config_modular_botorch.yaml"
247+
return cli_main(split_shell_command(f"--config-path {config_path} -td"), standalone_mode=False)

tests/integration_tests/test_cli.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,12 @@ def test_calling_command_line_test_script_doesnt_error_out_and_produces_correct_
7575

7676
# parametrize the test to use the full version (all scripts) or the light version (only run_model.R)
7777
# or parametrize the test to use the streamlined version (doesn't use trial_status.json, only use output.json)
78+
# the botorch modular version is the same as the streamlined version, but also uses botorch modular
79+
# which uses a custom kernel, acquisition function, mll and botorch model class
80+
# (which can customize the GP process even more)
7881
@pytest.mark.parametrize(
7982
"r_scripts_run",
80-
["r_full", "r_light", "r_streamlined"],
83+
["r_full", "r_light", "r_streamlined", "r_streamlined_botorch_modular"],
8184
)
8285
@pytest.mark.skipif(not R_INSTALLED, reason="requires R to be installed")
8386
def test_calling_command_line_r_test_scripts(r_scripts_run, request):
@@ -92,7 +95,7 @@ def test_calling_command_line_r_test_scripts(r_scripts_run, request):
9295
assert "param_names" in data
9396
assert "metric_properties" in data
9497

95-
if "r_streamlined" == r_scripts_run:
98+
if r_scripts_run in ("r_streamlined", "r_streamlined_botorch_modular"):
9699
with cd_and_cd_back(scheduler.wrapper.config_path.parent):
97100

98101
pre_num_trials = len(scheduler.experiment.trials)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
objective:
2+
metrics:
3+
- name: metric
4+
scheduler:
5+
n_trials: 15
6+
7+
parameters:
8+
x0:
9+
'bounds': [ 0, 1 ]
10+
'type': 'range'
11+
'value_type': 'float'
12+
x1:
13+
'bounds': [ 0, 1]
14+
'type': 'range'
15+
'value_type': 'float'
16+
x2:
17+
'bounds': [ 0, 1 ]
18+
'type': 'range'
19+
'value_type': 'float'
20+
x3:
21+
'bounds': [ 0, 1]
22+
'type': 'range'
23+
'value_type': 'float'
24+
x4:
25+
'bounds': [ 0, 1 ]
26+
'type': 'range'
27+
'value_type': 'float'
28+
x5:
29+
'bounds': [ 0, 1]
30+
'type': 'range'
31+
'value_type': 'float'
32+
33+
script_options:
34+
# notice here that this is a shell command
35+
# this is what BOA will do to launch your script
36+
# it will also pass as a command line argument the current trial directory
37+
# that is being parameterized
38+
39+
# This can either be a relative path or absolute path
40+
# (by default when BOA launches from a config file
41+
# it uses the config file directory as your working directory)
42+
# here config.yaml and run_model.R are in the same directory
43+
run_model: Rscript run_model.R
44+
exp_name: "r_streamlined_botorch_modular"
45+
46+
generation_strategy:
47+
steps:
48+
- model: SOBOL
49+
num_trials: 5
50+
- model: BOTORCH_MODULAR
51+
num_trials: -1 # No limitation on how many trials should be produced from this step
52+
model_kwargs:
53+
surrogate:
54+
botorch_model_class: SingleTaskGP # BoTorch model class name
55+
56+
covar_module_class: RBFKernel # GPyTorch kernel class name
57+
mll_class: LeaveOneOutPseudoLikelihood
58+
botorch_acqf_class: qUpperConfidenceBound # BoTorch acquisition function class name
59+
acquisition_options:
60+
beta: 0.5

0 commit comments

Comments
 (0)