diff --git a/src/diffusion/augmentation.py b/src/diffusion/augmentation.py index 69f0a31..1429fc9 100644 --- a/src/diffusion/augmentation.py +++ b/src/diffusion/augmentation.py @@ -22,21 +22,14 @@ def centre_random_augmentation( device = atom_positions.x.device # Center the atoms - center = atom_positions.mean(dim=-1, keepdim=True) + center = atom_positions.mean(dim=1, keepdim=True) atom_positions = atom_positions - center # Sample random rotation - quaternions = torch.randn(batch_size, 4, device=device) - rots = Rot3Array.from_quaternion(w=quaternions[:, 0], - x=quaternions[:, 1], - y=quaternions[:, 2], - z=quaternions[:, 3], - normalize=True) # (bs) - rots = rots.unsqueeze(-1) # (bs, 1) + rots = Rot3Array.uniform_random((batch_size, 1), device) - # Sample random translation - trans = s_trans * Vec3Array.from_array(torch.randn((batch_size, 3), device=device)) - trans = trans.unsqueeze(-1) # (bs, 1) + # Sample random translation from normal distribution + trans = s_trans * Vec3Array.randn((batch_size, 1), device) # Apply atom_positions = rots.apply_to_point(atom_positions) + trans diff --git a/src/utils/geometry/rotation_matrix.py b/src/utils/geometry/rotation_matrix.py index cb5e06a..61e1488 100644 --- a/src/utils/geometry/rotation_matrix.py +++ b/src/utils/geometry/rotation_matrix.py @@ -147,6 +147,16 @@ def from_array(cls, array: torch.Tensor) -> Rot3Array: rc = [torch.unbind(e, dim=-1) for e in rows] return cls(*[e for row in rc for e in row]) + @classmethod + def uniform_random(cls, shape, device='cpu') -> Rot3Array: + """Generates a random rotation of given shape.""" + quaternions = torch.randn((*shape, 4), device=device) + return Rot3Array.from_quaternion(w=quaternions[..., 0], + x=quaternions[..., 1], + y=quaternions[..., 2], + z=quaternions[..., 3], + normalize=True) + def to_tensor(self) -> torch.Tensor: """Convert Rot3Array to array of shape [..., 3, 3].""" return torch.stack( diff --git a/src/utils/geometry/vector.py b/src/utils/geometry/vector.py index c19d199..146e5df 100644 --- a/src/utils/geometry/vector.py +++ b/src/utils/geometry/vector.py @@ -166,6 +166,14 @@ def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array: torch.cat([v.z for v in vecs], dim=dim), ) + @classmethod + def randn(cls, shape, device="cpu"): + return cls( + torch.randn(shape, dtype=torch.float32, device=device), + torch.randn(shape, dtype=torch.float32, device=device), + torch.randn(shape, dtype=torch.float32, device=device), + ) + def square_euclidean_distance( vec1: Vec3Array,