Skip to content

Commit

Permalink
Implemented CentreRandomAugmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed May 30, 2024
1 parent bcce958 commit c18aff0
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 11 deletions.
37 changes: 32 additions & 5 deletions src/diffusion/augmentation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,43 @@
"""Data augmentations applied prior to sampling from the diffusion trajectory."""

import torch
from src.utils.geometry.vector import Vec3Array
from src.utils.geometry.rotation_matrix import Rot3Array


def centre_random_augmentation(
atom_positions: torch.Tensor, # (bs, n_atoms, 3)
atom_positions: Vec3Array, # (bs, n_atoms)
s_trans: float = 1.0, # Translation scaling factor
) -> torch.Tensor:
"""Centers the atoms and applies random rotation and translation."""
) -> Vec3Array: # (bs, n_atoms)
"""Centers the atoms and applies random rotation and translation.
Args:
atom_positions:
[*, n_atoms] vector of atom coordinates.
s_trans:
Scaling factor in Angstroms for the random translation sampled
from a normal distribution.
Returns:
[*, n_atoms] vector of atom coordinates after augmentation.
"""
batch_size, n_atoms = atom_positions.shape
device = atom_positions.x.device

# Center the atoms
center = atom_positions.mean(dim=-2, keepdim=True)
atom_positions = atom_positions - center

# Sample random rotation
quaternions = torch.randn(batch_size, 4, device=device)
rots = Rot3Array.from_quaternion(w=quaternions[:, 0],
x=quaternions[:, 1],
y=quaternions[:, 2],
z=quaternions[:, 3],
normalize=True) # (bs)
rots = rots.unsqueeze(-1) # (bs, 1)

# Sample random translation
trans = s_trans * Vec3Array.from_array(torch.randn((batch_size, 3), device=device))
trans = trans.unsqueeze(-1) # (bs, 1)

# Apply
pass
atom_positions = rots.apply_to_point(atom_positions) + trans
return atom_positions
2 changes: 0 additions & 2 deletions src/diffusion/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,3 @@ def encode(feature_tensor: torch.Tensor,
torch.full_like(relative_dists, 2*clamp_max + 1)
)
return F.one_hot(d_ij, num_classes=2 * clamp_max + 2) # (bs, n_tokens, n_tokens, 2 * clamp_max + 2)


15 changes: 11 additions & 4 deletions src/utils/geometry/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,18 @@ def reshape(self, new_shape) -> Vec3Array:

return Vec3Array(x, y, z)

def sum(self, dim: int) -> Vec3Array:
def sum(self, dim: int, keepdim=False) -> Vec3Array:
return Vec3Array(
torch.sum(self.x, dim=dim),
torch.sum(self.y, dim=dim),
torch.sum(self.z, dim=dim),
torch.sum(self.x, dim=dim, keepdim=keepdim),
torch.sum(self.y, dim=dim, keepdim=keepdim),
torch.sum(self.z, dim=dim, keepdim=keepdim),
)

def mean(self, dim: int, keepdim=False) -> Vec3Array:
return Vec3Array(
torch.mean(self.x, dim=dim, keepdim=keepdim),
torch.mean(self.y, dim=dim, keepdim=keepdim),
torch.mean(self.z, dim=dim, keepdim=keepdim),
)

def unsqueeze(self, dim: int):
Expand Down
33 changes: 33 additions & 0 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest
import torch
from src.utils.geometry.vector import Vec3Array
from src.utils.geometry.rotation_matrix import Rot3Array
from src.diffusion.augmentation import centre_random_augmentation


class TestCentreRandomAugmentation(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.n_atoms = 5
atom_positions = torch.randn(self.batch_size, self.n_atoms, 3)
self.atom_positions = Vec3Array(x=atom_positions[:, :, 0],
y=atom_positions[:, :, 1],
z=atom_positions[:, :, 2])

def test_output_shape(self):
augmented_positions = centre_random_augmentation(self.atom_positions)
self.assertEqual(augmented_positions.to_tensor().shape, (self.batch_size, self.n_atoms, 3))

def test_random_translation(self):
s_trans = 1.0
initial_positions = self.atom_positions.x.clone()
augmented_positions = centre_random_augmentation(self.atom_positions, s_trans=s_trans)
translation = augmented_positions.x - initial_positions
translation_magnitudes = translation.norm(dim=-1)
self.assertTrue(torch.all(translation_magnitudes > 0))

def test_random_rotation(self):
initial_positions = self.atom_positions.x.clone()
augmented_positions = centre_random_augmentation(self.atom_positions)
rotation_diff = augmented_positions.x - initial_positions
self.assertFalse(torch.allclose(rotation_diff, torch.zeros_like(rotation_diff), atol=1e-6))

0 comments on commit c18aff0

Please sign in to comment.