-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
933c647
commit be64ecc
Showing
29 changed files
with
2,204 additions
and
319 deletions.
There are no files selected for viewing
137 changes: 0 additions & 137 deletions
137
gauche/kernels/fingerprint_kernels/base_fingerprint_kernel.py
This file was deleted.
Oops, something went wrong.
111 changes: 111 additions & 0 deletions
111
gauche/kernels/fingerprint_kernels/braun_blanquet_kernel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
Braun-Blanquet Kernel. Operates on representations including bit vectors e.g. Morgan/ECFP6 fingerprints count vectors e.g. | ||
RDKit fragment features. | ||
""" | ||
|
||
import torch | ||
from gpytorch.kernels import Kernel | ||
|
||
tkwargs = {"dtype": torch.double} | ||
|
||
def batch_braun_blanquet_sim( | ||
x1: torch.Tensor, x2: torch.Tensor, eps: float = 1e-6 | ||
) -> torch.Tensor: | ||
""" | ||
Braun-Blanquet similarity between two batched tensors, across last 2 dimensions. | ||
eps argument ensures numerical stability if all zero tensors are added. | ||
<x1, x2> / max(|x1|, |x2|) | ||
Where || is the L1 norm and <.> is the inner product | ||
Args: | ||
x1: `[b x n x d]` Tensor where b is the batch dimension | ||
x2: `[b x m x d]` Tensor | ||
eps: Float for numerical stability. Default value is 1e-6 | ||
Returns: | ||
Tensor denoting the Braun-Blanquet similarity. | ||
""" | ||
|
||
if x1.ndim < 2 or x2.ndim < 2: | ||
raise ValueError("Tensors must have a batch dimension") | ||
|
||
# Compute L1 norm | ||
x1_norm = torch.sum(x1, dim=-1, keepdims=True) | ||
x2_norm = torch.sum(x2, dim=-1, keepdims=True) | ||
denom = torch.max(x1_norm[-1], x2_norm[-1]) | ||
dot_prod = torch.matmul(x1, torch.transpose(x2, -1, -2)) | ||
|
||
similarity = (dot_prod + eps) / (denom + eps) | ||
|
||
return similarity.to(**tkwargs).clamp_min_(0) # zero out negative values for numerical stability | ||
|
||
|
||
class BraunBlanquetKernel(Kernel): | ||
r""" | ||
Computes a covariance matrix based on the Braun-Blanquet kernel | ||
between inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`: | ||
.. note:: | ||
This kernel does not have an `outputscale` parameter. To add a scaling parameter, | ||
decorate this kernel with a :class:`gpytorch.test_kernels.ScaleKernel`. | ||
Example: | ||
>>> x = torch.randint(0, 2, (10, 5)) | ||
>>> # Non-batch: Simple option | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(BraunBlanquetKernel()) | ||
>>> covar = covar_module(x) # Output: LazyTensor of size (10 x 10) | ||
>>> | ||
>>> batch_x = torch.randint(0, 2, (2, 10, 5)) | ||
>>> # Batch: Simple option | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(BraunBlanquetKernel()) | ||
>>> covar = covar_module(batch_x) # Output: LazyTensor of size (2 x 10 x 10) | ||
""" | ||
|
||
is_stationary = False | ||
has_lengthscale = False | ||
|
||
def __init__(self, **kwargs): | ||
super(BraunBlanquetKernel, self).__init__(**kwargs) | ||
|
||
def forward(self, x1, x2, diag=False, **params): | ||
if diag: | ||
assert x1.size() == x2.size() and torch.equal(x1, x2) | ||
return torch.ones( | ||
*x1.shape[:-2], x1.shape[-2], dtype=x1.dtype, device=x1.device | ||
) | ||
else: | ||
return self.covar_dist(x1, x2, **params) | ||
|
||
def covar_dist( | ||
self, | ||
x1, | ||
x2, | ||
last_dim_is_batch=False, | ||
**params, | ||
): | ||
r"""This is a helper method for computing the bit vector similarity between | ||
all pairs of points in x1 and x2. | ||
Args: | ||
:attr:`x1` (Tensor `n x d` or `b1 x ... x bk x n x d`): | ||
First set of data. | ||
:attr:`x2` (Tensor `m x d` or `b1 x ... x bk x m x d`): | ||
Second set of data. | ||
:attr:`last_dim_is_batch` (tuple, optional): | ||
Is the last dimension of the data a batch dimension or not? | ||
Returns: | ||
(:class:`Tensor`, :class:`Tensor) corresponding to the distance matrix between `x1` and `x2`. | ||
The shape depends on the kernel's mode | ||
* `diag=False` | ||
* `diag=False` and `last_dim_is_batch=True`: (`b x d x n x n`) | ||
* `diag=True` | ||
* `diag=True` and `last_dim_is_batch=True`: (`b x d x n`) | ||
""" | ||
if last_dim_is_batch: | ||
x1 = x1.transpose(-1, -2).unsqueeze(-1) | ||
x2 = x2.transpose(-1, -2).unsqueeze(-1) | ||
|
||
return batch_braun_blanquet_sim(x1, x2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
""" | ||
Dice Kernel. Operates on representations including bit vectors e.g. Morgan/ECFP6 fingerprints count vectors e.g. | ||
RDKit fragment features. | ||
""" | ||
|
||
import torch | ||
from gpytorch.kernels import Kernel | ||
|
||
|
||
def batch_dice_sim( | ||
x1: torch.Tensor, x2: torch.Tensor, eps: float = 1e-6 | ||
) -> torch.Tensor: | ||
""" | ||
Dice similarity between two batched tensors, across last 2 dimensions. | ||
eps argument ensures numerical stability if all zero tensors are added. | ||
(2 * <x1, x2>) / (|x1| + |x2|) | ||
Where || is the L1 norm and <.> is the inner product | ||
Args: | ||
x1: `[b x n x d]` Tensor where b is the batch dimension | ||
x2: `[b x m x d]` Tensor | ||
eps: Float for numerical stability. Default value is 1e-6 | ||
Returns: | ||
Tensor denoting the Dice similarity. | ||
""" | ||
|
||
if x1.ndim < 2 or x2.ndim < 2: | ||
raise ValueError("Tensors must have a batch dimension") | ||
|
||
# Compute L1 norm | ||
x1_norm = torch.sum(x1, dim=-1, keepdims=True) | ||
x2_norm = torch.sum(x2, dim=-1, keepdims=True) | ||
dot_prod = torch.matmul(x1, torch.transpose(x2, -1, -2)) | ||
|
||
dice_similarity = (2 * dot_prod + eps) / (x1_norm + torch.transpose(x2_norm, -1, -2) + eps) | ||
|
||
return dice_similarity.clamp_min_(0) # zero out negative values for numerical stability | ||
|
||
|
||
class DiceKernel(Kernel): | ||
r""" | ||
Computes a covariance matrix based on the Dice kernel | ||
between inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`: | ||
.. math:: | ||
\begin{equation*} | ||
k_{\text{Dice}}(\mathbf{x}, \mathbf{x'}) = \frac{2\langle\mathbf{x}, | ||
\mathbf{x'}\rangle}{\left\lVert\mathbf{x}\right\rVert + \left\lVert\mathbf{x'}\right\rVert} | ||
\end{equation*} | ||
.. note:: | ||
This kernel does not have an `outputscale` parameter. To add a scaling parameter, | ||
decorate this kernel with a :class:`gpytorch.test_kernels.ScaleKernel`. | ||
Example: | ||
>>> x = torch.randint(0, 2, (10, 5)) | ||
>>> # Non-batch: Simple option | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(DiceKernel()) | ||
>>> covar = covar_module(x) # Output: LazyTensor of size (10 x 10) | ||
>>> | ||
>>> batch_x = torch.randint(0, 2, (2, 10, 5)) | ||
>>> # Batch: Simple option | ||
>>> covar_module = gpytorch.kernels.ScaleKernel(DiceKernel()) | ||
>>> covar = covar_module(batch_x) # Output: LazyTensor of size (2 x 10 x 10) | ||
""" | ||
|
||
is_stationary = False | ||
has_lengthscale = False | ||
|
||
def __init__(self, **kwargs): | ||
super(DiceKernel, self).__init__(**kwargs) | ||
|
||
def forward(self, x1, x2, diag=False, **params): | ||
if diag: | ||
assert x1.size() == x2.size() and torch.equal(x1, x2) | ||
return torch.ones( | ||
*x1.shape[:-2], x1.shape[-2], dtype=x1.dtype, device=x1.device | ||
) | ||
else: | ||
return self.covar_dist(x1, x2, **params) | ||
|
||
def covar_dist( | ||
self, | ||
x1, | ||
x2, | ||
last_dim_is_batch=False, | ||
**params, | ||
): | ||
r"""This is a helper method for computing the bit vector similarity between | ||
all pairs of points in x1 and x2. | ||
Args: | ||
:attr:`x1` (Tensor `n x d` or `b1 x ... x bk x n x d`): | ||
First set of data. | ||
:attr:`x2` (Tensor `m x d` or `b1 x ... x bk x m x d`): | ||
Second set of data. | ||
:attr:`last_dim_is_batch` (tuple, optional): | ||
Is the last dimension of the data a batch dimension or not? | ||
Returns: | ||
(:class:`Tensor`, :class:`Tensor) corresponding to the distance matrix between `x1` and `x2`. | ||
The shape depends on the kernel's mode | ||
* `diag=False` | ||
* `diag=False` and `last_dim_is_batch=True`: (`b x d x n x n`) | ||
* `diag=True` | ||
* `diag=True` and `last_dim_is_batch=True`: (`b x d x n`) | ||
""" | ||
if last_dim_is_batch: | ||
x1 = x1.transpose(-1, -2).unsqueeze(-1) | ||
x2 = x2.transpose(-1, -2).unsqueeze(-1) | ||
|
||
return batch_dice_sim(x1, x2) |
Oops, something went wrong.