Skip to content

Commit b7e4b83

Browse files
author
ardagoreci
committed
Added weighted_rigid_align to diffusion loss
1 parent a76b820 commit b7e4b83

File tree

2 files changed

+36
-30
lines changed

2 files changed

+36
-30
lines changed

src/diffusion/loss.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from src.utils.geometry.vector import Vec3Array, square_euclidean_distance, euclidean_distance
5+
from src.utils.geometry.alignment import weighted_rigid_align
56
from torch.nn import functional as F
67
from typing import Optional
78

@@ -23,7 +24,8 @@ def smooth_lddt_loss(
2324
# Compute distance difference for all pairs of atoms
2425
delta_lm = torch.abs(delta_x_gt_lm - delta_x_lm) # (bs, n_atoms, n_atoms)
2526
epsilon_lm = torch.div((F.sigmoid(torch.sub(0.5, delta_lm)) + F.sigmoid(torch.sub(1.0, delta_lm)) +
26-
F.sigmoid(torch.sub(2.0, delta_lm)) + F.sigmoid(torch.sub(4.0, delta_lm))), 4.0)
27+
F.sigmoid(torch.sub(2.0, delta_lm)) + F.sigmoid(torch.sub(4.0, delta_lm))),
28+
4.0)
2729

2830
# Restrict to bespoke inclusion radius
2931
atom_is_nucleotide = (atom_is_dna + atom_is_rna).unsqueeze(-1).expand_as(delta_x_gt_lm)
@@ -73,8 +75,12 @@ def diffusion_loss(
7375
sd_data: float = 16.0, # Standard deviation of the data
7476
) -> torch.Tensor: # (bs,)
7577
"""Diffusion loss that scales the MSE and LDDT losses by the noise level (timestep)."""
76-
mse = mean_squared_error(pred_atoms, gt_atoms, weights, mask)
77-
lddt_loss = smooth_lddt_loss(pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask)
78+
# Align the gt_atoms to pred_atoms
79+
aligned_gt_atoms = weighted_rigid_align(x=gt_atoms, x_gt=pred_atoms, weights=weights, mask=mask)
80+
81+
# MSE loss
82+
mse = mean_squared_error(pred_atoms, aligned_gt_atoms, weights, mask)
83+
lddt_loss = smooth_lddt_loss(pred_atoms, aligned_gt_atoms, atom_is_rna, atom_is_dna, mask)
7884

7985
# Scale by (t**2 + σ**2) / (t + σ)**2
8086
scaling_factor = torch.add(timesteps ** 2, sd_data ** 2) / (torch.add(timesteps, sd_data) ** 2)

src/utils/geometry/alignment.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,30 @@ def weighted_rigid_align(
2626
) -> Vec3Array:
2727
"""Performs a weighted alignment of x to x_gt. Warning: ground truth here only refers to the structure
2828
not being moved, not to be confused with ground truth during training."""
29-
30-
# Mean-centre positions
31-
mu = (x * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True)
32-
mu_gt = (x_gt * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True)
33-
x -= mu # Vec3Array of shape (bs, n_atoms)
34-
x_gt -= mu_gt
35-
36-
# Mask atoms before computing covariance matrix
37-
if mask is not None:
38-
x *= mask
39-
x_gt *= mask
40-
41-
# Find optimal rotation from singular value decomposition
42-
U, S, Vh = torch.linalg.svd(compute_covariance_matrix(x_gt.to_tensor(), x.to_tensor())) # shapes: (bs, 3, 3)
43-
R = U @ Vh
44-
45-
# Remove reflection
46-
if torch.linalg.det(R) < 0:
47-
reflection_matrix = torch.diag((torch.tensor([1, 1, -1], device=U.device, dtype=U.dtype)))
48-
reflection_matrix = reflection_matrix.unsqueeze(0).expand_as(R)
49-
R = U @ reflection_matrix @ Vh # (bs, 3, 3)
50-
51-
R = Rot3Array.from_array(R)
52-
53-
# Apply alignment
54-
x_aligned = R.apply_to_point(x) + mu
55-
return x_aligned
29+
with torch.no_grad():
30+
# Mean-centre positions
31+
mu = (x * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True)
32+
mu_gt = (x_gt * weights).mean(dim=1, keepdim=True) / weights.mean(dim=1, keepdim=True)
33+
x -= mu # Vec3Array of shape (bs, n_atoms)
34+
x_gt -= mu_gt
35+
36+
# Mask atoms before computing covariance matrix
37+
if mask is not None:
38+
x *= mask
39+
x_gt *= mask
40+
41+
# Find optimal rotation from singular value decomposition
42+
U, S, Vh = torch.linalg.svd(compute_covariance_matrix(x_gt.to_tensor(), x.to_tensor())) # shapes: (bs, 3, 3)
43+
R = U @ Vh
44+
45+
# Remove reflection
46+
if torch.linalg.det(R) < 0:
47+
reflection_matrix = torch.diag((torch.tensor([1, 1, -1], device=U.device, dtype=U.dtype)))
48+
reflection_matrix = reflection_matrix.unsqueeze(0).expand_as(R)
49+
R = U @ reflection_matrix @ Vh # (bs, 3, 3)
50+
51+
R = Rot3Array.from_array(R)
52+
53+
# Apply alignment
54+
x_aligned = R.apply_to_point(x) + mu
55+
return x_aligned

0 commit comments

Comments
 (0)