Skip to content

Commit

Permalink
stabilize alignment, cleanup MSA module
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 19, 2024
1 parent 70b4854 commit 5ef8659
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
6 changes: 2 additions & 4 deletions src/models/msa_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from src.models.components.outer_product_mean import OuterProductMean
from src.models.components.transition import Transition
from src.models.components.dropout import DropoutRowwise
from src.utils.tensor_utils import add
from src.utils.tensor_utils import add, flatten_final_dims
from functools import partial
from typing import Dict
from src.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
Expand Down Expand Up @@ -121,9 +121,7 @@ def forward(
o = g * torch.sum(v * weights, dim=-3) # (*, seq, res, heads, c_hidden)

# Output projection
output = self.output_proj(
o.reshape((m.shape[:-1] + (self.c_hidden * self.no_heads,))) # (*, seq, res, c_hidden * heads)
)
output = self.output_proj(flatten_final_dims(o, 2)) # (*, seq, res, c_hidden * heads)
return output


Expand Down
7 changes: 4 additions & 3 deletions src/utils/geometry/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def weighted_rigid_align(
x: Vec3Array, # (bs, n_atoms)
x_gt: Vec3Array, # (bs, n_atoms)
weights: torch.Tensor, # (bs, n_atoms)
mask: torch.Tensor = None # (bs, n_atoms)
mask: torch.Tensor = None, # (bs, n_atoms)
eps: float = 1e-5
) -> Vec3Array:
"""Performs a weighted alignment of x to x_gt. Warning: ground truth here only refers to the structure
not being moved, not to be confused with ground truth during training."""
Expand All @@ -32,8 +33,8 @@ def weighted_rigid_align(
torch.set_float32_matmul_precision("highest")

# Mean-centre positions
mu = (x * weights).mean(dim=-1, keepdim=True) / weights.mean(dim=-1, keepdim=True)
mu_gt = (x_gt * weights).mean(dim=-1, keepdim=True) / weights.mean(dim=-1, keepdim=True)
mu = (x * weights).sum(dim=-1, keepdim=True) / (weights.sum(dim=-1, keepdim=True) + eps)
mu_gt = (x_gt * weights).sum(dim=-1, keepdim=True) / (weights.sum(dim=-1, keepdim=True) + eps)
x -= mu # Vec3Array of shape (*, n_atoms)
x_gt -= mu_gt

Expand Down

0 comments on commit 5ef8659

Please sign in to comment.