Skip to content

Commit

Permalink
Supports a prior acquisition function in GP_UCB_PE
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726507984
  • Loading branch information
vizier-team authored and copybara-github committed Feb 13, 2025
1 parent 9e4ea96 commit f30840c
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 11 deletions.
75 changes: 64 additions & 11 deletions vizier/_src/algorithms/designers/gp_ucb_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,10 @@ class UCBScoreFunction(eqx.Module):
The UCB acquisition value is the sum of the predicted mean based on completed
trials and the predicted standard deviation based on all trials, completed and
pending (scaled by the UCB coefficient). This class follows the
`acquisitions.ScoreFunction` protocol.
pending (scaled by the UCB coefficient). If `prior_acquisition` is not None,
the return value is the sum of the prior acquisition value and the UCB
acquisition value. This class follows the `acquisitions.ScoreFunction`
protocol.
Attributes:
predictive: Predictive model with cached Cholesky conditioned on completed
Expand All @@ -288,6 +290,7 @@ class UCBScoreFunction(eqx.Module):
on completed and pending trials.
ucb_coefficient: The UCB coefficient.
trust_region: Trust region.
prior_acquisition: An optional prior acquisition function.
scalarization_weights_rng: Random key for scalarization.
labels: Labels, shaped as [num_index_points, num_metrics].
num_scalarizations: Number of scalarizations.
Expand All @@ -297,6 +300,7 @@ class UCBScoreFunction(eqx.Module):
predictive_all_features: sp.UniformEnsemblePredictive
ucb_coefficient: jt.Float[jt.Array, '']
trust_region: Optional[acquisitions.TrustRegion]
prior_acquisition: Callable[[types.ModelInput], jax.Array] | None
labels: types.PaddedArray
scalarizer: scalarization.Scalarization

