generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented CentreRandomAugmentation
- Loading branch information
ardagoreci
committed
May 30, 2024
1 parent
bcce958
commit c18aff0
Showing
4 changed files
with
76 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,43 @@ | ||
"""Data augmentations applied prior to sampling from the diffusion trajectory.""" | ||
|
||
import torch | ||
from src.utils.geometry.vector import Vec3Array | ||
from src.utils.geometry.rotation_matrix import Rot3Array | ||
|
||
|
||
def centre_random_augmentation( | ||
atom_positions: torch.Tensor, # (bs, n_atoms, 3) | ||
atom_positions: Vec3Array, # (bs, n_atoms) | ||
s_trans: float = 1.0, # Translation scaling factor | ||
) -> torch.Tensor: | ||
"""Centers the atoms and applies random rotation and translation.""" | ||
) -> Vec3Array: # (bs, n_atoms) | ||
"""Centers the atoms and applies random rotation and translation. | ||
Args: | ||
atom_positions: | ||
[*, n_atoms] vector of atom coordinates. | ||
s_trans: | ||
Scaling factor in Angstroms for the random translation sampled | ||
from a normal distribution. | ||
Returns: | ||
[*, n_atoms] vector of atom coordinates after augmentation. | ||
""" | ||
batch_size, n_atoms = atom_positions.shape | ||
device = atom_positions.x.device | ||
|
||
# Center the atoms | ||
center = atom_positions.mean(dim=-2, keepdim=True) | ||
atom_positions = atom_positions - center | ||
|
||
# Sample random rotation | ||
quaternions = torch.randn(batch_size, 4, device=device) | ||
rots = Rot3Array.from_quaternion(w=quaternions[:, 0], | ||
x=quaternions[:, 1], | ||
y=quaternions[:, 2], | ||
z=quaternions[:, 3], | ||
normalize=True) # (bs) | ||
rots = rots.unsqueeze(-1) # (bs, 1) | ||
|
||
# Sample random translation | ||
trans = s_trans * Vec3Array.from_array(torch.randn((batch_size, 3), device=device)) | ||
trans = trans.unsqueeze(-1) # (bs, 1) | ||
|
||
# Apply | ||
pass | ||
atom_positions = rots.apply_to_point(atom_positions) + trans | ||
return atom_positions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import unittest | ||
import torch | ||
from src.utils.geometry.vector import Vec3Array | ||
from src.utils.geometry.rotation_matrix import Rot3Array | ||
from src.diffusion.augmentation import centre_random_augmentation | ||
|
||
|
||
class TestCentreRandomAugmentation(unittest.TestCase): | ||
def setUp(self): | ||
self.batch_size = 2 | ||
self.n_atoms = 5 | ||
atom_positions = torch.randn(self.batch_size, self.n_atoms, 3) | ||
self.atom_positions = Vec3Array(x=atom_positions[:, :, 0], | ||
y=atom_positions[:, :, 1], | ||
z=atom_positions[:, :, 2]) | ||
|
||
def test_output_shape(self): | ||
augmented_positions = centre_random_augmentation(self.atom_positions) | ||
self.assertEqual(augmented_positions.to_tensor().shape, (self.batch_size, self.n_atoms, 3)) | ||
|
||
def test_random_translation(self): | ||
s_trans = 1.0 | ||
initial_positions = self.atom_positions.x.clone() | ||
augmented_positions = centre_random_augmentation(self.atom_positions, s_trans=s_trans) | ||
translation = augmented_positions.x - initial_positions | ||
translation_magnitudes = translation.norm(dim=-1) | ||
self.assertTrue(torch.all(translation_magnitudes > 0)) | ||
|
||
def test_random_rotation(self): | ||
initial_positions = self.atom_positions.x.clone() | ||
augmented_positions = centre_random_augmentation(self.atom_positions) | ||
rotation_diff = augmented_positions.x - initial_positions | ||
self.assertFalse(torch.allclose(rotation_diff, torch.zeros_like(rotation_diff), atol=1e-6)) |