Skip to content

Commit

Permalink
fix: Return acquisition_functions still in use
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Aug 29, 2024
1 parent 205f503 commit 03729ca
Show file tree
Hide file tree
Showing 7 changed files with 537 additions and 21 deletions.
2 changes: 1 addition & 1 deletion neps/sampling/distributions.py → neps/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ def log_prob(self, value):


@dataclass
class TorchDistributionWithDomain:
class DistributionOverDomain:
distribution: Distribution
domain: Domain
213 changes: 213 additions & 0 deletions neps/optimizers/bayesian_optimization/acquisition_functions/_ehvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# from abc import ABC, abstractmethod
from itertools import product

import torch
from torch import Tensor
from torch.distributions import Normal
from torch.nn import Module

# class MultiObjectiveBaseAcqusition(ABC):
# def __init__(self, surrogate_models: dict):
# self.surrogate_models = surrogate_models
#
# def propose_location(self, *args):
# """Propose new locations for subsequent sampling
# This method should be overriden by respective acquisition function implementations."""
# raise NotImplementedError
#
# def optimize(self):
# """This is the method that user should call for the Bayesian optimisation main loop."""
# raise NotImplementedError
#
# @abstractmethod
# def eval(self, x, asscalar: bool = False):
# """Evaluate the acquisition function at point x2. This should be overridden by respective acquisition
# function implementations"""
# raise NotImplementedError
#
# def __call__(self, *args, **kwargs):
# return self.eval(*args, **kwargs)
#
# def reset_surrogate_model(self, surrogate_models: dict):
# for objective, surrogate_model in surrogate_models.items():
# self.surrogate_models[objective] = surrogate_model
#


class ExpectedHypervolumeImprovement(Module): # , MultiObjectiveBaseAcqusition):
def __init__(
self,
model,
ref_point,
partitioning,
) -> None:
r"""Expected Hypervolume Improvement supporting m>=2 outcomes.
Implementation from BOtorch, adapted from
https://github.com/pytorch/botorch/blob/353f37649fa8d90d881e8ea20c11986b15723ef1/botorch/acquisition/multi_objective/analytic.py#L78
This implements the computes EHVI using the algorithm from [Yang2019]_, but
additionally computes gradients via auto-differentiation as proposed by
[Daulton2020qehvi]_.
Note: this is currently inefficient in two ways due to the binary partitioning
algorithm that we use for the box decomposition:
- We have more boxes in our decomposition
- If we used a box decomposition that used `inf` as the upper bound for
the last dimension *in all hypercells*, then we could reduce the number
of terms we need to compute from 2^m to 2^(m-1). [Yang2019]_ do this
by using DKLV17 and LKF17 for the box decomposition.
TODO: Use DKLV17 and LKF17 for the box decomposition as in [Yang2019]_ for
greater efficiency.
TODO: Add support for outcome constraints.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
>>> ref_point = [0.0, 0.0]
>>> EHVI = ExpectedHypervolumeImprovement(model, ref_point, partitioning)
>>> ehvi = EHVI(test_X)
Args:
model: A fitted model.
ref_point: A list with `m` elements representing the reference point (in the
outcome space) w.r.t. to which compute the hypervolume. This is a
reference point for the objective values (i.e. after applying
`objective` to the samples).
partitioning: A `NondominatedPartitioning` module that provides the non-
dominated front and a partitioning of the non-dominated space in hyper-
rectangles.
objective: An `AnalyticMultiOutputObjective`.
"""
# TODO: we could refactor this __init__ logic into a
# HypervolumeAcquisitionFunction Mixin
if len(ref_point) != partitioning.num_outcomes:
raise ValueError(
"The length of the reference point must match the number of outcomes. "
f"Got ref_point with {len(ref_point)} elements, but expected "
f"{partitioning.num_outcomes}."
)
ref_point = torch.tensor(
ref_point,
dtype=partitioning.pareto_Y.dtype,
device=partitioning.pareto_Y.device,
)
better_than_ref = (partitioning.pareto_Y > ref_point).all(dim=1)
if not better_than_ref.any() and partitioning.pareto_Y.shape[0] > 0:
raise ValueError(
"At least one pareto point must be better than the reference point."
)
super().__init__()
self.model = model
self.register_buffer("ref_point", ref_point)
self.partitioning = partitioning
cell_bounds = self.partitioning.get_hypercell_bounds()
self.register_buffer("cell_lower_bounds", cell_bounds[0])
self.register_buffer("cell_upper_bounds", cell_bounds[1])
# create indexing tensor of shape `2^m x m`
self._cross_product_indices = torch.tensor(
list(product(*[[0, 1] for _ in range(ref_point.shape[0])])),
dtype=torch.long,
device=ref_point.device,
)
self.normal = Normal(0, 1)

