generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
ardagoreci
committed
Jun 2, 2024
1 parent
b2c1cf0
commit bbb2513
Showing
2 changed files
with
107 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,55 @@ | ||
"""There is a weighted rigid alignment of the ground truth onto the denoised structure before | ||
the diffusion loss is applied.""" | ||
import torch | ||
from src.utils.geometry.vector import Vec3Array | ||
from src.utils.geometry.rotation_matrix import Rot3Array | ||
|
||
|
||
def compute_covariance_matrix(P, Q): | ||
"""Computes the covariance matrix between two sets of points P and Q. | ||
The covariance matrix H is calculated by H = P^T*Q. This is used to | ||
find the transformation from Q to P. | ||
Args: | ||
P: (bs, n_atoms, 3) tensor of points | ||
Q: (bs, n_atoms, 3) tensor of points | ||
Returns: | ||
(bs, 3, 3) tensor of covariance matrices. | ||
""" | ||
return torch.matmul(P.transpose(-2, -1), Q) | ||
|
||
|
||
def weighted_rigid_align( | ||
x: Vec3Array, | ||
x_gt: Vec3Array, | ||
weights: torch.Tensor, | ||
mask: torch.Tensor = None # (bs, n_atoms) | ||
) -> Vec3Array: | ||
"""Performs a weighted alignment of x to x_gt. Warning: ground truth here only refers to the structure | ||
not being moved, not to be confused with ground truth during training.""" | ||
|
||
# Mean-centre positions | ||
mu = (x * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True) | ||
mu_gt = (x_gt * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True) | ||
x = x - mu # Vec3Array of shape (bs, n_atoms) | ||
x_gt = x_gt - mu_gt | ||
|
||
# Mask atoms before computing covariance matrix | ||
if mask is not None: | ||
x *= mask | ||
x_gt *= mask | ||
|
||
# Find optimal rotation from singular value decomposition | ||
U, S, Vh = torch.linalg.svd(compute_covariance_matrix(x_gt.to_tensor(), x.to_tensor())) # shapes: (bs, 3, 3) | ||
R = U @ Vh | ||
|
||
# Remove reflection | ||
if torch.linalg.det(R) < 0: | ||
reflection_matrix = torch.diag((torch.tensor([1, 1, -1], device=U.device, dtype=U.dtype))) | ||
reflection_matrix = reflection_matrix.unsqueeze(0).expand_as(R) | ||
R = U @ reflection_matrix @ Vh # (bs, 3, 3) | ||
|
||
R = Rot3Array.from_array(R) | ||
|
||
# Apply alignment | ||
x_aligned = R.apply_to_point(x) + mu | ||
return x_aligned |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
import unittest | ||
from src.utils.geometry.alignment import weighted_rigid_align, compute_covariance_matrix | ||
from src.utils.geometry.vector import Vec3Array | ||
from src.utils.geometry.rotation_matrix import Rot3Array | ||
from src.diffusion.augmentation import centre_random_augmentation | ||
|
||
|
||
class TestAlignment(unittest.TestCase): | ||
|
||
def test_compute_covariance_matrix(self): | ||
# Test simple case | ||
P = torch.tensor([[[1, 2, 3]]], dtype=torch.float32) # Shape (1, 1, 3) | ||
Q = torch.tensor([[[4, 5, 6]]], dtype=torch.float32) # Shape (1, 1, 3) | ||
expected_H = torch.tensor([[[4., 5., 6.], | ||
[8., 10., 12.], | ||
[12., 15., 18.]]]) # Shape (1, 3, 3) | ||
H = compute_covariance_matrix(P, Q) | ||
|
||
self.assertTrue(torch.allclose(H, expected_H), "Covariance matrix calculation failed") | ||
|
||
# Test zero matrices | ||
P_zero = torch.zeros((1, 10, 3), dtype=torch.float32) | ||
Q_zero = torch.zeros((1, 10, 3), dtype=torch.float32) | ||
H_zero = compute_covariance_matrix(P_zero, Q_zero) | ||
self.assertTrue(torch.equal(H_zero, torch.zeros((1, 3, 3))), "Covariance matrix should be zero") | ||
|
||
def test_weighted_rigid_align(self): | ||
# Setup | ||
bs, n_atoms = 1, 3 | ||
x = Vec3Array.from_array(torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]])) | ||
x_gt = Vec3Array.from_array(torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]])) | ||
weights = torch.ones((bs, n_atoms)) | ||
mask = torch.ones((bs, n_atoms)) | ||
|
||
# Identity alignment (no movement expected) | ||
x_aligned = weighted_rigid_align(x, x_gt, weights, mask) | ||
self.assertTrue(torch.allclose(x_aligned.to_tensor(), x.to_tensor(), atol=1e-6), "Alignment should be identity") | ||
|
||
# Rotation removal test | ||
n_atoms = 100 | ||
x = Vec3Array.from_array(torch.randn((bs, n_atoms, 3))) | ||
x = x - x.mean(dim=1, keepdim=True) # Center x | ||
weights = torch.ones((bs, n_atoms)) | ||
mask = torch.ones((bs, n_atoms)) | ||
|
||
x_rotated = centre_random_augmentation(x, s_trans=0.0) | ||
x_rotated_aligned = weighted_rigid_align(x_rotated, x, weights, mask) | ||
self.assertTrue(torch.allclose(x_rotated_aligned.to_tensor(), x.to_tensor(), atol=1e-6), "Rotation should be " | ||
"corrected") | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |