Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Some priors do not respect selected device #2581

Open
slishak-PX opened this issue Sep 9, 2024 · 2 comments
Open

[Bug] Some priors do not respect selected device #2581

slishak-PX opened this issue Sep 9, 2024 · 2 comments
Labels

Comments

@slishak-PX
Copy link

slishak-PX commented Sep 9, 2024

🐛 Bug

When sampling from a prior that's been moved to GPU, the correct device is only used for some priors, even though the state_dict has been updated correctly (as of #2550, which this issue seems related to, although no regression was introduced as far as I can tell):

from gpytorch import priors

for prior in (
    priors.NormalPrior(1.0, 1.0),
    priors.GammaPrior(1.0, 1.0),
    priors.HalfCauchyPrior(1.0, 1.0),
    priors.HalfNormalPrior(1.0, 1.0),
    priors.LogNormalPrior(1.0, 1.0),
    priors.UniformPrior(1.0, 2.0),
):
    prior.to("cuda:0")
    samples = prior.rsample()
    print(f"{str(prior):<35} {str(samples.device):<8} {dict(prior.state_dict())}")
NormalPrior()                       cuda:0   {'loc': tensor(1., device='cuda:0'), 'scale': tensor(1., device='cuda:0')}
GammaPrior()                        cuda:0   {'concentration': tensor(1., device='cuda:0'), 'rate': tensor(1., device='cuda:0')}
HalfCauchyPrior()                   cpu      {'_transformed_scale': tensor(1., device='cuda:0')}
HalfNormalPrior()                   cpu      {'_transformed_scale': tensor(1., device='cuda:0')}
LogNormalPrior()                    cpu      {'_transformed_loc': tensor(1., device='cuda:0'), '_transformed_scale': tensor(1., device='cuda:0')}
UniformPrior(low: 1.0, high: 2.0)   cpu      {}

This manifests itself in BoTorch when a LogNormal prior is in use. If the fit fails the first time, new initial hyperparameter values are sampled from the prior, which results in a device mismatch. In the reproducible example below, I'm triggering this manually with optimizer_kwargs set such that a warning is raised, and warning_handler set to trigger a retry for any warning.

To reproduce

** Code snippet to reproduce **

import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from gpytorch import kernels, priors
from gpytorch.mlls import ExactMarginalLogLikelihood

n_inputs = 4
n_outputs = 2
n_train = 256
device = torch.device("cuda:0")

train_x = torch.rand(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_outputs, dtype=torch.float64, device=device)

model = SingleTaskGP(
    train_x, 
    train_y, 
    input_transform=Normalize(n_inputs),
    outcome_transform=Standardize(m=n_outputs),
    covar_module=kernels.ScaleKernel(
        base_kernel=kernels.MaternKernel(
            nu=2.5,
            ard_num_dims=n_inputs,
            batch_shape=torch.Size([n_outputs]),
            lengthscale_prior=priors.LogNormalPrior(0.5, 0.5),
        ),
        outputscale_prior=priors.GammaPrior(2.0, 0.15),
        batch_shape=torch.Size([n_outputs]),
    )
)

mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(
    mll, 
    optimizer_kwargs={"timeout_sec": 1e-3}, 
    warning_handler=lambda _: False,
)

** Stack trace/error message **

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 34
     16 model = SingleTaskGP(
     17     train_x, 
     18     train_y, 
   (...)
     30     )
     31 )
     33 mll = ExactMarginalLogLikelihood(model.likelihood, model)
---> 34 fit_gpytorch_mll(
     35     mll, 
     36     optimizer_kwargs={"timeout_sec": 1e-3}, 
     37     warning_handler=lambda _: False,
     38 )

File .../python3.10/site-packages/botorch/fit.py:104, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    101 if optimizer is not None:  # defer to per-method defaults
    102     kwargs["optimizer"] = optimizer