def psi(self, lower: Tensor, upper: Tensor, mu: Tensor, sigma: Tensor) -> None:
r"""Compute Psi function.
For each cell i and outcome k:
Psi(lower_{i,k}, upper_{i,k}, mu_k, sigma_k) = (
sigma_k * PDF((upper_{i,k} - mu_k) / sigma_k) + (
mu_k - lower_{i,k}
) * (1 - CDF(upper_{i,k} - mu_k) / sigma_k)
)
See Equation 19 in [Yang2019]_ for more details.
Args:
lower: A `num_cells x m`-dim tensor of lower cell bounds
upper: A `num_cells x m`-dim tensor of upper cell bounds
mu: A `batch_shape x 1 x m`-dim tensor of means
sigma: A `batch_shape x 1 x m`-dim tensor of standard deviations (clamped).
Returns:
A `batch_shape x num_cells x m`-dim tensor of values.
"""
u = (upper - mu) / sigma
return sigma * self.normal.log_prob(u).exp() + (mu - lower) * (
1 - self.normal.cdf(u)
)

def nu(self, lower: Tensor, upper: Tensor, mu: Tensor, sigma: Tensor) -> None:
r"""Compute Nu function.
For each cell i and outcome k:
nu(lower_{i,k}, upper_{i,k}, mu_k, sigma_k) = (
upper_{i,k} - lower_{i,k}
) * (1 - CDF((upper_{i,k} - mu_k) / sigma_k))
See Equation 25 in [Yang2019]_ for more details.
Args:
lower: A `num_cells x m`-dim tensor of lower cell bounds
upper: A `num_cells x m`-dim tensor of upper cell bounds
mu: A `batch_shape x 1 x m`-dim tensor of means
sigma: A `batch_shape x 1 x m`-dim tensor of standard deviations (clamped).
Returns:
A `batch_shape x num_cells x m`-dim tensor of values.
"""
return (upper - lower) * (1 - self.normal.cdf((upper - mu) / sigma))

def forward(self, X: Tensor) -> Tensor:
posterior = [[_m.predict(_x) for _m in self.model] for _x in X]
mu = torch.tensor([[_m[0].item() for _m in _p] for _p in posterior])[:, None, :]
sigma = torch.tensor([[_s[1].item() for _s in _p] for _p in posterior])[
:, None, :
]

# clamp here, since upper_bounds will contain `inf`s, which
# are not differentiable
cell_upper_bounds = self.cell_upper_bounds.clamp_max(1e8)
# Compute psi(lower_i, upper_i, mu_i, sigma_i) for i=0, ... m-2
psi_lu = self.psi(
lower=self.cell_lower_bounds, upper=cell_upper_bounds, mu=mu, sigma=sigma
)
# Compute psi(lower_m, lower_m, mu_m, sigma_m)
psi_ll = self.psi(
lower=self.cell_lower_bounds,
upper=self.cell_lower_bounds,
mu=mu,
sigma=sigma,
)
# Compute nu(lower_m, upper_m, mu_m, sigma_m)
nu = self.nu(
lower=self.cell_lower_bounds, upper=cell_upper_bounds, mu=mu, sigma=sigma
)
# compute the difference psi_ll - psi_lu
psi_diff = psi_ll - psi_lu

# this is batch_shape x num_cells x 2 x (m-1)
stacked_factors = torch.stack([psi_diff, nu], dim=-2)

# Take the cross product of psi_diff and nu across all outcomes
# e.g. for m = 2
# for each batch and cell, compute
# [psi_diff_0, psi_diff_1]
# [nu_0, psi_diff_1]
# [psi_diff_0, nu_1]
# [nu_0, nu_1]
# this tensor has shape: `batch_shape x num_cells x 2^m x m`
all_factors_up_to_last = stacked_factors.gather(
dim=-2,
index=self._cross_product_indices.expand(
stacked_factors.shape[:-2] + self._cross_product_indices.shape
),
)
# compute product for all 2^m terms,
# sum across all terms and hypercells
return all_factors_up_to_last.prod(dim=-1).sum(dim=-1).sum(dim=-1)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod


class BaseAcquisition(ABC):
def __init__(self):
self.surrogate_model = None

