Skip to content

Commit

Permalink
Supports the Separable kernel in VizierGaussianProcess for multimet…
Browse files Browse the repository at this point in the history
…ric problems.

PiperOrigin-RevId: 715111661
  • Loading branch information
vizier-team authored and copybara-github committed Jan 13, 2025
1 parent e1857f8 commit eaff660
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 117 deletions.
227 changes: 117 additions & 110 deletions vizier/_src/jax/models/multitask_tuned_gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,120 @@ class MultiTaskType(enum.Enum):
SEPARABLE_DIAG_TASK_KERNEL_PRIOR = 'separable_diag_task_kernel_prior'


def build_task_kernel_scale_linop(
num_tasks: int,
multitask_type: MultiTaskType,
) -> Generator[sp.ModelParameter, jax.Array, tfp.tf2jax.linalg.LinearOperator]:
"""Builds a Separable MultiTask GP's task kernel scale LinearOperator.
Args:
num_tasks: The number of tasks.
multitask_type: The type of MultiTask GP.
Yields:
Model parameters for the task kernel scale and a LinearOperator representing
the task kernel scale.
"""
if multitask_type == MultiTaskType.SEPARABLE_DIAG_TASK_KERNEL_PRIOR:
correlation_diag = yield sp.ModelParameter.from_prior(
tfd.Sample(
tfd.Uniform(low=jnp.float64(1e-6), high=1.0),
sample_shape=num_tasks,
name='correlation_diag',
),
constraint=sp.Constraint(
bounds=(1e-6, 1.0),
bijector=tfb.Sigmoid(low=jnp.float64(1e-6), high=1.0),
),
)
task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorDiag(
correlation_diag
)
elif multitask_type == MultiTaskType.SEPARABLE_LKJ_TASK_KERNEL_PRIOR:
# Generate parameters for the Cholesky of the task kernel matrix,
# which accounts for correlations between tasks.
num_task_kernel_entries = tfb.CorrelationCholesky().inverse_event_shape(
[num_tasks, num_tasks]
)
correlation_cholesky_vec = yield sp.ModelParameter(
init_fn=lambda key: tfd.Sample( # pylint: disable=g-long-lambda
tfd.Normal(jnp.float64(0.0), 1.0), num_task_kernel_entries
).sample(seed=key),
# Use `jnp.copy` to prevent tracers leaking from bijector cache.
regularizer=lambda x: -tfd.CholeskyLKJ( # pylint: disable=g-long-lambda
dimension=num_tasks, concentration=1.0
).log_prob(tfb.CorrelationCholesky()(jnp.copy(x))),
name='task_kernel_correlation_cholesky_vec',
)

task_kernel_correlation_cholesky = tfb.CorrelationCholesky()(
jnp.copy(correlation_cholesky_vec)
)

task_kernel_scale_vec = yield sp.ModelParameter(
init_fn=functools.partial(
jax.random.uniform,
shape=(num_tasks,),
dtype=jnp.float64,
minval=1e-6,
maxval=1.0,
),
constraint=sp.Constraint(
bounds=(1e-6, 1.0),
bijector=tfb.Sigmoid(low=jnp.float64(1e-6), high=1.0),
),
name='task_kernel_sqrt_diagonal',
)
task_kernel_cholesky = (
task_kernel_correlation_cholesky * task_kernel_scale_vec[:, jnp.newaxis]
)

# Build the `LinearOperator` object representing the task kernel matrix,
# to parameterize the Separable kernel.
task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorLowerTriangular(
task_kernel_cholesky
)
elif multitask_type == MultiTaskType.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR:
# Generate parameters for the Cholesky of the task kernel matrix;
# accounts for correlations between tasks. The task kernel matrix must
# be positive definite, so we construct it via a Cholesky factor.
# Define the prior of the kernel task matrix to be centered at the
# identity.
prior_mean = jnp.eye(num_tasks, dtype=jnp.float64)
prior_mean_vec = tfb.FillTriangular().inverse(prior_mean)
prior_mean_batched = jnp.broadcast_to(prior_mean_vec, prior_mean_vec.shape)

task_kernel_cholesky_entries = yield sp.ModelParameter.from_prior(
tfd.Independent(
tfd.Normal(prior_mean_batched, 1.0),
reinterpreted_batch_ndims=1,
name='task_kernel_cholesky_entries',
)
)

