diff --git a/src/models/msa_module.py b/src/models/msa_module.py index 50ece8b..a847eb2 100644 --- a/src/models/msa_module.py +++ b/src/models/msa_module.py @@ -26,7 +26,7 @@ from src.models.components.outer_product_mean import OuterProductMean from src.models.components.transition import Transition from src.models.components.dropout import DropoutRowwise -from src.utils.tensor_utils import add +from src.utils.tensor_utils import add, flatten_final_dims from functools import partial from typing import Dict from src.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn @@ -121,9 +121,7 @@ def forward( o = g * torch.sum(v * weights, dim=-3) # (*, seq, res, heads, c_hidden) # Output projection - output = self.output_proj( - o.reshape((m.shape[:-1] + (self.c_hidden * self.no_heads,))) # (*, seq, res, c_hidden * heads) - ) + output = self.output_proj(flatten_final_dims(o, 2)) # (*, seq, res, c_hidden * heads) return output diff --git a/src/utils/geometry/alignment.py b/src/utils/geometry/alignment.py index aa2d731..6e29d82 100644 --- a/src/utils/geometry/alignment.py +++ b/src/utils/geometry/alignment.py @@ -22,7 +22,8 @@ def weighted_rigid_align( x: Vec3Array, # (bs, n_atoms) x_gt: Vec3Array, # (bs, n_atoms) weights: torch.Tensor, # (bs, n_atoms) - mask: torch.Tensor = None # (bs, n_atoms) + mask: torch.Tensor = None, # (bs, n_atoms) + eps: float = 1e-5 ) -> Vec3Array: """Performs a weighted alignment of x to x_gt. Warning: ground truth here only refers to the structure not being moved, not to be confused with ground truth during training.""" @@ -32,8 +33,8 @@ def weighted_rigid_align( torch.set_float32_matmul_precision("highest") # Mean-centre positions - mu = (x * weights).mean(dim=-1, keepdim=True) / weights.mean(dim=-1, keepdim=True) - mu_gt = (x_gt * weights).mean(dim=-1, keepdim=True) / weights.mean(dim=-1, keepdim=True) + mu = (x * weights).sum(dim=-1, keepdim=True) / (weights.sum(dim=-1, keepdim=True) + eps) + mu_gt = (x_gt * weights).sum(dim=-1, keepdim=True) / (weights.sum(dim=-1, keepdim=True) + eps) x -= mu # Vec3Array of shape (*, n_atoms) x_gt -= mu_gt