@abstractmethod
def eval(self, x, asscalar: bool = False):
"""Evaluate the acquisition function at point x2."""
raise NotImplementedError

def __call__(self, *args, **kwargs):
return self.eval(*args, **kwargs)

def set_state(self, surrogate_model, **kwargs):
self.surrogate_model = surrogate_model
120 changes: 120 additions & 0 deletions neps/optimizers/bayesian_optimization/acquisition_functions/ei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

import torch
from torch.distributions import Normal

from .base_acquisition import BaseAcquisition

if TYPE_CHECKING:
import numpy as np

from neps.search_spaces import SearchSpace


class ComprehensiveExpectedImprovement(BaseAcquisition):
def __init__(
self,
augmented_ei: bool = False,
xi: float = 0.0,
in_fill: str = "best",
log_ei: bool = False,
optimize_on_max_fidelity: bool = True,
):
"""This is the graph BO version of the expected improvement
key differences are:
1. The input x2 is a networkx graph instead of a vectorial input
2. The search space (a collection of x1_graphs) is discrete, so there is no
gradient-based optimisation. Instead, we compute the EI at all candidate points
and empirically select the best position during optimisation
Args:
augmented_ei: Using the Augmented EI heuristic modification to the standard
expected improvement algorithm according to Huang (2006).
xi: manual exploration-exploitation trade-off parameter.
in_fill: the criterion to be used for in-fill for the determination of mu_star
'best' means the empirical best observation so far (but could be
susceptible to noise), 'posterior' means the best *posterior GP mean*
encountered so far, and is recommended for optimization of more noisy
functions. Defaults to "best".
log_ei: log-EI if true otherwise usual EI.
"""
super().__init__()

if in_fill not in ["best", "posterior"]:
raise ValueError(f"Invalid value for in_fill ({in_fill})")
self.augmented_ei = augmented_ei
self.xi = xi
self.in_fill = in_fill
self.log_ei = log_ei
self.incumbent = None
self.optimize_on_max_fidelity = optimize_on_max_fidelity

def eval(
self,
x: Sequence[SearchSpace],
asscalar: bool = False,
) -> np.ndarray | torch.Tensor | float:
"""Return the negative expected improvement at the query point x2."""
assert self.incumbent is not None, "EI function not fitted on model"

if x[0].has_fidelity and self.optimize_on_max_fidelity:
_x = [e.clone() for e in x]
for e in _x:
e.set_to_max_fidelity()
else:
_x = x

mu, cov = self.surrogate_model.predict(_x)

std = torch.sqrt(torch.diag(cov))
mu_star = self.incumbent

gauss = Normal(torch.zeros(1, device=mu.device), torch.ones(1, device=mu.device))
# u = (mu - mu_star - self.xi) / std
# ei = std * updf + (mu - mu_star - self.xi) * ucdf
if self.log_ei:
# we expect that f_min is in log-space
f_min = mu_star - self.xi
v = (f_min - mu) / std
ei = torch.exp(f_min) * gauss.cdf(v) - torch.exp(
0.5 * torch.diag(cov) + mu
) * gauss.cdf(v - std)
else:
u = (mu_star - mu - self.xi) / std
try:
ucdf = gauss.cdf(u)
except ValueError as e:
print(f"u: {u}") # noqa: T201
print(f"mu_star: {mu_star}") # noqa: T201
print(f"mu: {mu}") # noqa: T201
print(f"std: {std}") # noqa: T201
print(f"diag: {cov.diag()}") # noqa: T201
raise e
updf = torch.exp(gauss.log_prob(u))
ei = std * updf + (mu_star - mu - self.xi) * ucdf
if self.augmented_ei:
sigma_n = self.surrogate_model.likelihood
ei *= 1.0 - torch.sqrt(torch.tensor(sigma_n, device=mu.device)) / torch.sqrt(
sigma_n + torch.diag(cov)
)
if isinstance(_x, list) and asscalar:
return ei.detach().numpy()
if asscalar:
ei = ei.detach().numpy().item()
return ei

def set_state(self, surrogate_model, **kwargs):
super().set_state(surrogate_model, **kwargs)

# Compute incumbent
if self.in_fill == "best":
self.incumbent = torch.min(self.surrogate_model.y_)
else:
x = self.surrogate_model.x
mu_train, _ = self.surrogate_model.predict(x)
incumbent_idx = torch.argmin(mu_train)
self.incumbent = self.surrogate_model.y_[incumbent_idx]
Loading

0 comments on commit 03729ca

Please sign in to comment.