Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Multitask Approximate GP compatibility with NNVariationalStrategy #2560

Open
ancorso opened this issue Aug 7, 2024 · 0 comments
Open
Labels

Comments

@ancorso
Copy link

ancorso commented Aug 7, 2024

🐛 Bug

Multitask GPs (using the LMCVariationalStrategy and MeanFieldVariationalDistribution) are seemingly incompatible with the NNVariationalStrategy. This makes it difficult to train multitask GPs with very large numbers of inducing points.

To reproduce

I modified the Variational GPs w/ Multiple Outputs example to try to use NNVariationalStrategy as follows:

** Code snippet to reproduce **

import math
import torch
import gpytorch
import tqdm
from matplotlib import pyplot as plt

train_x = torch.linspace(0, 1, 100)

train_y = torch.stack([
    torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    torch.sin(train_x * (2 * math.pi)) + 2 * torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
    -torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
], -1)

print(train_x.shape, train_y.shape)

num_latents = 3
num_tasks = 4

class MultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points):
        
        variational_distribution = gpytorch.variational.MeanFieldVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_latents])
        )
        
        variational_strategy = gpytorch.variational.LMCVariationalStrategy(
            gpytorch.variational.NNVariationalStrategy(self, inducing_points, variational_distribution, k=8, training_batch_size=16),
            num_tasks=num_tasks,
            num_latents=num_latents,
            latent_dim=-1,
        )

        super().__init__(variational_strategy)

        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
            batch_shape=torch.Size([num_latents])
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


model = MultitaskGPModel(train_x)
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_tasks)

num_epochs = 150

model.train()
likelihood.train()

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr=0.1)

# Our loss object. We're using the VariationalELBO, which essentially just computes the ELBO
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))

# We use more CG iterations here because the preconditioner introduced in the NeurIPS paper seems to be less
# effective for VI.
epochs_iter = tqdm.tqdm(range(num_epochs), desc="Epoch")
for i in epochs_iter:
    # Within each iteration, we will go over each minibatch of data
    optimizer.zero_grad()
    output = model(None)
    loss = -mll(output, train_y)
    epochs_iter.set_postfix(loss=loss.item())
    loss.backward()
    optimizer.step()

** Stack trace/error message **

Traceback (most recent call last):
  File "<stdin>", line 4, in <module>
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/models/approximate_gp.py", line 114, in __call__
    return self.variational_strategy(inputs, prior=prior, **kwargs)
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/lmc_variational_strategy.py", line 197, in __call__
    latent_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 145, in __call__
    return self.forward(
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 192, in forward
    kl = self._kl_divergence(kl_indices)
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 369, in _kl_divergence
    kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices)
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 313, in _stochastic_kl_helper
    cov = self.model.covar_module.forward(nearest_neighbors, nearest_neighbors)
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/kernels/scale_kernel.py", line 109, in forward
    orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/kernels/rbf_kernel.py", line 80, in forward
    return RBFCovariance.apply(
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/Users/anthonycorso/micromamba/envs/gpytorch/lib/python3.8/site-packages/gpytorch/functions/rbf_covariance.py", line 12, in forward
    x1_ = x1.div(lengthscale)
RuntimeError: The size of tensor a (16) must match the size of tensor b (3) at non-singleton dimension 1

Expected Behavior

The forward call to NNVariationalStrategy should work in a multi-task setting

System information

  • GPyTorch Version 1.12
  • PyTorch Version 2.3.1
  • MacOS 14.5
@ancorso ancorso added the bug label Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant