Skip to content

Commit

Permalink
Implemented alignment methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Jun 2, 2024
1 parent b2c1cf0 commit bbb2513
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/utils/geometry/alignment.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,55 @@
"""There is a weighted rigid alignment of the ground truth onto the denoised structure before
the diffusion loss is applied."""
import torch
from src.utils.geometry.vector import Vec3Array
from src.utils.geometry.rotation_matrix import Rot3Array


def compute_covariance_matrix(P, Q):
"""Computes the covariance matrix between two sets of points P and Q.
The covariance matrix H is calculated by H = P^T*Q. This is used to
find the transformation from Q to P.
Args:
P: (bs, n_atoms, 3) tensor of points
Q: (bs, n_atoms, 3) tensor of points
Returns:
(bs, 3, 3) tensor of covariance matrices.
"""
return torch.matmul(P.transpose(-2, -1), Q)


def weighted_rigid_align(
x: Vec3Array,
x_gt: Vec3Array,
weights: torch.Tensor,
mask: torch.Tensor = None # (bs, n_atoms)
) -> Vec3Array:
"""Performs a weighted alignment of x to x_gt. Warning: ground truth here only refers to the structure
not being moved, not to be confused with ground truth during training."""

# Mean-centre positions
mu = (x * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True)
mu_gt = (x_gt * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True)
x = x - mu # Vec3Array of shape (bs, n_atoms)
x_gt = x_gt - mu_gt

# Mask atoms before computing covariance matrix
if mask is not None:
x *= mask
x_gt *= mask

# Find optimal rotation from singular value decomposition
U, S, Vh = torch.linalg.svd(compute_covariance_matrix(x_gt.to_tensor(), x.to_tensor())) # shapes: (bs, 3, 3)
R = U @ Vh

# Remove reflection
if torch.linalg.det(R) < 0:
reflection_matrix = torch.diag((torch.tensor([1, 1, -1], device=U.device, dtype=U.dtype)))
reflection_matrix = reflection_matrix.unsqueeze(0).expand_as(R)
R = U @ reflection_matrix @ Vh # (bs, 3, 3)

R = Rot3Array.from_array(R)

# Apply alignment
x_aligned = R.apply_to_point(x) + mu
return x_aligned
54 changes: 54 additions & 0 deletions tests/test_alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import unittest
from src.utils.geometry.alignment import weighted_rigid_align, compute_covariance_matrix
from src.utils.geometry.vector import Vec3Array
from src.utils.geometry.rotation_matrix import Rot3Array
from src.diffusion.augmentation import centre_random_augmentation


class TestAlignment(unittest.TestCase):

def test_compute_covariance_matrix(self):
# Test simple case
P = torch.tensor([[[1, 2, 3]]], dtype=torch.float32) # Shape (1, 1, 3)
Q = torch.tensor([[[4, 5, 6]]], dtype=torch.float32) # Shape (1, 1, 3)
expected_H = torch.tensor([[[4., 5., 6.],
[8., 10., 12.],
[12., 15., 18.]]]) # Shape (1, 3, 3)
H = compute_covariance_matrix(P, Q)

self.assertTrue(torch.allclose(H, expected_H), "Covariance matrix calculation failed")

# Test zero matrices
P_zero = torch.zeros((1, 10, 3), dtype=torch.float32)
Q_zero = torch.zeros((1, 10, 3), dtype=torch.float32)
H_zero = compute_covariance_matrix(P_zero, Q_zero)
self.assertTrue(torch.equal(H_zero, torch.zeros((1, 3, 3))), "Covariance matrix should be zero")

def test_weighted_rigid_align(self):
# Setup
bs, n_atoms = 1, 3
x = Vec3Array.from_array(torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]))
x_gt = Vec3Array.from_array(torch.tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]))
weights = torch.ones((bs, n_atoms))
mask = torch.ones((bs, n_atoms))

# Identity alignment (no movement expected)
x_aligned = weighted_rigid_align(x, x_gt, weights, mask)
self.assertTrue(torch.allclose(x_aligned.to_tensor(), x.to_tensor(), atol=1e-6), "Alignment should be identity")

# Rotation removal test
n_atoms = 100
x = Vec3Array.from_array(torch.randn((bs, n_atoms, 3)))
x = x - x.mean(dim=1, keepdim=True) # Center x
weights = torch.ones((bs, n_atoms))
mask = torch.ones((bs, n_atoms))

x_rotated = centre_random_augmentation(x, s_trans=0.0)
x_rotated_aligned = weighted_rigid_align(x_rotated, x, weights, mask)
self.assertTrue(torch.allclose(x_rotated_aligned.to_tensor(), x.to_tensor(), atol=1e-6), "Rotation should be "
"corrected")


if __name__ == '__main__':
unittest.main()

0 comments on commit bbb2513

Please sign in to comment.