From bbb251320cd0081393a0578e287ad9b14f1b9599 Mon Sep 17 00:00:00 2001 From: ardagoreci <62720042+ardagoreci@users.noreply.github.com> Date: Sun, 2 Jun 2024 15:06:45 +0100 Subject: [PATCH] Implemented alignment methods --- src/utils/geometry/alignment.py | 53 ++++++++++++++++++++++++++++++++ tests/test_alignment.py | 54 +++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 tests/test_alignment.py diff --git a/src/utils/geometry/alignment.py b/src/utils/geometry/alignment.py index 4680328..3532ea8 100644 --- a/src/utils/geometry/alignment.py +++ b/src/utils/geometry/alignment.py @@ -1,2 +1,55 @@ """There is a weighted rigid alignment of the ground truth onto the denoised structure before the diffusion loss is applied.""" +import torch +from src.utils.geometry.vector import Vec3Array +from src.utils.geometry.rotation_matrix import Rot3Array + + +def compute_covariance_matrix(P, Q): + """Computes the covariance matrix between two sets of points P and Q. + The covariance matrix H is calculated by H = P^T*Q. This is used to + find the transformation from Q to P. + Args: + P: (bs, n_atoms, 3) tensor of points + Q: (bs, n_atoms, 3) tensor of points + Returns: + (bs, 3, 3) tensor of covariance matrices. + """ + return torch.matmul(P.transpose(-2, -1), Q) + + +def weighted_rigid_align( + x: Vec3Array, + x_gt: Vec3Array, + weights: torch.Tensor, + mask: torch.Tensor = None # (bs, n_atoms) +) -> Vec3Array: + """Performs a weighted alignment of x to x_gt. Warning: ground truth here only refers to the structure + not being moved, not to be confused with ground truth during training.""" + + # Mean-centre positions + mu = (x * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True) + mu_gt = (x_gt * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True) + x = x - mu # Vec3Array of shape (bs, n_atoms) + x_gt = x_gt - mu_gt + + # Mask atoms before computing covariance matrix + if mask is not None: + x *= mask + x_gt *= mask + + # Find optimal rotation from singular value decomposition + U, S, Vh = torch.linalg.svd(compute_covariance_matrix(x_gt.to_tensor(), x.to_tensor())) # shapes: (bs, 3, 3) + R = U @ Vh + + # Remove reflection + if torch.linalg.det(R) < 0: + reflection_matrix = torch.diag((torch.tensor([1, 1, -1], device=U.device, dtype=U.dtype))) + reflection_matrix = reflection_matrix.unsqueeze(0).expand_as(R) + R = U @ reflection_matrix @ Vh # (bs, 3, 3) + + R = Rot3Array.from_array(R) + + # Apply alignment + x_aligned = R.apply_to_point(x) + mu + return x_aligned diff --git a/tests/test_alignment.py b/tests/test_alignment.py new file mode 100644 index 0000000..bb0a669 --- /dev/null +++ b/tests/test_alignment.py @@ -0,0 +1,54 @@ +import torch +import unittest +from src.utils.geometry.alignment import weighted_rigid_align, compute_covariance_matrix +from src.utils.geometry.vector import Vec3Array +from src.utils.geometry.rotation_matrix import Rot3Array +from src.diffusion.augmentation import centre_random_augmentation + + +class TestAlignment(unittest.TestCase): + + def test_compute_covariance_matrix(self): + # Test simple case + P = torch.tensor([[[1, 2, 3]]], dtype=torch.float32) # Shape (1, 1, 3) + Q = torch.tensor([[[4, 5, 6]]], dtype=torch.float32) # Shape (1, 1, 3) + expected_H = torch.tensor([[[4., 5., 6.], + [8., 10., 12.], + [12., 15., 18.]]]) # Shape (1, 3, 3) + H = compute_covariance_matrix(P, Q) + + self.assertTrue(torch.allclose(H, expected_H), "Covariance matrix calculation failed") + + # Test zero matrices + P_zero = torch.zeros((1, 10, 3), dtype=torch.float32) + Q_zero = torch.zeros((1, 10, 3), dtype=torch.float32) + H_zero = compute_covariance_matrix(P_zero, Q_zero) + self.assertTrue(torch.equal(H_zero, torch.zeros((1, 3, 3))), "Covariance matrix should be zero") + + def test_weighted_rigid_align(self): + # Setup + bs, n_atoms = 1, 3 + x = Vec3Array.from_array(torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]])) + x_gt = Vec3Array.from_array(torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]])) + weights = torch.ones((bs, n_atoms)) + mask = torch.ones((bs, n_atoms)) + + # Identity alignment (no movement expected) + x_aligned = weighted_rigid_align(x, x_gt, weights, mask) + self.assertTrue(torch.allclose(x_aligned.to_tensor(), x.to_tensor(), atol=1e-6), "Alignment should be identity") + + # Rotation removal test + n_atoms = 100 + x = Vec3Array.from_array(torch.randn((bs, n_atoms, 3))) + x = x - x.mean(dim=1, keepdim=True) # Center x + weights = torch.ones((bs, n_atoms)) + mask = torch.ones((bs, n_atoms)) + + x_rotated = centre_random_augmentation(x, s_trans=0.0) + x_rotated_aligned = weighted_rigid_align(x_rotated, x, weights, mask) + self.assertTrue(torch.allclose(x_rotated_aligned.to_tensor(), x.to_tensor(), atol=1e-6), "Rotation should be " + "corrected") + + +if __name__ == '__main__': + unittest.main()