Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed May 30, 2024
1 parent 5afe517 commit 1dbf1e9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
15 changes: 4 additions & 11 deletions src/diffusion/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,14 @@ def centre_random_augmentation(
device = atom_positions.x.device

# Center the atoms
center = atom_positions.mean(dim=-1, keepdim=True)
center = atom_positions.mean(dim=1, keepdim=True)
atom_positions = atom_positions - center

# Sample random rotation
quaternions = torch.randn(batch_size, 4, device=device)
rots = Rot3Array.from_quaternion(w=quaternions[:, 0],
x=quaternions[:, 1],
y=quaternions[:, 2],
z=quaternions[:, 3],
normalize=True) # (bs)
rots = rots.unsqueeze(-1) # (bs, 1)
rots = Rot3Array.uniform_random((batch_size, 1), device)

# Sample random translation
trans = s_trans * Vec3Array.from_array(torch.randn((batch_size, 3), device=device))
trans = trans.unsqueeze(-1) # (bs, 1)
# Sample random translation from normal distribution
trans = s_trans * Vec3Array.randn((batch_size, 1), device)

# Apply
atom_positions = rots.apply_to_point(atom_positions) + trans
Expand Down
10 changes: 10 additions & 0 deletions src/utils/geometry/rotation_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ def from_array(cls, array: torch.Tensor) -> Rot3Array:
rc = [torch.unbind(e, dim=-1) for e in rows]
return cls(*[e for row in rc for e in row])

@classmethod
def uniform_random(cls, shape, device='cpu') -> Rot3Array:
"""Generates a random rotation of given shape."""
quaternions = torch.randn((*shape, 4), device=device)
return Rot3Array.from_quaternion(w=quaternions[..., 0],
x=quaternions[..., 1],
y=quaternions[..., 2],
z=quaternions[..., 3],
normalize=True)

def to_tensor(self) -> torch.Tensor:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return torch.stack(
Expand Down
8 changes: 8 additions & 0 deletions src/utils/geometry/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ def cat(cls, vecs: List[Vec3Array], dim: int) -> Vec3Array:
torch.cat([v.z for v in vecs], dim=dim),
)

@classmethod
def randn(cls, shape, device="cpu"):
return cls(
torch.randn(shape, dtype=torch.float32, device=device),
torch.randn(shape, dtype=torch.float32, device=device),
torch.randn(shape, dtype=torch.float32, device=device),
)


def square_euclidean_distance(
vec1: Vec3Array,
Expand Down

0 comments on commit 1dbf1e9

Please sign in to comment.