From f30840c88ad22746e5528a4dbb59c39647482229 Mon Sep 17 00:00:00 2001 From: vizier-team Date: Thu, 13 Feb 2025 08:54:59 -0800 Subject: [PATCH] Supports a prior acquisition function in GP_UCB_PE PiperOrigin-RevId: 726507984 --- vizier/_src/algorithms/designers/gp_ucb_pe.py | 75 ++++++++-- .../algorithms/designers/gp_ucb_pe_test.py | 134 ++++++++++++++++++ 2 files changed, 198 insertions(+), 11 deletions(-) diff --git a/vizier/_src/algorithms/designers/gp_ucb_pe.py b/vizier/_src/algorithms/designers/gp_ucb_pe.py index 44f7a4322..d34b88006 100644 --- a/vizier/_src/algorithms/designers/gp_ucb_pe.py +++ b/vizier/_src/algorithms/designers/gp_ucb_pe.py @@ -278,8 +278,10 @@ class UCBScoreFunction(eqx.Module): The UCB acquisition value is the sum of the predicted mean based on completed trials and the predicted standard deviation based on all trials, completed and - pending (scaled by the UCB coefficient). This class follows the - `acquisitions.ScoreFunction` protocol. + pending (scaled by the UCB coefficient). If `prior_acquisition` is not None, + the return value is the sum of the prior acquisition value and the UCB + acquisition value. This class follows the `acquisitions.ScoreFunction` + protocol. Attributes: predictive: Predictive model with cached Cholesky conditioned on completed @@ -288,6 +290,7 @@ class UCBScoreFunction(eqx.Module): on completed and pending trials. ucb_coefficient: The UCB coefficient. trust_region: Trust region. + prior_acquisition: An optional prior acquisition function. scalarization_weights_rng: Random key for scalarization. labels: Labels, shaped as [num_index_points, num_metrics]. num_scalarizations: Number of scalarizations. @@ -297,6 +300,7 @@ class UCBScoreFunction(eqx.Module): predictive_all_features: sp.UniformEnsemblePredictive ucb_coefficient: jt.Float[jt.Array, ''] trust_region: Optional[acquisitions.TrustRegion] + prior_acquisition: Callable[[types.ModelInput], jax.Array] | None labels: types.PaddedArray scalarizer: scalarization.Scalarization @@ -306,6 +310,7 @@ def __init__( predictive_all_features: sp.UniformEnsemblePredictive, ucb_coefficient: jt.Float[jt.Array, ''], trust_region: Optional[acquisitions.TrustRegion], + prior_acquisition: Callable[[types.ModelInput], jax.Array] | None, scalarization_weights_rng: jax.Array, labels: types.PaddedArray, num_scalarizations: int = 1000, @@ -314,6 +319,7 @@ def __init__( self.predictive_all_features = predictive_all_features self.ucb_coefficient = ucb_coefficient self.trust_region = trust_region + self.prior_acquisition = prior_acquisition self.labels = labels self.scalarizer = acquisitions.create_hv_scalarization( num_scalarizations, labels, scalarization_weights_rng @@ -357,11 +363,16 @@ def score_with_aux( scalarized_acq_values = _apply_trust_region( self.trust_region, xs, scalarized_acq_values ) - return scalarized_acq_values, { + aux = { 'mean': mean, 'stddev': gprm.stddev(), 'stddev_from_all': stddev_from_all, } + if self.prior_acquisition is not None: + prior_acq_values = self.prior_acquisition(xs) + scalarized_acq_values = prior_acq_values + scalarized_acq_values + aux['prior_acq_values'] = prior_acq_values + return scalarized_acq_values, aux class PEScoreFunction(eqx.Module): @@ -370,8 +381,10 @@ class PEScoreFunction(eqx.Module): The PE acquisition value is the predicted standard deviation (eq. (9) in https://arxiv.org/pdf/1304.5350) based on all completed and active trials, plus a penalty term that grows linearly in the amount of violation of the - constraint `UCB(xs) >= threshold`. This class follows the - `acquisitions.ScoreFunction` protocol. + constraint `UCB(xs) >= threshold`. If `prior_acquisition` is not None, the + returned value is the sum of the prior acquisition value and the PE + acquisition value. This class follows the `acquisitions.ScoreFunction` + protocol. Attributes: predictive: Predictive model with cached Cholesky conditioned on completed @@ -383,6 +396,9 @@ class PEScoreFunction(eqx.Module): values on `xs`. penalty_coefficient: Multiplier on the constraint violation penalty. trust_region: + prior_acquisition: An optional prior acquisition function. + multimetric_promising_region_penalty_type: The type of multimetric promising + region penalty. Returns: The Pure-Exploration acquisition value. @@ -394,6 +410,7 @@ class PEScoreFunction(eqx.Module): explore_ucb_coefficient: jt.Float[jt.Array, ''] penalty_coefficient: jt.Float[jt.Array, ''] trust_region: Optional[acquisitions.TrustRegion] + prior_acquisition: Callable[[types.ModelInput], jax.Array] | None multimetric_promising_region_penalty_type: ( MultimetricPromisingRegionPenaltyType ) @@ -457,11 +474,16 @@ def score_with_aux( acq_values = stddev_from_all + penalty if self.trust_region is not None: acq_values = _apply_trust_region(self.trust_region, xs, acq_values) - return acq_values, { + aux = { 'mean': mean, 'stddev': stddev, 'stddev_from_all': stddev_from_all, } + if self.prior_acquisition is not None: + prior_acq_values = self.prior_acquisition(xs) + acq_values += prior_acq_values + aux['prior_acq_values'] = prior_acq_values + return acq_values, aux def _logdet(matrix: jax.Array): @@ -486,8 +508,10 @@ class SetPEScoreFunction(eqx.Module): predicted covariance matrix evaluated at the points (eq. (8) in https://arxiv.org/pdf/1304.5350) based on all completed and active trials, plus a penalty term that grows linearly in the amount of violation of the - constraint `UCB(xs) >= threshold`. This class follows the - `acquisitions.ScoreFunction` protocol. + constraint `UCB(xs) >= threshold`. If `prior_acquisition` is not None, the + returned value is the sum of the prior acquisition value and the PE + acquisition value. This class follows the `acquisitions.ScoreFunction` + protocol. Attributes: predictive: Predictive model with cached Cholesky conditioned on completed @@ -499,6 +523,7 @@ class SetPEScoreFunction(eqx.Module): values on `xs`. penalty_coefficient: Multiplier on the constraint violation penalty. trust_region: + prior_acquisition: An optional prior acquisition function. Returns: The Pure-Exploration acquisition value. @@ -510,6 +535,7 @@ class SetPEScoreFunction(eqx.Module): explore_ucb_coefficient: jt.Float[jt.Array, ''] penalty_coefficient: jt.Float[jt.Array, ''] trust_region: Optional[acquisitions.TrustRegion] + prior_acquisition: Callable[[types.ModelInput], jax.Array] | None def score( self, xs: types.ModelInput, seed: Optional[jax.Array] = None @@ -549,11 +575,16 @@ def score_with_aux( ) if self.trust_region is not None: acq_values = _apply_trust_region_to_set(self.trust_region, xs, acq_values) - return acq_values, { + aux = { 'mean': mean, 'stddev': stddev, 'stddev_from_all': jnp.sqrt(jnp.diagonal(cov, axis1=1, axis2=2)), } + if self.prior_acquisition is not None: + prior_acq_values = self.prior_acquisition(xs) + acq_values += prior_acq_values + aux['prior_acq_values'] = prior_acq_values + return acq_values, aux def default_ard_optimizer() -> optimizers.Optimizer[types.ParameterDict]: @@ -587,6 +618,14 @@ class method that takes `ModelInput` and returns a observed. rng: If not set, uses random numbers. clear_jax_cache: If True, every `suggest` call clears the Jax cache. + padding_schedule: Configures what inputs (trials, features, labels) to pad + with what schedule. Useful for reducing JIT compilation passes. (Default + implies no padding.) + prior_acquisition: An optional prior acquisition function. If provided, the + suggestions will be generated by maximizing the sum of the prior + acquisition value and the GP-based acquisition value (UCB or PE). Useful + for biasing the suggestions towards a prior, e.g., being close to some + known parameter values. """ _problem: vz.ProblemStatement = attr.field(kw_only=False) @@ -621,12 +660,13 @@ class method that takes `ModelInput` and returns a factory=lambda: jax.random.PRNGKey(random.getrandbits(32)), kw_only=True ) _clear_jax_cache: bool = attr.field(default=False, kw_only=True) - # Whether to pad all inputs, and what type of schedule to use. This is to - # ensure fewer JIT compilation passes. (Default implies no padding.) # TODO: Check padding does not affect designer behavior. _padding_schedule: padding.PaddingSchedule = attr.field( factory=padding.PaddingSchedule, kw_only=True ) + _prior_acquisition: Callable[[types.ModelInput], jax.Array] | None = ( + attr.field(factory=lambda: None, kw_only=True) + ) default_eagle_config = es.EagleStrategyConfig( visibility=3.6782451729470043, @@ -1003,6 +1043,7 @@ def _suggest_one( predictive_all_features, ucb_coefficient=self._config.ucb_coefficient, trust_region=tr if self._use_trust_region else None, + prior_acquisition=self._prior_acquisition, scalarization_weights_rng=scalarization_weights_rng, labels=data.labels, ) @@ -1014,6 +1055,7 @@ def _suggest_one( ucb_coefficient=self._config.ucb_coefficient, explore_ucb_coefficient=self._config.explore_region_ucb_coefficient, trust_region=tr if self._use_trust_region else None, + prior_acquisition=self._prior_acquisition, multimetric_promising_region_penalty_type=( self._config.multimetric_promising_region_penalty_type ), @@ -1083,6 +1125,11 @@ def _suggest_one( 'trust_radius': f'{tr.trust_radius}', 'params': f'{model.params}', }) + if 'prior_acq_values' in aux: + # Take the first element of the array because `aux` is computed only for + # the best candidate. + prior_acq_value = aux['prior_acq_values'][0] + metadata.ns('prior_acquisition').update({'value': f'{prior_acq_value}'}) metadata.ns('timing').update( {'time': f'{datetime.datetime.now() - start_time}'} ) @@ -1118,6 +1165,7 @@ def _suggest_batch_with_exploration( ucb_coefficient=self._config.ucb_coefficient, explore_ucb_coefficient=self._config.explore_region_ucb_coefficient, trust_region=tr if self._use_trust_region else None, + prior_acquisition=self._prior_acquisition, ) acquisition_optimizer = self._acquisition_optimizer_factory(self._converter) @@ -1180,6 +1228,11 @@ def _suggest_batch_with_exploration( 'trust_radius': f'{tr.trust_radius}', 'params': f'{model.params}', }) + if 'prior_acq_values' in aux: + # Take the first element of the array because `aux` is computed only for + # the best candidate. + prior_acq_value = aux['prior_acq_values'][0] + metadata.ns('prior_acquisition').update({'value': f'{prior_acq_value}'}) metadata.ns('timing').update({'time': f'{end_time - start_time}'}) suggestions.append( vz.TrialSuggestion( diff --git a/vizier/_src/algorithms/designers/gp_ucb_pe_test.py b/vizier/_src/algorithms/designers/gp_ucb_pe_test.py index 08732161a..b79788d72 100644 --- a/vizier/_src/algorithms/designers/gp_ucb_pe_test.py +++ b/vizier/_src/algorithms/designers/gp_ucb_pe_test.py @@ -28,6 +28,7 @@ from vizier._src.algorithms.designers import quasi_random from vizier._src.algorithms.optimizers import eagle_strategy as es from vizier._src.algorithms.optimizers import vectorized_base as vb +from vizier._src.jax import types from vizier._src.jax.models import multitask_tuned_gp_models from vizier.jax import optimizers from vizier.pyvizier.converters import padding @@ -450,6 +451,139 @@ def test_ucb_overwrite(self): ) self.assertTrue(use_ucb) + @parameterized.parameters( + dict(optimize_set_acquisition_for_exploration=False), + dict(optimize_set_acquisition_for_exploration=True), + ) + def test_prior_acquisition( + self, optimize_set_acquisition_for_exploration: bool + ): + problem = vz.ProblemStatement( + test_studies.flat_continuous_space_with_scaling() + ) + problem.metric_information.append( + vz.MetricInformation( + name='metric', goal=vz.ObjectiveMetricGoal.MAXIMIZE + ) + ) + vectorized_optimizer_factory = vb.VectorizedOptimizerFactory( + strategy_factory=es.VectorizedEagleStrategyFactory(), + max_evaluations=100, + ) + + def dummy_prior_acquisition(xs: types.ModelInput): + return np.ones(xs.continuous.shape[0]) * 12345.0 + + designer = gp_ucb_pe.VizierGPUCBPEBandit( + problem, + acquisition_optimizer_factory=vectorized_optimizer_factory, + metadata_ns='gp_ucb_pe_bandit_test', + num_seed_trials=1, + config=gp_ucb_pe.UCBPEConfig( + ucb_coefficient=10.0, + explore_region_ucb_coefficient=0.5, + cb_violation_penalty_coefficient=10.0, + ucb_overwrite_probability=0.0, + pe_overwrite_probability=0.0, + signal_to_noise_threshold=0.0, + optimize_set_acquisition_for_exploration=( + optimize_set_acquisition_for_exploration + ), + ), + padding_schedule=padding.PaddingSchedule( + num_trials=padding.PaddingType.MULTIPLES_OF_10 + ), + prior_acquisition=dummy_prior_acquisition, + rng=jax.random.PRNGKey(1), + ) + + trial_id = 1 + batch_size = 3 + iters = 2 + rng = jax.random.PRNGKey(1) + all_trials = [] + # Simulates a batch suggestion loop that completes a full batch of + # suggestions before asking for the next batch. + for _ in range(iters): + suggestions = designer.suggest(count=batch_size) + self.assertLen(suggestions, batch_size) + completed_trials = [] + for suggestion in suggestions: + problem.search_space.assert_contains(suggestion.parameters) + trial_id += 1 + measurement = vz.Measurement() + for mi in problem.metric_information: + measurement.metrics[mi.name] = float( + jax.random.uniform( + rng, + minval=mi.min_value_or(lambda: -10.0), + maxval=mi.max_value_or(lambda: 10.0), + ) + ) + rng, _ = jax.random.split(rng) + completed_trials.append( + suggestion.to_trial(trial_id).complete(measurement) + ) + all_trials.extend(completed_trials) + designer.update( + completed=abstractions.CompletedTrials(completed_trials), + all_active=abstractions.ActiveTrials(), + ) + + self.assertLen(all_trials, iters * batch_size) + + set_acq_value = None + stddev_from_all_list = [] + for idx, trial in enumerate(all_trials): + if idx < batch_size: + # Skips the first batch of suggestions, which are generated by the + # seeding designer, not acquisition function optimization. + continue + mean, stddev, stddev_from_all, acq, use_ucb = _extract_predictions( + trial.metadata.ns('gp_ucb_pe_bandit_test') + ) + prior_acq_value = float( + trial.metadata.ns('gp_ucb_pe_bandit_test') + .ns('prior_acquisition') + .get('value') + ) + self.assertEqual(prior_acq_value, 12345.0) + if idx % batch_size == 0: + # The first suggestion in a batch is expected to be generated by UCB, + # and the acquisition value is expected to be the sum of UCB and the + # prior acquisition value. + self.assertTrue(use_ucb) + self.assertAlmostEqual( + mean + 10.0 * stddev + prior_acq_value, + acq, + ) + else: + # Later suggestions in a batch are expected to be generated by PE, + # and the acquisition value is expected to be the sum of PE and the + # prior acquisition value. + self.assertFalse(use_ucb) + if optimize_set_acquisition_for_exploration: + # The acquisition value is expected to be the sum of the + # log-determinant of the predicted covariance matrix and the prior + # acquisition value (12345.0), so it should be greater than 10000.0. + self.assertGreater(acq, 10000.0) + stddev_from_all_list.append(stddev_from_all) + if set_acq_value is None: + set_acq_value = acq - prior_acq_value + else: + self.assertAlmostEqual(set_acq_value, acq - prior_acq_value) + else: + self.assertAlmostEqual(stddev_from_all + prior_acq_value, acq) + + if optimize_set_acquisition_for_exploration: + geometric_mean_of_pred_cov_eigs = np.exp(set_acq_value / (batch_size - 1)) + arithmetic_mean_of_pred_cov_eigs = np.mean( + np.square(stddev_from_all_list) + ) + self.assertLessEqual( + geometric_mean_of_pred_cov_eigs, arithmetic_mean_of_pred_cov_eigs + ) + if __name__ == '__main__': jax.config.update('jax_enable_x64', True)