From 482a91c6f7a9a1e172881c37ec6e56525ae605ee Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 29 Aug 2024 17:15:44 +0200 Subject: [PATCH 1/4] typing: Use `Self` from `typing_extensions` --- botorch/models/approximate_gp.py | 10 ++++------ botorch/models/model.py | 15 +++++---------- requirements.txt | 1 + 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/botorch/models/approximate_gp.py b/botorch/models/approximate_gp.py index f9face99fb..d85404525d 100644 --- a/botorch/models/approximate_gp.py +++ b/botorch/models/approximate_gp.py @@ -32,7 +32,8 @@ import copy import warnings -from typing import Optional, TypeVar, Union +from typing import Optional, Union +from typing_extensions import Self import torch from botorch.models.gpytorch import GPyTorchModel @@ -69,9 +70,6 @@ from torch.nn import Module -TApproxModel = TypeVar("TApproxModel", bound="ApproximateGPyTorchModel") - - class ApproximateGPyTorchModel(GPyTorchModel): r""" Botorch wrapper class for various (variational) approximate GP models in @@ -123,11 +121,11 @@ def __init__( def num_outputs(self): return self._desired_num_outputs - def eval(self: TApproxModel) -> TApproxModel: + def eval(self) -> Self: r"""Puts the model in `eval` mode.""" return Module.eval(self) - def train(self: TApproxModel, mode: bool = True) -> TApproxModel: + def train(self, mode: bool = True) -> Self: r"""Put the model in `train` mode. Args: diff --git a/botorch/models/model.py b/botorch/models/model.py index f42fb46f6d..c778eecea4 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -16,7 +16,8 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Mapping -from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing_extensions import Self import numpy as np import torch @@ -41,8 +42,6 @@ if TYPE_CHECKING: from botorch.acquisition.objective import PosteriorTransform # pragma: no cover -TFantasizeMixin = TypeVar("TFantasizeMixin", bound="FantasizeMixin") - class Model(Module, ABC): r"""Abstract base class for BoTorch models. @@ -289,11 +288,7 @@ def __init__(self, args): """ @abstractmethod - def condition_on_observations( - self: TFantasizeMixin, - X: Tensor, - Y: Tensor, - ) -> TFantasizeMixin: + def condition_on_observations(self, X: Tensor, Y: Tensor) -> Self: """ Classes that inherit from `FantasizeMixin` must implement a `condition_on_observations` method. @@ -326,12 +321,12 @@ def transform_inputs( # this as # 'Self', but at this point the verbose 'T...' syntax is needed. def fantasize( - self: TFantasizeMixin, + self, X: Tensor, sampler: MCSampler, observation_noise: Optional[Tensor] = None, **kwargs: Any, - ) -> TFantasizeMixin: + ) -> Self: r"""Construct a fantasy model. Constructs a fantasy model in the following fashion: diff --git a/requirements.txt b/requirements.txt index c094eede8e..d9fdda123e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ torch>=1.13.1 pyro-ppl>=1.8.4 gpytorch==1.12 linear_operator==0.5.2 +typing_extensions From d6a5c03c73bb05a4c8b410ca3ca4a191f3ba6114 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 29 Aug 2024 17:17:08 +0200 Subject: [PATCH 2/4] doc: Remove unneeded comment --- botorch/models/model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/botorch/models/model.py b/botorch/models/model.py index c778eecea4..d30d333e6b 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -317,9 +317,6 @@ def transform_inputs( a `transform_inputs` method. """ - # When Python 3.11 arrives we can start annotating return types like - # this as - # 'Self', but at this point the verbose 'T...' syntax is needed. def fantasize( self, X: Tensor, From 7aa04887a20b8634afd66ae209c14829adab01bc Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 16 Sep 2024 08:51:04 +0200 Subject: [PATCH 3/4] style: fix formatting of import --- botorch/models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/botorch/models/model.py b/botorch/models/model.py index d30d333e6b..da82b65ba7 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -17,7 +17,6 @@ from collections import defaultdict from collections.abc import Mapping from typing import Any, Callable, Optional, TYPE_CHECKING, Union -from typing_extensions import Self import numpy as np import torch @@ -38,6 +37,7 @@ from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from torch import Tensor from torch.nn import Module, ModuleDict, ModuleList +from typing_extensions import Self if TYPE_CHECKING: from botorch.acquisition.objective import PosteriorTransform # pragma: no cover From be87a2973c0f1879c74463e46745d8445f69b555 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 16 Sep 2024 13:49:14 +0200 Subject: [PATCH 4/4] style: formatting imports --- botorch/models/approximate_gp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/botorch/models/approximate_gp.py b/botorch/models/approximate_gp.py index d85404525d..2e3fa37318 100644 --- a/botorch/models/approximate_gp.py +++ b/botorch/models/approximate_gp.py @@ -33,7 +33,6 @@ import warnings from typing import Optional, Union -from typing_extensions import Self import torch from botorch.models.gpytorch import GPyTorchModel @@ -68,6 +67,7 @@ ) from torch import Tensor from torch.nn import Module +from typing_extensions import Self class ApproximateGPyTorchModel(GPyTorchModel):