Skip to content

Commit

Permalink
Update mixed-discrete Krg instantiation: switch to Gower
Browse files Browse the repository at this point in the history
  • Loading branch information
jbussemaker committed Sep 24, 2023
1 parent fb52ffa commit 2fcf5df
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 8 deletions.
60 changes: 56 additions & 4 deletions sb_arch_opt/algo/arch_sbo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,28 @@ def get_kriging_model(multi=True, kpls_n_comp: int = None, **kwargs):
def get_md_kriging_model(self, kpls_n_comp: int = None, multi=True, **kwargs_) -> Tuple['SurrogateModel', Normalization]:
check_dependencies()
normalization = self.get_md_normalization()
norm_ds_spec = self.create_smt_design_space_spec(self.problem.design_space, md_normalize=True)
design_space = self.problem.design_space
norm_ds_spec = self.create_smt_design_space_spec(design_space, md_normalize=True)

kwargs = dict(
print_global=False,
design_space=norm_ds_spec.design_space,
categorical_kernel=MixIntKernelType.EXP_HOMO_HSPHERE,
categorical_kernel=MixIntKernelType.GOWER,
hierarchical_kernel=MixHrcKernelType.ALG_KERNEL,
)
if norm_ds_spec.is_mixed_discrete:
kwargs['n_start'] = kwargs.get('n_start', 5)
kwargs.update(kwargs_)

# Disable KPLS if the nr of requested components is too high
if kpls_n_comp is not None:
n_dim_apply_pls = design_space.n_var

# PLS is not applied to categorical variables for EHH/HH kernels (see KrgBased._matrix_data_corr)
if IS_SMT_21 and kwargs['categorical_kernel'] not in [MixIntKernelType.CONT_RELAX, MixIntKernelType.GOWER]:
n_dim_apply_pls = design_space.n_var - np.sum(design_space.is_cat_mask)

if kpls_n_comp > n_dim_apply_pls:
kpls_n_comp = None

if kpls_n_comp is not None:
if not IS_SMT_21:
kwargs['categorical_kernel'] = MixIntKernelType.CONT_RELAX
Expand All @@ -198,6 +208,48 @@ def get_md_kriging_model(self, kpls_n_comp: int = None, multi=True, **kwargs_) -

return surrogate, normalization

@staticmethod
def get_n_theta(problem: ArchOptProblemBase, surrogate: 'SurrogateModel') -> int:

def _get_n_theta(model: 'SurrogateModel') -> int:
if isinstance(model, KrgBased):
if hasattr(model, 'optimal_theta') and len(model.optimal_theta):
return len(model.optimal_theta)

n_train = 2
if isinstance(model, KPLS):
n_train = model.options['n_comp']+1
n_theta = 0

def _override(theta):
nonlocal n_theta
# No need to actually train the model: we only want to know how many hyperparams we have
n_theta = len(theta)
raise RuntimeError

model = copy.deepcopy(model)
model.options['n_start'] = 1
model.set_training_values(np.zeros((n_train, problem.n_var)), np.zeros((n_train, 1)))
model._reduced_likelihood_function = _override
try:
model.train()
except RuntimeError:
pass
return n_theta

raise RuntimeError(f'Not a Kriging model: {surrogate!r}')

if isinstance(surrogate, MultiSurrogateModel):
if len(surrogate._models) == 0:
n_single = _get_n_theta(surrogate._surrogate)
else:
n_single = _get_n_theta(surrogate._models[0])

ny = problem.n_obj + problem.n_ieq_constr
return n_single * ny

return _get_n_theta(surrogate)


class SBArchOptDesignSpace(BaseDesignSpace):
"""SMT design space implementation using SBArchOpt's design space logic"""
Expand Down
9 changes: 6 additions & 3 deletions sb_arch_opt/problems/problems_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy as np
from typing import Optional, Tuple
from pymoo.core.problem import Problem
from pymoo.core.variable import Real, Integer
from pymoo.core.variable import Real, Integer, Choice
from sb_arch_opt.problem import ArchOptProblemBase
from sb_arch_opt.pareto_front import CachedParetoFrontMixin
from sb_arch_opt.sampling import HierarchicalExhaustiveSampling
Expand Down Expand Up @@ -148,7 +148,7 @@ class MixedDiscretizerProblemBase(NoHierarchyProblemBase):
"""Problem class that turns an existing test problem into a mixed-discrete problem by mapping the first n (if not
given: all) variables to integers with a given number of options."""

def __init__(self, problem: Problem, n_opts=10, n_vars_int: int = None):
def __init__(self, problem: Problem, n_opts=10, n_vars_int: int = None, cat=False):
self.problem = problem
self.n_opts = n_opts
if n_vars_int is None:
Expand All @@ -160,7 +160,10 @@ def __init__(self, problem: Problem, n_opts=10, n_vars_int: int = None):
self._xl_orig = problem.xl
self._xu_orig = problem.xu

des_vars = [Integer(bounds=(0, n_opts-1)) if i < n_vars_int else Real(bounds=(problem.xl[i], problem.xu[i]))
def _get_var():
return Choice(options=list(range(n_opts))) if cat else Integer(bounds=(0, n_opts-1))

des_vars = [_get_var() if i < n_vars_int else Real(bounds=(problem.xl[i], problem.xu[i]))
for i in range(problem.n_var)]
super().__init__(des_vars, n_obj=problem.n_obj, n_ieq_constr=problem.n_ieq_constr)
self.callback = problem.callback
Expand Down
10 changes: 9 additions & 1 deletion sb_arch_opt/tests/algo/test_arch_sbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_smt_krg_features():
hierarchical_kernel=MixHrcKernelType.ALG_KERNEL,
)
pls_kwargs = dict(
categorical_kernel=MixIntKernelType.CONT_RELAX,
categorical_kernel=MixIntKernelType.GOWER,
hierarchical_kernel=MixHrcKernelType.ALG_KERNEL,
)

Expand Down Expand Up @@ -369,7 +369,12 @@ def _try_model(problem: ArchOptProblemBase, pls: bool = False, cont_relax: bool
else:
model = KRG(design_space=model_ds, **kwargs)

n_theta = ModelFactory.get_n_theta(problem, model)
assert n_theta > 0

model = MultiSurrogateModel(model)
n_theta_multi = ModelFactory.get_n_theta(problem, model)
assert n_theta_multi == n_theta*(problem.n_obj+problem.n_ieq_constr)

x_train = HierarchicalSampling().do(problem, 50).get('X')
y_train = problem.evaluate(x_train, return_as_dictionary=True)['F']
Expand All @@ -381,6 +386,9 @@ def _try_model(problem: ArchOptProblemBase, pls: bool = False, cont_relax: bool
assert not throws_error
except (TypeError, ValueError):
assert throws_error
return

assert ModelFactory.get_n_theta(problem, model) == n_theta_multi

with disable_int_fix():
# Continuous
Expand Down

0 comments on commit 2fcf5df

Please sign in to comment.