Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] KeOps RBF kernel not properly equipped with prediction variance? #2566

Open
matthieudelsart opened this issue Aug 13, 2024 · 0 comments
Labels

Comments

@matthieudelsart
Copy link

matthieudelsart commented Aug 13, 2024

🐛 Bug

Contrarily to when using the standard RBF kernel, using the keops.RBFKernel to get the predicted variance results in a bug, which seems similar to this one.
The same thing occurs when trying to predict the standard deviation, confidence intervals, etc.

To reproduce

Code snippet to reproduce

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.LinearMean(input_size=train_x.size(-1))
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.keops.RBFKernel(ard_num_dims=train_x.size(-1)))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda() # I am using the FixedGaussianNoiseLikelihood but the same issue seems to occur everywhere
model = ExactGPModel(train_x, train_y, likelihood).cuda()

model.train()
likelihood.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.1) 
mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)

training_iter = 10

for i in range(training_iter):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f - Lengthscale_0: %.3f - Outputscale: %.3f' % (
        i + 1, training_iter, loss.item(),
        model.covar_module.base_kernel.lengthscale[0, 0].item(),
        model.covar_module.outputscale.item(),
    ))  
    optimizer.step()

model.eval()

with torch.no_grad(), gpytorch.settings.fast_pred_var():
    observed_pred = model.likelihood(model(train_x))
    pred_mean = observed_pred.mean
    pred_variance = observed_pred.variance

Stack trace/error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], [line 45](vscode-notebook-cell:?execution_count=16&line=45)
     [43](vscode-notebook-cell:?execution_count=16&line=43) observed_pred = model.likelihood(model(train_x))
     [44](vscode-notebook-cell:?execution_count=16&line=44) pred_mean = observed_pred.mean
---> [45](vscode-notebook-cell:?execution_count=16&line=45) pred_variance = observed_pred.variance

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:309, in MultivariateNormal.variance(self)
    [305](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:305) @property
    [306](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:306) def variance(self) -> Tensor:
    [307](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:307)     if self.islazy:
    [308](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:308)         # overwrite this since torch MVN uses unbroadcasted_scale_tril for this
--> [309](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:309)         diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
    [310](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:310)         diag = diag.view(diag.shape[:-1] + self._event_shape)
    [311](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:311)         variance = diag.expand(self._batch_shape + self._event_shape)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1411, in LinearOperator.diagonal(self, offset, dim1, dim2)
   [1409](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1409) elif not self.is_square:
   [1410](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1410)     raise RuntimeError("LinearOperator#diagonal is only implemented for square operators.")
-> [1411](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1411) return self._diagonal()

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in SumLinearOperator._diagonal(self)
     [28](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:28) def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
---> [29](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29)     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in <genexpr>(.0)
     [28](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:28) def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
---> [29](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29)     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in SumLinearOperator._diagonal(self)
     [28](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:28) def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
---> [29](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29)     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29, in <genexpr>(.0)
     [28](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:28) def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
---> [29](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/linear_operator/operators/sum_linear_operator.py:29)     return sum(linear_op._diagonal().contiguous() for linear_op in self.linear_ops)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
     [57](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:57) kwargs_pkl = pickle.dumps(kwargs)
     [58](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:58) if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> [59](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:59)     return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
     [60](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/utils/memoize.py:60) return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:25, in recall_grad_state.<locals>.wrapped(self, *args, **kwargs)
     [22](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:22) @functools.wraps(method)
     [23](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:23) def wrapped(self, *args, **kwargs):
     [24](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:24)     with torch.set_grad_enabled(self._is_grad_enabled):
---> [25](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:25)         output = method(self, *args, **kwargs)
     [26](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:26)     return output

File /opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:126, in LazyEvaluatedKernelTensor._diagonal(self)
    [124](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:124)     expected_shape = self.shape[:-1]
    [125](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:125)     if res.shape != expected_shape:
--> [126](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:126)         raise RuntimeError(
    [127](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:127)             "The kernel {} is not equipped to handle and diag. Expected size {}. "
    [128](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:128)             "Got size {}".format(self.kernel.__class__.__name__, expected_shape, res.shape)
    [129](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:129)         )
    [131](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:131) if isinstance(res, LinearOperator):
    [132](https://vscode-remote+ssh-002dremote-002bgp-002dinstance.vscode-resource.vscode-cdn.net/opt/conda/envs/pytorch/lib/python3.11/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:132)     res = res.to_dense()

RuntimeError: The kernel ScaleKernel is not equipped to handle and diag. Expected size torch.Size([45159]). Got size torch.Size([45159, 45159])

Expected Behavior

I get for pred_variance a tensor corresponding to the predicted variance for each point, like when using the standard RBF kernel.

System information

  • GPyTorch Version: 1.12
  • PyTorch Version: 2.3.0
  • Computer info:
    • MacBook Pro M2, MacOs Sonoma 14.4.1
    • Service: Connected to Amazon EC2
    • Instance: g5.xlarge, Amazon Linux 2
@matthieudelsart matthieudelsart changed the title [Bug] Keops RBF kernel not equipped with covariance matrix? [Bug] KeOps RBF kernel not properly equipped with prediction variance? Aug 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant