Skip to content

Commit

Permalink
Implemented SmoothLDDTLoss
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Jun 2, 2024
1 parent 06145dd commit c040980
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 4 deletions.
31 changes: 28 additions & 3 deletions src/diffusion/loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Diffusion losses"""

import torch
from src.utils.geometry.vector import Vec3Array, square_euclidean_distance
from torch import nn
from src.utils.geometry.vector import Vec3Array, square_euclidean_distance, euclidean_distance
from torch.nn import functional as F


def smooth_lddt_loss(
Expand All @@ -13,7 +13,32 @@ def smooth_lddt_loss(
mask: torch.Tensor = None # (bs, n_atoms)
) -> torch.Tensor: # (bs,)
"""Smooth local distance difference test (LDDT) auxiliary loss."""
raise NotImplementedError
bs, n_atoms = pred_atoms.shape

# Compute distances between all pairs of atoms
delta_x_lm = euclidean_distance(pred_atoms[:, :, None], pred_atoms[:, None, :]) # (bs, n_atoms, n_atoms)
delta_x_gt_lm = euclidean_distance(gt_atoms[:, :, None], gt_atoms[:, None, :])

# Compute distance difference for all pairs of atoms
delta_lm = torch.abs(delta_x_gt_lm - delta_x_lm) # (bs, n_atoms, n_atoms)
epsilon_lm = torch.div((F.sigmoid(torch.sub(0.5, delta_lm)) + F.sigmoid(torch.sub(1.0, delta_lm)) +
F.sigmoid(torch.sub(2.0, delta_lm)) + F.sigmoid(torch.sub(4.0, delta_lm))), 4.0)

# Restrict to bespoke inclusion radius
atom_is_nucleotide = atom_is_dna + atom_is_rna
atom_not_nucleotide = torch.add(torch.neg(atom_is_nucleotide), 1.0) # (1 - atom_is_nucleotide)
c_lm = (delta_x_gt_lm < 30.0).float() * atom_is_nucleotide + (delta_x_gt_lm < 15.0).float() * atom_not_nucleotide

# Mask positions
if mask is not None:
c_lm *= (mask[:, :, None] * mask[:, None, :])

# Compute mean, avoiding self-term
self_mask = torch.eye(n_atoms).unsqueeze(0).expand_as(c_lm).to(c_lm.device) # (bs, n_atoms, n_atoms)
self_mask = torch.add(torch.neg(self_mask), 1.0)
c_lm *= self_mask
lddt = torch.mean(epsilon_lm * c_lm, dim=(1, 2)) / torch.mean(c_lm, dim=(1, 2))
return torch.add(torch.neg(lddt), 1.0) # (1 - lddt)


def mean_squared_error(
Expand Down
47 changes: 46 additions & 1 deletion tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import torch
from src.diffusion.loss import mean_squared_error
from src.diffusion.loss import mean_squared_error, smooth_lddt_loss
from src.utils.geometry.vector import Vec3Array


Expand Down Expand Up @@ -44,5 +44,50 @@ def test_mse_with_mask(self):
self.assertTrue(torch.allclose(mse, expected_mse), f"Expected MSE {expected_mse}, got {mse}")


class TestSmoothLDDTLoss(unittest.TestCase):

def test_basic_functionality(self):
# Create a simple scenario where pred_atoms and gt_atoms are the same, no nucleotide-specific settings.
bs, n_atoms = 1, 4
pred_atoms = Vec3Array.from_array(torch.tensor([[[0., 0., 0.], [1., 0., 0.], [2., 0., 0.], [3., 0., 0.]]]))
gt_atoms = Vec3Array.from_array(torch.tensor([[[0., 0., 0.], [1., 0., 0.], [2., 0., 0.], [3., 0., 0.]]]))
atom_is_rna = torch.zeros((bs, n_atoms))
atom_is_dna = torch.zeros((bs, n_atoms))
mask = None

# Expected loss is 0 since pred_atoms == gt_atoms
loss_when_identical = smooth_lddt_loss(pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask)

# Add noise and compare
noisy_pred_atoms = pred_atoms + 0.1 * Vec3Array.from_array(torch.randn((bs, n_atoms, 3)))
noisier_pred_atoms = pred_atoms + 1.0 * Vec3Array.from_array(torch.randn((bs, n_atoms, 3)))

loss_when_noisy = smooth_lddt_loss(noisy_pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask)
loss_when_noisier = smooth_lddt_loss(noisier_pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask)

self.assertTrue(torch.all((loss_when_identical < loss_when_noisy) & (loss_when_noisy < loss_when_noisier)))

def test_mask(self):
# Create a simple scenario where pred_atoms and gt_atoms are the same, no nucleotide-specific settings.
bs, n_atoms = 1, 4
pred_atoms = Vec3Array.from_array(torch.tensor([[[0., 0., 0.], [1., 0., 0.], [2., 0., 0.], [3., 0., 0.]]]))
gt_atoms = Vec3Array.from_array(torch.tensor([[[0., 0., 0.], [1., 0., 0.], [2., 0., 0.], [3., 0., 0.]]]))
atom_is_rna = torch.randint(0, 2, (bs, n_atoms))
atom_is_dna = torch.zeros((bs, n_atoms))
mask = torch.ones((bs, n_atoms))

# Expected loss is 0 since pred_atoms == gt_atoms
loss_when_identical = smooth_lddt_loss(pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask)

# Add noise and compare
noisy_pred_atoms = pred_atoms + 5.0 * Vec3Array.from_array(torch.randn((bs, n_atoms, 3)))
noisier_pred_atoms = pred_atoms + 10.0 * Vec3Array.from_array(torch.randn((bs, n_atoms, 3)))

loss_when_noisy = smooth_lddt_loss(noisy_pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask)
loss_when_noisier = smooth_lddt_loss(noisier_pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask)

self.assertTrue(torch.all((loss_when_identical < loss_when_noisy) & (loss_when_noisy < loss_when_noisier)))


if __name__ == '__main__':
unittest.main()

0 comments on commit c040980

Please sign in to comment.