diff --git a/src/diffusion/loss.py b/src/diffusion/loss.py index d48db47..02fa261 100644 --- a/src/diffusion/loss.py +++ b/src/diffusion/loss.py @@ -1,8 +1,8 @@ """Diffusion losses""" import torch -from src.utils.geometry.vector import Vec3Array, square_euclidean_distance -from torch import nn +from src.utils.geometry.vector import Vec3Array, square_euclidean_distance, euclidean_distance +from torch.nn import functional as F def smooth_lddt_loss( @@ -13,7 +13,32 @@ def smooth_lddt_loss( mask: torch.Tensor = None # (bs, n_atoms) ) -> torch.Tensor: # (bs,) """Smooth local distance difference test (LDDT) auxiliary loss.""" - raise NotImplementedError + bs, n_atoms = pred_atoms.shape + + # Compute distances between all pairs of atoms + delta_x_lm = euclidean_distance(pred_atoms[:, :, None], pred_atoms[:, None, :]) # (bs, n_atoms, n_atoms) + delta_x_gt_lm = euclidean_distance(gt_atoms[:, :, None], gt_atoms[:, None, :]) + + # Compute distance difference for all pairs of atoms + delta_lm = torch.abs(delta_x_gt_lm - delta_x_lm) # (bs, n_atoms, n_atoms) + epsilon_lm = torch.div((F.sigmoid(torch.sub(0.5, delta_lm)) + F.sigmoid(torch.sub(1.0, delta_lm)) + + F.sigmoid(torch.sub(2.0, delta_lm)) + F.sigmoid(torch.sub(4.0, delta_lm))), 4.0) + + # Restrict to bespoke inclusion radius + atom_is_nucleotide = atom_is_dna + atom_is_rna + atom_not_nucleotide = torch.add(torch.neg(atom_is_nucleotide), 1.0) # (1 - atom_is_nucleotide) + c_lm = (delta_x_gt_lm < 30.0).float() * atom_is_nucleotide + (delta_x_gt_lm < 15.0).float() * atom_not_nucleotide + + # Mask positions + if mask is not None: + c_lm *= (mask[:, :, None] * mask[:, None, :]) + + # Compute mean, avoiding self-term + self_mask = torch.eye(n_atoms).unsqueeze(0).expand_as(c_lm).to(c_lm.device) # (bs, n_atoms, n_atoms) + self_mask = torch.add(torch.neg(self_mask), 1.0) + c_lm *= self_mask + lddt = torch.mean(epsilon_lm * c_lm, dim=(1, 2)) / torch.mean(c_lm, dim=(1, 2)) + return torch.add(torch.neg(lddt), 1.0) # (1 - lddt) def mean_squared_error( diff --git a/tests/test_loss.py b/tests/test_loss.py index 29bb43e..d124156 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -1,6 +1,6 @@ import unittest import torch -from src.diffusion.loss import mean_squared_error +from src.diffusion.loss import mean_squared_error, smooth_lddt_loss from src.utils.geometry.vector import Vec3Array @@ -44,5 +44,50 @@ def test_mse_with_mask(self): self.assertTrue(torch.allclose(mse, expected_mse), f"Expected MSE {expected_mse}, got {mse}") +class TestSmoothLDDTLoss(unittest.TestCase): + + def test_basic_functionality(self): + # Create a simple scenario where pred_atoms and gt_atoms are the same, no nucleotide-specific settings. + bs, n_atoms = 1, 4 + pred_atoms = Vec3Array.from_array(torch.tensor([[[0., 0., 0.], [1., 0., 0.], [2., 0., 0.], [3., 0., 0.]]])) + gt_atoms = Vec3Array.from_array(torch.tensor([[[0., 0., 0.], [1., 0., 0.], [2., 0., 0.], [3., 0., 0.]]])) + atom_is_rna = torch.zeros((bs, n_atoms)) + atom_is_dna = torch.zeros((bs, n_atoms)) + mask = None + + # Expected loss is 0 since pred_atoms == gt_atoms + loss_when_identical = smooth_lddt_loss(pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask) + + # Add noise and compare + noisy_pred_atoms = pred_atoms + 0.1 * Vec3Array.from_array(torch.randn((bs, n_atoms, 3))) + noisier_pred_atoms = pred_atoms + 1.0 * Vec3Array.from_array(torch.randn((bs, n_atoms, 3))) + + loss_when_noisy = smooth_lddt_loss(noisy_pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask) + loss_when_noisier = smooth_lddt_loss(noisier_pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask) + + self.assertTrue(torch.all((loss_when_identical < loss_when_noisy) & (loss_when_noisy < loss_when_noisier))) + + def test_mask(self): + # Create a simple scenario where pred_atoms and gt_atoms are the same, no nucleotide-specific settings. + bs, n_atoms = 1, 4 + pred_atoms = Vec3Array.from_array(torch.tensor([[[0., 0., 0.], [1., 0., 0.], [2., 0., 0.], [3., 0., 0.]]])) + gt_atoms = Vec3Array.from_array(torch.tensor([[[0., 0., 0.], [1., 0., 0.], [2., 0., 0.], [3., 0., 0.]]])) + atom_is_rna = torch.randint(0, 2, (bs, n_atoms)) + atom_is_dna = torch.zeros((bs, n_atoms)) + mask = torch.ones((bs, n_atoms)) + + # Expected loss is 0 since pred_atoms == gt_atoms + loss_when_identical = smooth_lddt_loss(pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask) + + # Add noise and compare + noisy_pred_atoms = pred_atoms + 5.0 * Vec3Array.from_array(torch.randn((bs, n_atoms, 3))) + noisier_pred_atoms = pred_atoms + 10.0 * Vec3Array.from_array(torch.randn((bs, n_atoms, 3))) + + loss_when_noisy = smooth_lddt_loss(noisy_pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask) + loss_when_noisier = smooth_lddt_loss(noisier_pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask) + + self.assertTrue(torch.all((loss_when_identical < loss_when_noisy) & (loss_when_noisy < loss_when_noisier))) + + if __name__ == '__main__': unittest.main()