Expand All @@ -306,6 +310,7 @@ def __init__(
predictive_all_features: sp.UniformEnsemblePredictive,
ucb_coefficient: jt.Float[jt.Array, ''],
trust_region: Optional[acquisitions.TrustRegion],
prior_acquisition: Callable[[types.ModelInput], jax.Array] | None,
scalarization_weights_rng: jax.Array,
labels: types.PaddedArray,
num_scalarizations: int = 1000,
Expand All @@ -314,6 +319,7 @@ def __init__(
self.predictive_all_features = predictive_all_features
self.ucb_coefficient = ucb_coefficient
self.trust_region = trust_region
self.prior_acquisition = prior_acquisition
self.labels = labels
self.scalarizer = acquisitions.create_hv_scalarization(
num_scalarizations, labels, scalarization_weights_rng
Expand Down Expand Up @@ -357,11 +363,16 @@ def score_with_aux(
scalarized_acq_values = _apply_trust_region(
self.trust_region, xs, scalarized_acq_values
)
return scalarized_acq_values, {
aux = {
'mean': mean,
'stddev': gprm.stddev(),
'stddev_from_all': stddev_from_all,
}
if self.prior_acquisition is not None:
prior_acq_values = self.prior_acquisition(xs)
scalarized_acq_values = prior_acq_values + scalarized_acq_values
aux['prior_acq_values'] = prior_acq_values
return scalarized_acq_values, aux


class PEScoreFunction(eqx.Module):
Expand All @@ -370,8 +381,10 @@ class PEScoreFunction(eqx.Module):
The PE acquisition value is the predicted standard deviation (eq. (9)
in https://arxiv.org/pdf/1304.5350) based on all completed and active trials,
plus a penalty term that grows linearly in the amount of violation of the
constraint `UCB(xs) >= threshold`. This class follows the
`acquisitions.ScoreFunction` protocol.
constraint `UCB(xs) >= threshold`. If `prior_acquisition` is not None, the
returned value is the sum of the prior acquisition value and the PE
acquisition value. This class follows the `acquisitions.ScoreFunction`
protocol.
Attributes:
predictive: Predictive model with cached Cholesky conditioned on completed
Expand All @@ -383,6 +396,9 @@ class PEScoreFunction(eqx.Module):
values on `xs`.
penalty_coefficient: Multiplier on the constraint violation penalty.
trust_region:
prior_acquisition: An optional prior acquisition function.
multimetric_promising_region_penalty_type: The type of multimetric promising
region penalty.
Returns:
The Pure-Exploration acquisition value.
Expand All @@ -394,6 +410,7 @@ class PEScoreFunction(eqx.Module):
explore_ucb_coefficient: jt.Float[jt.Array, '']
penalty_coefficient: jt.Float[jt.Array, '']
trust_region: Optional[acquisitions.TrustRegion]
prior_acquisition: Callable[[types.ModelInput], jax.Array] | None
multimetric_promising_region_penalty_type: (
MultimetricPromisingRegionPenaltyType
)
Expand Down Expand Up @@ -457,11 +474,16 @@ def score_with_aux(
acq_values = stddev_from_all + penalty
if self.trust_region is not None:
acq_values = _apply_trust_region(self.trust_region, xs, acq_values)
return acq_values, {
aux = {
'mean': mean,
'stddev': stddev,
'stddev_from_all': stddev_from_all,
}
if self.prior_acquisition is not None:
prior_acq_values = self.prior_acquisition(xs)
acq_values += prior_acq_values
aux['prior_acq_values'] = prior_acq_values
return acq_values, aux


def _logdet(matrix: jax.Array):
Expand All @@ -486,8 +508,10 @@ class SetPEScoreFunction(eqx.Module):
predicted covariance matrix evaluated at the points (eq. (8) in
https://arxiv.org/pdf/1304.5350) based on all completed and active trials,
plus a penalty term that grows linearly in the amount of violation of the
constraint `UCB(xs) >= threshold`. This class follows the
`acquisitions.ScoreFunction` protocol.
constraint `UCB(xs) >= threshold`. If `prior_acquisition` is not None, the
returned value is the sum of the prior acquisition value and the PE
acquisition value. This class follows the `acquisitions.ScoreFunction`
protocol.
Attributes:
predictive: Predictive model with cached Cholesky conditioned on completed
Expand All @@ -499,6 +523,7 @@ class SetPEScoreFunction(eqx.Module):
values on `xs`.
penalty_coefficient: Multiplier on the constraint violation penalty.
trust_region:
prior_acquisition: An optional prior acquisition function.
Returns:
The Pure-Exploration acquisition value.
Expand All @@ -510,6 +535,7 @@ class SetPEScoreFunction(eqx.Module):
explore_ucb_coefficient: jt.Float[jt.Array, '']
penalty_coefficient: jt.Float[jt.Array, '']
trust_region: Optional[acquisitions.TrustRegion]
prior_acquisition: Callable[[types.ModelInput], jax.Array] | None

def score(
self, xs: types.ModelInput, seed: Optional[jax.Array] = None
Expand Down Expand Up @@ -549,11 +575,16 @@ def score_with_aux(
)
if self.trust_region is not None:
acq_values = _apply_trust_region_to_set(self.trust_region, xs, acq_values)
return acq_values, {
aux = {
'mean': mean,
'stddev': stddev,
'stddev_from_all': jnp.sqrt(jnp.diagonal(cov, axis1=1, axis2=2)),
}
if self.prior_acquisition is not None:
prior_acq_values = self.prior_acquisition(xs)
acq_values += prior_acq_values
aux['prior_acq_values'] = prior_acq_values
return acq_values, aux


def default_ard_optimizer() -> optimizers.Optimizer[types.ParameterDict]:
Expand Down Expand Up @@ -587,6 +618,14 @@ class method that takes `ModelInput` and returns a
observed.
rng: If not set, uses random numbers.
clear_jax_cache: If True, every `suggest` call clears the Jax cache.
padding_schedule: Configures what inputs (trials, features, labels) to pad
with what schedule. Useful for reducing JIT compilation passes. (Default
implies no padding.)
prior_acquisition: An optional prior acquisition function. If provided, the
suggestions will be generated by maximizing the sum of the prior
acquisition value and the GP-based acquisition value (UCB or PE). Useful
for biasing the suggestions towards a prior, e.g., being close to some
known parameter values.
"""

_problem: vz.ProblemStatement = attr.field(kw_only=False)
Expand Down Expand Up @@ -621,12 +660,13 @@ class method that takes `ModelInput` and returns a
factory=lambda: jax.random.PRNGKey(random.getrandbits(32)), kw_only=True
)
_clear_jax_cache: bool = attr.field(default=False, kw_only=True)
# Whether to pad all inputs, and what type of schedule to use. This is to
# ensure fewer JIT compilation passes. (Default implies no padding.)
# TODO: Check padding does not affect designer behavior.
_padding_schedule: padding.PaddingSchedule = attr.field(
factory=padding.PaddingSchedule, kw_only=True
)
_prior_acquisition: Callable[[types.ModelInput], jax.Array] | None = (
attr.field(factory=lambda: None, kw_only=True)
)

default_eagle_config = es.EagleStrategyConfig(
visibility=3.6782451729470043,
Expand Down Expand Up @@ -1003,6 +1043,7 @@ def _suggest_one(
predictive_all_features,
ucb_coefficient=self._config.ucb_coefficient,
trust_region=tr if self._use_trust_region else None,
prior_acquisition=self._prior_acquisition,
scalarization_weights_rng=scalarization_weights_rng,
labels=data.labels,
)
Expand All @@ -1014,6 +1055,7 @@ def _suggest_one(
ucb_coefficient=self._config.ucb_coefficient,
explore_ucb_coefficient=self._config.explore_region_ucb_coefficient,
trust_region=tr if self._use_trust_region else None,
prior_acquisition=self._prior_acquisition,
multimetric_promising_region_penalty_type=(
self._config.multimetric_promising_region_penalty_type
),
Expand Down Expand Up @@ -1083,6 +1125,11 @@ def _suggest_one(
'trust_radius': f'{tr.trust_radius}',
'params': f'{model.params}',
})
if 'prior_acq_values' in aux:
# Take the first element of the array because `aux` is computed only for
# the best candidate.
prior_acq_value = aux['prior_acq_values'][0]
metadata.ns('prior_acquisition').update({'value': f'{prior_acq_value}'})
metadata.ns('timing').update(
{'time': f'{datetime.datetime.now() - start_time}'}
)
Expand Down Expand Up @@ -1118,6 +1165,7 @@ def _suggest_batch_with_exploration(
ucb_coefficient=self._config.ucb_coefficient,
explore_ucb_coefficient=self._config.explore_region_ucb_coefficient,
trust_region=tr if self._use_trust_region else None,
prior_acquisition=self._prior_acquisition,
)

acquisition_optimizer = self._acquisition_optimizer_factory(self._converter)
Expand Down Expand Up @@ -1180,6 +1228,11 @@ def _suggest_batch_with_exploration(
'trust_radius': f'{tr.trust_radius}',
'params': f'{model.params}',
})
if 'prior_acq_values' in aux:
# Take the first element of the array because `aux` is computed only for
# the best candidate.
prior_acq_value = aux['prior_acq_values'][0]
metadata.ns('prior_acquisition').update({'value': f'{prior_acq_value}'})
metadata.ns('timing').update({'time': f'{end_time - start_time}'})
suggestions.append(
vz.TrialSuggestion(
Expand Down
134 changes: 134 additions & 0 deletions vizier/_src/algorithms/designers/gp_ucb_pe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vizier._src.algorithms.designers import quasi_random
from vizier._src.algorithms.optimizers import eagle_strategy as es
from vizier._src.algorithms.optimizers import vectorized_base as vb
from vizier._src.jax import types
from vizier._src.jax.models import multitask_tuned_gp_models
from vizier.jax import optimizers
from vizier.pyvizier.converters import padding
Expand Down Expand Up @@ -450,6 +451,139 @@ def test_ucb_overwrite(self):
)
self.assertTrue(use_ucb)

@parameterized.parameters(
dict(optimize_set_acquisition_for_exploration=False),
dict(optimize_set_acquisition_for_exploration=True),
)
def test_prior_acquisition(
self, optimize_set_acquisition_for_exploration: bool
):
problem = vz.ProblemStatement(
test_studies.flat_continuous_space_with_scaling()
)
problem.metric_information.append(
vz.MetricInformation(
name='metric', goal=vz.ObjectiveMetricGoal.MAXIMIZE
)
)
vectorized_optimizer_factory = vb.VectorizedOptimizerFactory(
strategy_factory=es.VectorizedEagleStrategyFactory(),
max_evaluations=100,
)

def dummy_prior_acquisition(xs: types.ModelInput):
return np.ones(xs.continuous.shape[0]) * 12345.0

designer = gp_ucb_pe.VizierGPUCBPEBandit(
problem,
acquisition_optimizer_factory=vectorized_optimizer_factory,
metadata_ns='gp_ucb_pe_bandit_test',
num_seed_trials=1,
config=gp_ucb_pe.UCBPEConfig(
ucb_coefficient=10.0,
explore_region_ucb_coefficient=0.5,
cb_violation_penalty_coefficient=10.0,
ucb_overwrite_probability=0.0,
pe_overwrite_probability=0.0,
signal_to_noise_threshold=0.0,
optimize_set_acquisition_for_exploration=(
optimize_set_acquisition_for_exploration
),
),
padding_schedule=padding.PaddingSchedule(
num_trials=padding.PaddingType.MULTIPLES_OF_10
),
prior_acquisition=dummy_prior_acquisition,
rng=jax.random.PRNGKey(1),
)

trial_id = 1
batch_size = 3
iters = 2
rng = jax.random.PRNGKey(1)
all_trials = []
# Simulates a batch suggestion loop that completes a full batch of
# suggestions before asking for the next batch.
for _ in range(iters):
suggestions = designer.suggest(count=batch_size)
self.assertLen(suggestions, batch_size)
completed_trials = []
for suggestion in suggestions:
problem.search_space.assert_contains(suggestion.parameters)
trial_id += 1
measurement = vz.Measurement()
for mi in problem.metric_information:
measurement.metrics[mi.name] = float(
jax.random.uniform(
rng,
minval=mi.min_value_or(lambda: -10.0),
maxval=mi.max_value_or(lambda: 10.0),
)
)
rng, _ = jax.random.split(rng)
completed_trials.append(
suggestion.to_trial(trial_id).complete(measurement)
)
all_trials.extend(completed_trials)
designer.update(
completed=abstractions.CompletedTrials(completed_trials),
all_active=abstractions.ActiveTrials(),
)

self.assertLen(all_trials, iters * batch_size)

set_acq_value = None
stddev_from_all_list = []
for idx, trial in enumerate(all_trials):
if idx < batch_size:
# Skips the first batch of suggestions, which are generated by the
# seeding designer, not acquisition function optimization.
continue
mean, stddev, stddev_from_all, acq, use_ucb = _extract_predictions(
trial.metadata.ns('gp_ucb_pe_bandit_test')
)
prior_acq_value = float(
trial.metadata.ns('gp_ucb_pe_bandit_test')
.ns('prior_acquisition')
.get('value')
)
self.assertEqual(prior_acq_value, 12345.0)
if idx % batch_size == 0:
# The first suggestion in a batch is expected to be generated by UCB,
# and the acquisition value is expected to be the sum of UCB and the
# prior acquisition value.
self.assertTrue(use_ucb)
self.assertAlmostEqual(
mean + 10.0 * stddev + prior_acq_value,
acq,
)
else:
# Later suggestions in a batch are expected to be generated by PE,
# and the acquisition value is expected to be the sum of PE and the
# prior acquisition value.
self.assertFalse(use_ucb)
if optimize_set_acquisition_for_exploration:
# The acquisition value is expected to be the sum of the
# log-determinant of the predicted covariance matrix and the prior
# acquisition value (12345.0), so it should be greater than 10000.0.
self.assertGreater(acq, 10000.0)
stddev_from_all_list.append(stddev_from_all)
if set_acq_value is None:
set_acq_value = acq - prior_acq_value
else:
self.assertAlmostEqual(set_acq_value, acq - prior_acq_value)
else:
self.assertAlmostEqual(stddev_from_all + prior_acq_value, acq)

if optimize_set_acquisition_for_exploration:
geometric_mean_of_pred_cov_eigs = np.exp(set_acq_value / (batch_size - 1))
arithmetic_mean_of_pred_cov_eigs = np.mean(
np.square(stddev_from_all_list)
)
self.assertLessEqual(
geometric_mean_of_pred_cov_eigs, arithmetic_mean_of_pred_cov_eigs
)


if __name__ == '__main__':
jax.config.update('jax_enable_x64', True)
Expand Down

0 comments on commit f30840c

Please sign in to comment.