# Apply a bijector to pack the task kernel entries into a lower
# triangular matrix and ensure the diagonal is positive.
task_kernel_bijector = tfb.Chain([
tfb.TransformDiagonal(
tfb.Chain([tfb.Shift(jnp.float64(1e-3)), tfb.Softplus()])
),
tfb.FillTriangular(),
])
task_kernel_cholesky = task_kernel_bijector(
jnp.copy(task_kernel_cholesky_entries)
)

# Build the `LinearOperator` object representing the task kernel
# matrix, to parameterize the Separable kernel.
task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorLowerTriangular(
task_kernel_cholesky
)
else:
raise ValueError(f'Unsupported multitask type: {multitask_type}')

return task_kernel_scale_linop


@struct.dataclass
class VizierMultitaskGaussianProcess(
sp.ModelCoroutine[Union[tfd.GaussianProcess, tfde.MultiTaskGaussianProcess]]
Expand Down Expand Up @@ -101,115 +215,6 @@ def sample(key: Any) -> jnp.ndarray:

return sample

def _build_task_kernel_scale_linop(
self,
) -> Generator[
sp.ModelParameter, jax.Array, tfp.tf2jax.linalg.LinearOperator
]:
if self._multitask_type == MultiTaskType.SEPARABLE_DIAG_TASK_KERNEL_PRIOR:
correlation_diag = yield sp.ModelParameter.from_prior(
tfd.Sample(
tfd.Uniform(low=jnp.float64(1e-6), high=1.0),
sample_shape=self._num_tasks,
name='correlation_diag',
),
constraint=sp.Constraint(
bounds=(1e-6, 1.0),
bijector=tfb.Sigmoid(low=jnp.float64(1e-6), high=1.0),
),
)
task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorDiag(
correlation_diag
)
elif self._multitask_type == MultiTaskType.SEPARABLE_LKJ_TASK_KERNEL_PRIOR:
# Generate parameters for the Cholesky of the task kernel matrix,
# which accounts for correlations between tasks.
num_task_kernel_entries = tfb.CorrelationCholesky().inverse_event_shape(
[self._num_tasks, self._num_tasks]
)
correlation_cholesky_vec = yield sp.ModelParameter(
init_fn=lambda key: tfd.Sample( # pylint: disable=g-long-lambda
tfd.Normal(jnp.float64(0.0), 1.0), num_task_kernel_entries
).sample(seed=key),
# Use `jnp.copy` to prevent tracers leaking from bijector cache.
regularizer=lambda x: -tfd.CholeskyLKJ( # pylint: disable=g-long-lambda
dimension=self._num_tasks, concentration=1.0
).log_prob(tfb.CorrelationCholesky()(jnp.copy(x))),
name='task_kernel_correlation_cholesky_vec',
)

task_kernel_correlation_cholesky = tfb.CorrelationCholesky()(
jnp.copy(correlation_cholesky_vec)
)

task_kernel_scale_vec = yield sp.ModelParameter(
init_fn=functools.partial(
jax.random.uniform,
shape=(self._num_tasks,),
dtype=jnp.float64,
minval=1e-6,
maxval=1.0,
),
constraint=sp.Constraint(
bounds=(1e-6, 1.0),
bijector=tfb.Sigmoid(low=jnp.float64(1e-6), high=1.0),
),
name='task_kernel_sqrt_diagonal',
)
task_kernel_cholesky = (
task_kernel_correlation_cholesky
* task_kernel_scale_vec[:, jnp.newaxis]
)

# Build the `LinearOperator` object representing the task kernel matrix,
# to parameterize the Separable kernel.
task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorLowerTriangular(
task_kernel_cholesky
)
elif (
self._multitask_type == MultiTaskType.SEPARABLE_NORMAL_TASK_KERNEL_PRIOR
):
# Generate parameters for the Cholesky of the task kernel matrix;
# accounts for correlations between tasks. The task kernel matrix must
# be positive definite, so we construct it via a Cholesky factor.
# Define the prior of the kernel task matrix to be centered at the
# identity.
prior_mean = jnp.eye(self._num_tasks, dtype=jnp.float64)
prior_mean_vec = tfb.FillTriangular().inverse(prior_mean)
prior_mean_batched = jnp.broadcast_to(
prior_mean_vec, prior_mean_vec.shape
)

task_kernel_cholesky_entries = yield sp.ModelParameter.from_prior(
tfd.Independent(
tfd.Normal(prior_mean_batched, 1.0),
reinterpreted_batch_ndims=1,
name='task_kernel_cholesky_entries',
)
)

# Apply a bijector to pack the task kernel entries into a lower
# triangular matrix and ensure the diagonal is positive.
task_kernel_bijector = tfb.Chain([
tfb.TransformDiagonal(
tfb.Chain([tfb.Shift(jnp.float64(1e-6)), tfb.Softplus()])
),
tfb.FillTriangular(),
])
task_kernel_cholesky = task_kernel_bijector(
jnp.copy(task_kernel_cholesky_entries)
)

# Build the `LinearOperator` object representing the task kernel
# matrix, to parameterize the Separable kernel.
task_kernel_scale_linop = tfp.tf2jax.linalg.LinearOperatorLowerTriangular(
task_kernel_cholesky
)
else:
raise ValueError(f'Unsupported multitask type: {self._multitask_type}')

return task_kernel_scale_linop

def __call__(
self, inputs: Optional[types.ModelInput] = None
) -> Generator[
Expand Down Expand Up @@ -356,7 +361,9 @@ def __call__(
if self._multitask_type == MultiTaskType.INDEPENDENT:
multitask_kernel = tfpke.Independent(self._num_tasks, kernel)
else:
task_kernel_scale_linop = yield from self._build_task_kernel_scale_linop()
task_kernel_scale_linop = yield from build_task_kernel_scale_linop(
self._num_tasks, self._multitask_type
)
multitask_kernel = tfpke.Separable(
self._num_tasks,
base_kernel=kernel,
Expand Down
27 changes: 26 additions & 1 deletion vizier/_src/jax/models/tuned_gp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vizier._src.jax import types
from vizier._src.jax.models import continuous_only_kernel
from vizier._src.jax.models import mask_features
from vizier._src.jax.models import multitask_tuned_gp_models

tfb = tfp.bijectors
tfd = tfp.distributions
Expand Down Expand Up @@ -93,6 +94,10 @@ class VizierGaussianProcess(sp.ModelCoroutine[tfd.GaussianProcess]):
)
_boundary_epsilon: float = struct.field(default=1e-12, kw_only=True)
_linear_coef: Optional[float] = struct.field(default=None, kw_only=True)
_multitask_type: multitask_tuned_gp_models.MultiTaskType = struct.field(
default=multitask_tuned_gp_models.MultiTaskType.INDEPENDENT,
kw_only=True,
)

def __attrs_post_init__(self):
if self._num_metrics < 1:
Expand All @@ -107,6 +112,9 @@ def build_model(
*,
use_retrying_cholesky: bool = True,
linear_coef: Optional[float] = None,
multitask_type: multitask_tuned_gp_models.MultiTaskType = (
multitask_tuned_gp_models.MultiTaskType.INDEPENDENT
),
) -> sp.StochasticProcessModel:
"""Returns a StochasticProcessModel for the GP."""
gp_coroutine = VizierGaussianProcess(
Expand All @@ -117,6 +125,7 @@ def build_model(
_num_metrics=data.labels.shape[-1],
_use_retrying_cholesky=use_retrying_cholesky,
_linear_coef=linear_coef,
_multitask_type=multitask_type,
)
return sp.StochasticProcessModel(gp_coroutine)

Expand Down Expand Up @@ -271,8 +280,24 @@ def __call__(
cholesky_fn = lambda matrix: retrying_cholesky(matrix)[0]

if self._num_metrics > 1:
if (
self._multitask_type
== multitask_tuned_gp_models.MultiTaskType.INDEPENDENT
):
multitask_kernel = tfpke.Independent(self._num_metrics, kernel)
else:
task_kernel_scale_linop = (
yield from multitask_tuned_gp_models.build_task_kernel_scale_linop(
self._num_metrics, self._multitask_type
)
)
multitask_kernel = tfpke.Separable(
self._num_metrics,
base_kernel=kernel,
task_kernel_scale_linop=task_kernel_scale_linop,
)
return tfde.MultiTaskGaussianProcess(
tfpke.Independent(self._num_metrics, kernel),
multitask_kernel,
index_points=inputs,
observation_noise_variance=observation_noise_variance,
cholesky_fn=cholesky_fn,
Expand Down
Loading

0 comments on commit eaff660

Please sign in to comment.