Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] get_fantasy_likelihood method broken for DirichletClassificationLikelihood #2579

Open
SaiAakash opened this issue Sep 7, 2024 · 0 comments
Labels

Comments

@SaiAakash
Copy link

🐛 Bug

Conditioning on new observations for a multi-class classification model with DirichletClassificationLikelihood throws an error.

To reproduce

** Code snippet to reproduce **

import torch
import numpy as np
import gpytorch

from gpytorch.models import ExactGP
from gpytorch.likelihoods import DirichletClassificationLikelihood
from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel


def gen_data(num_data, seed=2019):
    torch.random.manual_seed(seed)

    x = torch.randn(num_data, 1)
    y = torch.randn(num_data, 1)

    u = torch.rand(1)
    data_fn = lambda x, y: 1 * torch.sin(0.15 * u * 3.1415 * (x + y)) + 1
    latent_fn = data_fn(x, y)
    z = torch.round(latent_fn).long().squeeze()
    return torch.cat((x, y), dim=1), z, data_fn


train_x, train_y, genfn = gen_data(500)


# We will use the simplest form of GP model, exact inference
class DirichletGPModel(ExactGP):
    def __init__(self, train_x, train_y, likelihood, num_classes):
        super(DirichletGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean(batch_shape=torch.Size((num_classes,)))
        self.covar_module = ScaleKernel(
            RBFKernel(batch_shape=torch.Size((num_classes,))),
            batch_shape=torch.Size((num_classes,)),
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


# initialize likelihood and model
# we let the DirichletClassificationLikelihood compute the targets for us
likelihood = DirichletClassificationLikelihood(train_y, learn_additional_noise=True)
model = DirichletGPModel(
    train_x,
    likelihood.transformed_targets,
    likelihood,
    num_classes=likelihood.num_classes,
)

# Training loop
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(
    model.parameters(), lr=0.1
)  # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(50):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(train_x)
    # Calc loss and backprop gradients
    loss = -mll(output, likelihood.transformed_targets).sum()
    loss.backward()
    if i % 5 == 0:
        print(
            "Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f"
            % (
                i + 1,
                50,
                loss.item(),
                model.covar_module.base_kernel.lengthscale.mean().item(),
                model.likelihood.second_noise_covar.noise.mean().item(),
            )
        )
    optimizer.step()


model.eval()
likelihood.eval()

with gpytorch.settings.fast_pred_var(), torch.no_grad():
    test_dist = model(train_x)

    pred_means = test_dist.loc

# Fantasize on new observations
new_xy, new_z, genfn = gen_data(20, seed=2000)
_, new_z, num_classes = likelihood._prepare_targets(new_z.unsqueeze(0))
updated_model = model.get_fantasy_model(new_xy, new_z.T)

** Stack trace/error message **

{
	"name": "RuntimeError",
	"message": "FixedNoiseGaussianLikelihood.fantasize requires a `targets` kwarg",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[39], line 99
     97 new_xy, new_z, genfn = gen_data(20, seed=2000)
     98 _, new_z, num_classes = likelihood._prepare_targets(new_z.unsqueeze(0))
---> 99 updated_model = model.get_fantasy_model(new_xy, new_z.T)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/models/exact_gp.py:238, in ExactGP.get_fantasy_model(self, inputs, targets, **kwargs)
    235 self.train_targets = old_train_targets
    236 self.likelihood = old_likelihood
--> 238 new_model.likelihood = old_likelihood.get_fantasy_likelihood(**fantasy_kwargs)
    239 new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
    240     inputs, targets, full_inputs, full_targets, full_output, **fantasy_kwargs
    241 )
    243 # if the fantasies are at the same points, we need to expand the inputs for the new model

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/likelihoods/gaussian_likelihood.py:439, in DirichletClassificationLikelihood.get_fantasy_likelihood(self, **kwargs)
    435 def get_fantasy_likelihood(self, **kwargs: Any) -> \"DirichletClassificationLikelihood\":
    436     # we assume that the number of classes does not change.
    438     if \"targets\" not in kwargs:
--> 439         raise RuntimeError(\"FixedNoiseGaussianLikelihood.fantasize requires a `targets` kwarg\")
    441     old_noise_covar = self.noise_covar
    442     self.noise_covar = None  # pyre-fixme[8]

RuntimeError: FixedNoiseGaussianLikelihood.fantasize requires a `targets` kwarg"
}

Expected Behavior

Should return the updated model with the fantasized likelihood.

System information

Please complete the following information:

  • GPyTorch version: 1.12
  • PyTorch version: 2.4.0
  • OS: macOS Sonoma 14.5

Additional context

I can see that a Runtime Error is raised in the get_fantasy_likelihood method of DirichletClassificationLikelihood for the absence of targets in kwargs. However, I can't see targets being used anywhere in that method. Also, it is not possible to pass a kwarg called targets because the get_fantasy_likelihood method is called inside the get_fantasy_model method of the ExactGP class and this method already takes in a separate targets argument. So basically, the same function cannot take two arguments called targets.

@SaiAakash SaiAakash added the bug label Sep 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant