From e4b0842bdea8fb2f36d283c40d2d71fd1087b7b3 Mon Sep 17 00:00:00 2001 From: vizier-team Date: Mon, 6 Jan 2025 18:21:56 -0800 Subject: [PATCH] Adds multi-metric support to `GP_UCB_PE` PiperOrigin-RevId: 712716509 --- .../algorithms/designers/gp/acquisitions.py | 25 +++ vizier/_src/algorithms/designers/gp_bandit.py | 18 +- vizier/_src/algorithms/designers/gp_ucb_pe.py | 202 +++++++++++++++--- .../algorithms/designers/gp_ucb_pe_test.py | 88 ++++++-- 4 files changed, 272 insertions(+), 61 deletions(-) diff --git a/vizier/_src/algorithms/designers/gp/acquisitions.py b/vizier/_src/algorithms/designers/gp/acquisitions.py index 4e9b37951..0ca518c6e 100644 --- a/vizier/_src/algorithms/designers/gp/acquisitions.py +++ b/vizier/_src/algorithms/designers/gp/acquisitions.py @@ -557,6 +557,31 @@ def __call__( )() +def create_hv_scalarization( + num_scalarizations: int, labels: types.PaddedArray, rng: jax.Array +): + """Creates a HyperVolumeScalarization with random weights. + + Args: + num_scalarizations: The number of scalarizations to create. + labels: The labels used to create the reference point. + rng: The random key to use for sampling the weights. + + Returns: + A HyperVolumeScalarization with random weights. + """ + weights = jax.random.normal( + rng, + shape=(num_scalarizations, labels.shape[1]), + ) + weights = jnp.abs(weights) + weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True) + ref_point = ( + get_reference_point(labels, scale=0.01) if labels.shape[0] > 0 else None + ) + return scalarization.HyperVolumeScalarization(weights, ref_point) + + # TODO: What do we end up jitting? If we end up directly jitting this call # then we should make it `eqx.Module` and set # `reduction_fn=eqx.field(static=True)` instead. diff --git a/vizier/_src/algorithms/designers/gp_bandit.py b/vizier/_src/algorithms/designers/gp_bandit.py index 01605484f..389fe5b1c 100644 --- a/vizier/_src/algorithms/designers/gp_bandit.py +++ b/vizier/_src/algorithms/designers/gp_bandit.py @@ -36,7 +36,6 @@ from vizier import algorithms as vza from vizier import pyvizier as vz from vizier._src.algorithms.designers import quasi_random -from vizier._src.algorithms.designers import scalarization from vizier._src.algorithms.designers.gp import acquisitions as acq_lib from vizier._src.algorithms.designers.gp import gp_models from vizier._src.algorithms.designers.gp import output_warpers @@ -202,27 +201,18 @@ def __attrs_post_init__(self): # Multi-objective overrides. m_info = self._problem.metric_information if not m_info.is_single_objective: - num_obj = len(m_info.of_type(vz.MetricType.OBJECTIVE)) # Create scalarization weights. self._rng, weights_rng = jax.random.split(self._rng) - weights = jax.random.normal( - weights_rng, shape=(self._num_scalarizations, num_obj) - ) - weights = jnp.abs(weights) - weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True) def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction: # Scalarized UCB. - labels_array = data.labels.padded_array - has_labels = labels_array.shape[0] > 0 - ref_point = ( - acq_lib.get_reference_point(data.labels, self._ref_scaling) - if has_labels - else None + scalarizer = acq_lib.create_hv_scalarization( + self._num_scalarizations, data.labels, weights_rng ) - scalarizer = scalarization.HyperVolumeScalarization(weights, ref_point) + labels_array = data.labels.padded_array + has_labels = labels_array.shape[0] > 0 max_scalarized = None if has_labels: max_scalarized = jnp.max(scalarizer(labels_array), axis=-1) diff --git a/vizier/_src/algorithms/designers/gp_ucb_pe.py b/vizier/_src/algorithms/designers/gp_ucb_pe.py index be2b2dc73..13f23e289 100644 --- a/vizier/_src/algorithms/designers/gp_ucb_pe.py +++ b/vizier/_src/algorithms/designers/gp_ucb_pe.py @@ -20,6 +20,7 @@ import copy import datetime +import enum import random from typing import Any, Callable, Mapping, Optional, Sequence, Union @@ -35,6 +36,7 @@ from vizier import algorithms as vza from vizier import pyvizier as vz from vizier._src.algorithms.designers import quasi_random +from vizier._src.algorithms.designers import scalarization from vizier._src.algorithms.designers.gp import acquisitions from vizier._src.algorithms.designers.gp import output_warpers from vizier._src.algorithms.optimizers import eagle_strategy as es @@ -51,6 +53,23 @@ tfd = tfp.distributions +class MultimetricPromisingRegionPenaltyType(enum.Enum): + """The type of penalty to apply to the points outside the promising region. + + Configures the penalty term in `PEScoreFunction` for multimetric problems. + """ + + # The penalty is applied to the points outside the union of the promising + # regions of all metrics. + UNION = 'union' + # The penalty is applied to the points outside the intersection of the + # promising regions of all metrics. + INTERSECTION = 'intersection' + # The penalty applied to a point in the search space is the average of + # the penalties with respect to the promising regions of all metrics. + AVERAGE = 'average' + + class UCBPEConfig(eqx.Module): """UCB-PE config parameters.""" @@ -92,6 +111,13 @@ class UCBPEConfig(eqx.Module): optimize_set_acquisition_for_exploration: bool = eqx.field( default=False, static=True ) + # The type of penalty to apply to the points outside the promising region for + # multimetric problems. + multimetric_promising_region_penalty_type: ( + MultimetricPromisingRegionPenaltyType + ) = eqx.field( + default=MultimetricPromisingRegionPenaltyType.AVERAGE, static=True + ) def __repr__(self): return eqx.tree_pformat(self, short_arrays=False) @@ -155,10 +181,28 @@ def _compute_ucb_threshold( The predicted mean of the feature array with the maximum UCB among `xs`. """ pred_mean = gprm.mean() - ucb_values = jnp.where( - is_missing, -jnp.inf, pred_mean + ucb_coefficient * gprm.stddev() - ) - return pred_mean[jnp.argmax(ucb_values)] + if pred_mean.ndim > 1: + # In the multimetric case, the predicted mean and stddev are of shape + # [num_points, num_metrics]. + ucb_values = jnp.where( + jnp.tile(is_missing[:, jnp.newaxis], (1, pred_mean.shape[-1])), + -jnp.inf, + pred_mean + ucb_coefficient * gprm.stddev(), + ) + # The indices of the points with the maximum UCB values for each metric. + best_ucb_indices = jnp.argmax(ucb_values, axis=0) + return jax.vmap( + lambda pred_mean, best_ucb_idx: pred_mean[best_ucb_idx], + in_axes=-1, + out_axes=-1, + )(pred_mean, best_ucb_indices) + else: + # In the single metric case, the predicted mean and stddev are of shape + # [num_points]. + ucb_values = jnp.where( + is_missing, -jnp.inf, pred_mean + ucb_coefficient * gprm.stddev() + ) + return pred_mean[jnp.argmax(ucb_values)] # TODO: Use acquisitions.TrustRegion instead. @@ -238,12 +282,36 @@ class UCBScoreFunction(eqx.Module): on completed and pending trials. ucb_coefficient: The UCB coefficient. trust_region: Trust region. + scalarization_weights_rng: Random key for scalarization. + labels: Labels, shaped as [num_index_points, num_metrics]. + num_scalarizations: Number of scalarizations. """ predictive: sp.UniformEnsemblePredictive predictive_all_features: sp.UniformEnsemblePredictive ucb_coefficient: jt.Float[jt.Array, ''] trust_region: Optional[acquisitions.TrustRegion] + labels: types.PaddedArray + scalarizer: scalarization.Scalarization + + def __init__( + self, + predictive: sp.UniformEnsemblePredictive, + predictive_all_features: sp.UniformEnsemblePredictive, + ucb_coefficient: jt.Float[jt.Array, ''], + trust_region: Optional[acquisitions.TrustRegion], + scalarization_weights_rng: jax.Array, + labels: types.PaddedArray, + num_scalarizations: int = 1000, + ): + self.predictive = predictive + self.predictive_all_features = predictive_all_features + self.ucb_coefficient = ucb_coefficient + self.trust_region = trust_region + self.labels = labels + self.scalarizer = acquisitions.create_hv_scalarization( + num_scalarizations, labels, scalarization_weights_rng + ) def score( self, xs: types.ModelInput, seed: Optional[jax.Array] = None @@ -264,9 +332,26 @@ def score_with_aux( mean = gprm.mean() stddev_from_all = gprm_all_features.stddev() acq_values = mean + self.ucb_coefficient * stddev_from_all + # `self.labels` is of shape [num_index_points, num_metrics]. + if self.labels.shape[1] > 1: + scalarized = self.scalarizer(acq_values) + padded_labels = self.labels.replace_fill_value(-np.inf).padded_array + if padded_labels.shape[0] > 0: + # Broadcast max_scalarized to the same shape as scalarized and take max. + max_scalarized = jnp.max(self.scalarizer(padded_labels), axis=-1) + shape_mismatch = len(scalarized.shape) - len(max_scalarized.shape) + expand_max = jnp.expand_dims( + max_scalarized, axis=range(-shape_mismatch, 0) + ) + scalarized = jnp.maximum(scalarized, expand_max) + scalarized_acq_values = jnp.mean(scalarized, axis=0) + else: + scalarized_acq_values = acq_values if self.trust_region is not None: - acq_values = _apply_trust_region(self.trust_region, xs, acq_values) - return acq_values, { + scalarized_acq_values = _apply_trust_region( + self.trust_region, xs, scalarized_acq_values + ) + return scalarized_acq_values, { 'mean': mean, 'stddev': gprm.stddev(), 'stddev_from_all': stddev_from_all, @@ -303,6 +388,9 @@ class PEScoreFunction(eqx.Module): explore_ucb_coefficient: jt.Float[jt.Array, ''] penalty_coefficient: jt.Float[jt.Array, ''] trust_region: Optional[acquisitions.TrustRegion] + multimetric_promising_region_penalty_type: ( + MultimetricPromisingRegionPenaltyType + ) def score( self, xs: types.ModelInput, seed: Optional[jax.Array] = None @@ -333,10 +421,34 @@ def score_with_aux( gprm_all = self.predictive_all_features.predict(xs) stddev_from_all = gprm_all.stddev() - acq_values = stddev_from_all + self.penalty_coefficient * jnp.minimum( + penalty = self.penalty_coefficient * jnp.minimum( explore_ucb - threshold, 0.0, ) + # `stddev_from_all` and `penalty` are of shape + # [num_index_points, num_metrics] for multi-metric problems or + # [num_index_points] for single-metric problems. + if stddev_from_all.ndim > 1: + if self.multimetric_promising_region_penalty_type == ( + MultimetricPromisingRegionPenaltyType.UNION + ): + scalarized_penalty = jnp.max(penalty, axis=-1) + elif self.multimetric_promising_region_penalty_type == ( + MultimetricPromisingRegionPenaltyType.INTERSECTION + ): + scalarized_penalty = jnp.min(penalty, axis=-1) + elif self.multimetric_promising_region_penalty_type == ( + MultimetricPromisingRegionPenaltyType.AVERAGE + ): + scalarized_penalty = jnp.mean(penalty, axis=-1) + else: + raise ValueError( + 'Unsupported multimetric promising region penalty type:' + f' {self.multimetric_promising_region_penalty_type}' + ) + acq_values = jnp.mean(stddev_from_all, axis=-1) + scalarized_penalty + else: + acq_values = stddev_from_all + penalty if self.trust_region is not None: acq_values = _apply_trust_region(self.trust_region, xs, acq_values) return acq_values, { @@ -537,8 +649,14 @@ def __attrs_post_init__(self): # Extra validations if self._problem.search_space.is_conditional: raise ValueError(f'{type(self)} does not support conditional search.') - elif len(self._problem.metric_information) != 1: - raise ValueError(f'{type(self)} works with exactly one metric.') + elif ( + len(self._problem.metric_information) != 1 + and self._config.optimize_set_acquisition_for_exploration + ): + raise ValueError( + f'{type(self)} works with exactly one metric when' + ' `optimize_set_acquisition_for_exploration` is enabled.' + ) # Extra initializations. # Discrete parameters are continuified to account for their actual values. @@ -554,7 +672,7 @@ def __attrs_post_init__(self): self._problem.search_space, seed=int(jax.random.randint(qrs_seed, [], 0, 2**16)), ) - self._output_warper = None + self._output_warpers: list[output_warpers.OutputWarper] = [] def update( self, completed: vza.CompletedTrials, all_active: vza.ActiveTrials @@ -717,10 +835,15 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData: data.labels.shape, _get_features_shape(data.features), ) - self._output_warper = output_warpers.create_default_warper() - warped_labels = self._output_warper.warp(np.array(data.labels.unpad())) + unpadded_labels = np.asarray(data.labels.unpad()) + warped_labels = [] + self._output_warpers = [] + for i in range(data.labels.shape[1]): + output_warper = output_warpers.create_default_warper() + warped_labels.append(output_warper.warp(unpadded_labels[:, i : i + 1])) + self._output_warpers.append(output_warper) labels = types.PaddedArray.from_array( - warped_labels, + np.concatenate(warped_labels, axis=-1), data.labels.padded_array.shape, fill_value=data.labels.fill_value, ) @@ -773,7 +896,10 @@ def _get_predictive_all_features( # Pending features are only used to predict standard deviation, so their # labels do not matter, and we simply set them to 0. dummy_labels = jnp.zeros( - shape=(pending_features.continuous.unpad().shape[0], 1), + shape=( + pending_features.continuous.unpad().shape[0], + data.labels.shape[-1], + ), dtype=data.labels.padded_array.dtype, ) all_labels = jnp.concatenate([data.labels.unpad(), dummy_labels], axis=0) @@ -840,11 +966,14 @@ def _suggest_one( # When `use_ucb` is true, the acquisition function computes the UCB # values. Otherwise, it computes the Pure-Exploration acquisition values. if use_ucb: + scalarization_weights_rng, self._rng = jax.random.split(self._rng) scoring_fn = UCBScoreFunction( predictive, predictive_all_features, ucb_coefficient=self._config.ucb_coefficient, trust_region=tr if self._use_trust_region else None, + scalarization_weights_rng=scalarization_weights_rng, + labels=data.labels, ) else: scoring_fn = PEScoreFunction( @@ -854,6 +983,9 @@ def _suggest_one( ucb_coefficient=self._config.ucb_coefficient, explore_ucb_coefficient=self._config.explore_region_ucb_coefficient, trust_region=tr if self._use_trust_region else None, + multimetric_promising_region_penalty_type=( + self._config.multimetric_promising_region_penalty_type + ), ) if isinstance(acquisition_optimizer, vb.VectorizedOptimizer): @@ -910,9 +1042,11 @@ def _suggest_one( # debugging needs. metadata = best_candidate.metadata.ns(self._metadata_ns) metadata.ns('prediction_in_warped_y_space').update({ - 'mean': f'{predict_mean[0]}', - 'stddev': f'{predict_stddev[0]}', - 'stddev_from_all': f'{predict_stddev_from_all[0]}', + 'mean': np.array2string(np.asarray(predict_mean[0]), separator=','), + 'stddev': np.array2string(np.asarray(predict_stddev[0]), separator=','), + 'stddev_from_all': np.array2string( + np.asarray(predict_stddev_from_all[0]), separator=',' + ), 'acquisition': f'{acquisition}', 'use_ucb': f'{use_ucb}', 'trust_radius': f'{tr.trust_radius}', @@ -1060,20 +1194,36 @@ def sample( ) samples = eqx.filter_jit(acquisitions.sample_from_predictive)( predictive, xs, num_samples, key=rng - ) # (num_samples, num_trials) - # Scope the samples to non-padded only (there's a single padded dimension). + ) + # Scope `samples` to non-padded only (there's a single padded dimension). + # `samples` has shape: [num_samples, num_trials] for single metric or + # [num_samples, num_trials, num_metrics] for multi-metric problems. + if samples.ndim == 2: + samples = jnp.expand_dims(samples, axis=-1) samples = samples[ - :, ~(xs.continuous.is_missing[0] | xs.categorical.is_missing[0]) + :, ~(xs.continuous.is_missing[0] | xs.categorical.is_missing[0]), : ] # TODO: vectorize output warping. - if self._output_warper is not None: - return np.vstack([ - self._output_warper.unwarp(samples[i][..., np.newaxis]).reshape(-1) - for i in range(samples.shape[0]) - ]) + if self._output_warpers: + unwarped_samples = [] + for metric_idx, output_warper in enumerate(self._output_warpers): + unwarped_samples.append( + np.vstack([ + output_warper.unwarp( + samples[i][:, metric_idx : metric_idx + 1] + ).reshape(-1) + for i in range(samples.shape[0]) + ]) + ) + unwarped_samples = np.stack(unwarped_samples, axis=-1) + if unwarped_samples.shape[-1] > 1: + return unwarped_samples + else: + return np.squeeze(unwarped_samples, axis=-1) else: raise TypeError( - 'Output warper is expected to be set, but found to be None.' + 'Output warpers are expected to be set, but found to be' + f' {self._output_warpers}.' ) @profiler.record_runtime diff --git a/vizier/_src/algorithms/designers/gp_ucb_pe_test.py b/vizier/_src/algorithms/designers/gp_ucb_pe_test.py index a8f890492..0dee4b37b 100644 --- a/vizier/_src/algorithms/designers/gp_ucb_pe_test.py +++ b/vizier/_src/algorithms/designers/gp_ucb_pe_test.py @@ -16,6 +16,7 @@ """Tests for gp_ucb_pe.""" +import ast import copy from typing import Any, Tuple @@ -39,12 +40,12 @@ def _extract_predictions( metadata: Any, -) -> Tuple[float, float, float, float, bool]: +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float, bool]: pred = metadata.ns('prediction_in_warped_y_space') return ( - float(pred['mean']), - float(pred['stddev']), - float(pred['stddev_from_all']), + np.asarray(ast.literal_eval(pred['mean'])), + np.asarray(ast.literal_eval(pred['stddev'])), + np.asarray(ast.literal_eval(pred['stddev_from_all'])), float(pred['acquisition']), bool(pred['use_ucb'] == 'True'), ) @@ -81,6 +82,26 @@ class GpUcbPeTest(parameterized.TestCase): ensemble_size=3, turns_on_high_noise_mode=True, ), + dict(iters=3, batch_size=5, num_seed_trials=5, num_metrics=2), + dict( + iters=3, + batch_size=3, + num_metrics=2, + applies_padding=True, + multimetric_promising_region_penalty_type=( + gp_ucb_pe.MultimetricPromisingRegionPenaltyType.UNION + ), + ), + dict( + iters=3, + batch_size=3, + num_metrics=2, + applies_padding=True, + ensemble_size=4, + multimetric_promising_region_penalty_type=( + gp_ucb_pe.MultimetricPromisingRegionPenaltyType.INTERSECTION + ), + ), ) def test_on_flat_space( self, @@ -96,17 +117,25 @@ def test_on_flat_space( test_studies.flat_continuous_space_with_scaling() ), turns_on_high_noise_mode: bool = False, + num_metrics: int = 1, + multimetric_promising_region_penalty_type: ( + gp_ucb_pe.MultimetricPromisingRegionPenaltyType + ) = gp_ucb_pe.MultimetricPromisingRegionPenaltyType.AVERAGE, ): # We use string names so that test case names are readable. Convert them # to objects. if ard_optimizer == 'default': ard_optimizer = optimizers.default_optimizer() problem = vz.ProblemStatement(search_space) - problem.metric_information.append( - vz.MetricInformation( - name='metric', goal=vz.ObjectiveMetricGoal.MAXIMIZE - ) - ) + for metric_idx in range(num_metrics): + problem.metric_information.append( + vz.MetricInformation( + name=f'metric{metric_idx}', + goal=vz.ObjectiveMetricGoal.MAXIMIZE + if metric_idx % 2 == 0 + else vz.ObjectiveMetricGoal.MINIMIZE, + ) + ) vectorized_optimizer_factory = vb.VectorizedOptimizerFactory( strategy_factory=es.VectorizedEagleStrategyFactory(), max_evaluations=100, @@ -134,6 +163,9 @@ def test_on_flat_space( signal_to_noise_threshold=np.inf if turns_on_high_noise_mode else 0.0, + multimetric_promising_region_penalty_type=( + multimetric_promising_region_penalty_type + ), ), ensemble_size=ensemble_size, padding_schedule=padding.PaddingSchedule( @@ -197,7 +229,9 @@ def test_on_flat_space( if len(completed_trials) > 1: # test the sample method. samples = designer.sample(test_trials, num_samples=5) - self.assertSequenceEqual(samples.shape, (5, 3)) + self.assertSequenceEqual( + samples.shape, (5, 3) if num_metrics == 1 else (5, 3, num_metrics) + ) self.assertFalse(np.isnan(samples).any()) # test the sample method with a different rng. samples_rng = designer.sample( @@ -207,8 +241,14 @@ def test_on_flat_space( self.assertFalse((np.abs(samples - samples_rng) <= 1e-6).all()) # test the predict method. prediction = designer.predict(test_trials) - self.assertLen(prediction.mean, 3) - self.assertLen(prediction.stddev, 3) + self.assertSequenceEqual( + prediction.mean.shape, + (3,) if num_metrics == 1 else (3, num_metrics), + ) + self.assertSequenceEqual( + prediction.stddev.shape, + (3,) if num_metrics == 1 else (3, num_metrics), + ) self.assertFalse(np.isnan(prediction.mean).any()) self.assertFalse(np.isnan(prediction.stddev).any()) if last_prediction is None: @@ -260,8 +300,11 @@ def test_on_flat_space( # Except for the last batch of suggestions, the acquisition value of # the first suggestion in a batch is expected to be UCB, which # combines the predicted mean based only on completed trials and the - # predicted standard deviation based on all trials. - self.assertAlmostEqual(mean + 10.0 * stddev_from_all, acq) + # predicted standard deviation based on all trials. Only checks the + # single-metric case because the acquisition value in the multi-metric + # case is randomly scalarized. + if num_metrics == 1: + self.assertAlmostEqual(mean + 10.0 * stddev_from_all, acq) self.assertTrue(use_ucb) continue @@ -280,7 +323,9 @@ def test_on_flat_space( # in every batch. The Pure-Exploration acquisition values are standard # deviation predictions based on all trials (completed and pending). self.assertAlmostEqual( - acq, stddev_from_all, msg=f'batch: {idx}, suggestion: {jdx}' + acq, + np.mean(stddev_from_all), + msg=f'batch: {idx}, suggestion: {jdx}', ) if optimize_set_acquisition_for_exploration: geometric_mean_of_pred_cov_eigs = np.exp( @@ -316,6 +361,7 @@ def test_ucb_overwrite(self): explore_region_ucb_coefficient=0.5, cb_violation_penalty_coefficient=10.0, ucb_overwrite_probability=1.0, + pe_overwrite_probability=0.0, signal_to_noise_threshold=0.0, ), padding_schedule=padding.PaddingSchedule( @@ -364,12 +410,12 @@ def test_ucb_overwrite(self): # Skips the first batch of suggestions, which are generated by the # seeding designer, not acquisition function optimization. continue - # Because `ucb_overwrite_probability` is 1, all suggestions after the - # first batch are expected to be generated by UCB. Within a batch, the - # first suggestion's UCB value is expected to use predicted standard - # deviation based only on completed trials, while the UCB values of - # the second to the last suggestions are expected to use the predicted - # standard deviations based on completed and active trials. + # Because `ucb_overwrite_probability` is 1 and `pe_overwrite_probability` + # is 0, all suggestions after the first batch are expected to be generated + # by UCB. Within a batch, the first suggestion's UCB value is expected to + # use predicted standard deviation based only on completed trials, while + # the UCB values of the second to the last suggestions are expected to use + # the predicted standard deviations based on completed and active trials. mean, stddev, stddev_from_all, acq, use_ucb = _extract_predictions( trial.metadata.ns('gp_ucb_pe_bandit_test') )