diff --git a/src/diffusion/augmentation.py b/src/diffusion/augmentation.py index 8e9a23c..21138eb 100644 --- a/src/diffusion/augmentation.py +++ b/src/diffusion/augmentation.py @@ -1,16 +1,43 @@ """Data augmentations applied prior to sampling from the diffusion trajectory.""" - import torch +from src.utils.geometry.vector import Vec3Array +from src.utils.geometry.rotation_matrix import Rot3Array def centre_random_augmentation( - atom_positions: torch.Tensor, # (bs, n_atoms, 3) + atom_positions: Vec3Array, # (bs, n_atoms) s_trans: float = 1.0, # Translation scaling factor -) -> torch.Tensor: - """Centers the atoms and applies random rotation and translation.""" +) -> Vec3Array: # (bs, n_atoms) + """Centers the atoms and applies random rotation and translation. + Args: + atom_positions: + [*, n_atoms] vector of atom coordinates. + s_trans: + Scaling factor in Angstroms for the random translation sampled + from a normal distribution. + Returns: + [*, n_atoms] vector of atom coordinates after augmentation. + """ + batch_size, n_atoms = atom_positions.shape + device = atom_positions.x.device + # Center the atoms + center = atom_positions.mean(dim=-2, keepdim=True) + atom_positions = atom_positions - center # Sample random rotation + quaternions = torch.randn(batch_size, 4, device=device) + rots = Rot3Array.from_quaternion(w=quaternions[:, 0], + x=quaternions[:, 1], + y=quaternions[:, 2], + z=quaternions[:, 3], + normalize=True) # (bs) + rots = rots.unsqueeze(-1) # (bs, 1) + + # Sample random translation + trans = s_trans * Vec3Array.from_array(torch.randn((batch_size, 3), device=device)) + trans = trans.unsqueeze(-1) # (bs, 1) # Apply - pass + atom_positions = rots.apply_to_point(atom_positions) + trans + return atom_positions diff --git a/src/diffusion/conditioning.py b/src/diffusion/conditioning.py index 1ebfec4..78448b8 100644 --- a/src/diffusion/conditioning.py +++ b/src/diffusion/conditioning.py @@ -102,5 +102,3 @@ def encode(feature_tensor: torch.Tensor, torch.full_like(relative_dists, 2*clamp_max + 1) ) return F.one_hot(d_ij, num_classes=2 * clamp_max + 2) # (bs, n_tokens, n_tokens, 2 * clamp_max + 2) - - diff --git a/src/utils/geometry/vector.py b/src/utils/geometry/vector.py index 99bdefe..c19d199 100644 --- a/src/utils/geometry/vector.py +++ b/src/utils/geometry/vector.py @@ -120,11 +120,18 @@ def reshape(self, new_shape) -> Vec3Array: return Vec3Array(x, y, z) - def sum(self, dim: int) -> Vec3Array: + def sum(self, dim: int, keepdim=False) -> Vec3Array: return Vec3Array( - torch.sum(self.x, dim=dim), - torch.sum(self.y, dim=dim), - torch.sum(self.z, dim=dim), + torch.sum(self.x, dim=dim, keepdim=keepdim), + torch.sum(self.y, dim=dim, keepdim=keepdim), + torch.sum(self.z, dim=dim, keepdim=keepdim), + ) + + def mean(self, dim: int, keepdim=False) -> Vec3Array: + return Vec3Array( + torch.mean(self.x, dim=dim, keepdim=keepdim), + torch.mean(self.y, dim=dim, keepdim=keepdim), + torch.mean(self.z, dim=dim, keepdim=keepdim), ) def unsqueeze(self, dim: int): diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py new file mode 100644 index 0000000..5c4dbce --- /dev/null +++ b/tests/test_augmentation.py @@ -0,0 +1,33 @@ +import unittest +import torch +from src.utils.geometry.vector import Vec3Array +from src.utils.geometry.rotation_matrix import Rot3Array +from src.diffusion.augmentation import centre_random_augmentation + + +class TestCentreRandomAugmentation(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.n_atoms = 5 + atom_positions = torch.randn(self.batch_size, self.n_atoms, 3) + self.atom_positions = Vec3Array(x=atom_positions[:, :, 0], + y=atom_positions[:, :, 1], + z=atom_positions[:, :, 2]) + + def test_output_shape(self): + augmented_positions = centre_random_augmentation(self.atom_positions) + self.assertEqual(augmented_positions.to_tensor().shape, (self.batch_size, self.n_atoms, 3)) + + def test_random_translation(self): + s_trans = 1.0 + initial_positions = self.atom_positions.x.clone() + augmented_positions = centre_random_augmentation(self.atom_positions, s_trans=s_trans) + translation = augmented_positions.x - initial_positions + translation_magnitudes = translation.norm(dim=-1) + self.assertTrue(torch.all(translation_magnitudes > 0)) + + def test_random_rotation(self): + initial_positions = self.atom_positions.x.clone() + augmented_positions = centre_random_augmentation(self.atom_positions) + rotation_diff = augmented_positions.x - initial_positions + self.assertFalse(torch.allclose(rotation_diff, torch.zeros_like(rotation_diff), atol=1e-6))