Skip to content

Commit

Permalink
Adds multi-metric support to GP_UCB_PE
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712716509
  • Loading branch information
vizier-team authored and copybara-github committed Jan 7, 2025
1 parent e0d923e commit e4b0842
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 61 deletions.
25 changes: 25 additions & 0 deletions vizier/_src/algorithms/designers/gp/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,31 @@ def __call__(
)()


def create_hv_scalarization(
num_scalarizations: int, labels: types.PaddedArray, rng: jax.Array
):
"""Creates a HyperVolumeScalarization with random weights.
Args:
num_scalarizations: The number of scalarizations to create.
labels: The labels used to create the reference point.
rng: The random key to use for sampling the weights.
Returns:
A HyperVolumeScalarization with random weights.
"""
weights = jax.random.normal(
rng,
shape=(num_scalarizations, labels.shape[1]),
)
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)
ref_point = (
get_reference_point(labels, scale=0.01) if labels.shape[0] > 0 else None
)
return scalarization.HyperVolumeScalarization(weights, ref_point)


# TODO: What do we end up jitting? If we end up directly jitting this call
# then we should make it `eqx.Module` and set
# `reduction_fn=eqx.field(static=True)` instead.
Expand Down
18 changes: 4 additions & 14 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from vizier import algorithms as vza
from vizier import pyvizier as vz
from vizier._src.algorithms.designers import quasi_random
from vizier._src.algorithms.designers import scalarization
from vizier._src.algorithms.designers.gp import acquisitions as acq_lib
from vizier._src.algorithms.designers.gp import gp_models
from vizier._src.algorithms.designers.gp import output_warpers
Expand Down Expand Up @@ -202,27 +201,18 @@ def __attrs_post_init__(self):
# Multi-objective overrides.
m_info = self._problem.metric_information
if not m_info.is_single_objective:
num_obj = len(m_info.of_type(vz.MetricType.OBJECTIVE))

# Create scalarization weights.
self._rng, weights_rng = jax.random.split(self._rng)
weights = jax.random.normal(
weights_rng, shape=(self._num_scalarizations, num_obj)
)
weights = jnp.abs(weights)
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)

def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
# Scalarized UCB.
labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
ref_point = (
acq_lib.get_reference_point(data.labels, self._ref_scaling)
if has_labels
else None
scalarizer = acq_lib.create_hv_scalarization(
self._num_scalarizations, data.labels, weights_rng
)
scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point)

labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
max_scalarized = None
if has_labels:
max_scalarized = jnp.max(scalarizer(labels_array), axis=-1)
Expand Down
Loading

0 comments on commit e4b0842

Please sign in to comment.