--> 104 return FitGPyTorchMLL(
    105     mll,
    106     type(mll.likelihood),
    107     type(mll.model),
    108     closure=closure,
    109     closure_kwargs=closure_kwargs,
    110     optimizer_kwargs=optimizer_kwargs,
    111     **kwargs,
    112 )

File .../python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File .../python3.10/site-packages/botorch/fit.py:198, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, pick_best_of_all_attempts, warning_handler, caught_exception_types, **ignore)
    195         ckpt_nograd = {name: ckpt[name] for name in params_nograd}
    197     with parameter_rollback_ctx(params_nograd, checkpoint=ckpt_nograd):
--> 198         sample_all_priors(mll.model)
    200 try:
    201     # Fit the model
    202     with catch_warnings(record=True) as warning_list, debug(True):

File .../python3.10/site-packages/botorch/optim/utils/model_utils.py:191, in sample_all_priors(model, max_retries)
    186         raise RuntimeError(
    187             "Failed to sample a feasible parameter value "
    188             f"from the prior after {max_retries} attempts."
    189         )
    190 else:
--> 191     raise e

File .../python3.10/site-packages/botorch/optim/utils/model_utils.py:171, in sample_all_priors(model, max_retries)
    166 prior_shape = prior._extended_shape()
    167 if prior_shape.numel() == 1:
    168     # For a univariate prior we can sample the size of the closure.
    169     # Otherwise we will sample exactly the same value for all
    170     # lengthscales where we commonly specify a univariate prior.
--> 171     setting_closure(module, prior.sample(closure(module).shape))
    172 else:
    173     closure_shape = closure(module).shape

File .../python3.10/site-packages/gpytorch/kernels/kernel.py:221, in Kernel._lengthscale_closure(self, m, v)
    219 def _lengthscale_closure(self, m: Kernel, v: Tensor) -> Tensor:
    220     # Used by the lengthscale_prior
--> 221     return m._set_lengthscale(v)

File .../python3.10/site-packages/gpytorch/kernels/kernel.py:231, in Kernel._set_lengthscale(self, value)
    228 if not torch.is_tensor(value):
    229     value = torch.as_tensor(value).to(self.raw_lengthscale)
--> 231 self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))

File .../python3.10/site-packages/gpytorch/module.py:103, in Module.initialize(self, **kwargs)
    101 elif torch.is_tensor(val):
    102     constraint = self.constraint_for_parameter_name(name)
--> 103     if constraint is not None and constraint.enforced and not constraint.check_raw(val):
    104         raise RuntimeError(
    105             "Attempting to manually set a parameter value that is out of bounds of "
    106             f"its current constraints, {constraint}. "
    107             "Most likely, you want to do the following:\n likelihood = GaussianLikelihood"
    108             "(noise_constraint=gpytorch.constraints.GreaterThan(better_lower_bound))"
    109         )
    110     try:

File .../python3.10/site-packages/gpytorch/constraints/constraints.py:90, in Interval.check_raw(self, tensor)
     88 def check_raw(self, tensor) -> bool:
     89     return bool(
---> 90         torch.all((self.transform(tensor) <= self.upper_bound))
     91         and torch.all(self.transform(tensor) >= self.lower_bound)
     92     )

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Expected Behavior

System information

Please complete the following information:

  • GPyTorch Version: 1.14.dev2+g83332c2c (latest main)
  • PyTorch Version: '2.0.1+cu117'
  • Computer OS: Linux
@slishak-PX slishak-PX added the bug label Sep 9, 2024
@Balandat
Copy link
Collaborator

Thanks for raising this. Yes this is the same reason as for #2550 - the .to() doesn't move the attributes over the GPU. @hvarfner in case you have any immediate ideas on this - basically looks like we just need to overwrite the .to() method in the same way.

@hvarfner
Copy link
Contributor

@Balandat Interesting. I'm not sure if we need to overwrite .to() or just modify save/load_state_dict(), but I'll have a look.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants