Skip to content

Commit

Permalink
original alphafold3 loss settings
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 23, 2024
1 parent fcf03e2 commit 9d492a3
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 75 deletions.
15 changes: 4 additions & 11 deletions configs/model/alphafold3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,11 @@ scheduler:

# Loss configs
loss:
diffusion_loss:
# mse:
# weight_dna: 5.0
# weight_rna: 5.0
# weight_ligand: 10.0
# weight_protein: 1.0
# smooth_lddt:
# weight: 1.0
sd_data: 16.0
weight: 4.0
mse_loss:
sd_data: 16.0
weight: 4.0
smooth_lddt_loss:
weight: 1.0
weight: 4.0
epsilon: 1e-5

distogram:
Expand Down
17 changes: 6 additions & 11 deletions configs/model/mini-af3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,12 @@ scheduler:

# Loss configs
loss:
diffusion_loss:
# mse:
# weight_dna: 5.0
# weight_rna: 5.0
# weight_ligand: 10.0
# weight_protein: 1.0
# smooth_lddt:
# weight: 1.0
use_smooth_lddt: false
sd_data: 16.0
weight: 1.0
mse_loss:
sd_data: 16.0
weight: 4.0
smooth_lddt_loss:
weight: 4.0
epsilon: 1e-5

distogram:
min_bin: 2.3125
Expand Down
15 changes: 4 additions & 11 deletions configs/model/proteus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,11 @@ scheduler:

# Loss configs
loss:
diffusion_loss:
# mse:
# weight_dna: 5.0
# weight_rna: 5.0
# weight_ligand: 10.0
# weight_protein: 1.0
# smooth_lddt:
# weight: 1.0
sd_data: 16.0
weight: 4.0
mse_loss:
sd_data: 16.0
weight: 4.0
smooth_lddt_loss:
weight: 1.0
weight: 4.0
epsilon: 1e-5

model:
Expand Down
15 changes: 4 additions & 11 deletions configs/model/small-alphafold3-copy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,11 @@ scheduler:

# Loss configs
loss:
diffusion_loss:
# mse:
# weight_dna: 5.0
# weight_rna: 5.0
# weight_ligand: 10.0
# weight_protein: 1.0
# smooth_lddt:
# weight: 1.0
sd_data: 16.0
weight: 4.0
mse_loss:
sd_data: 16.0
weight: 4.0
smooth_lddt_loss:
weight: 1.0
weight: 4.0
epsilon: 1e-5

distogram:
Expand Down
15 changes: 4 additions & 11 deletions configs/model/small-alphafold3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,11 @@ scheduler:

# Loss configs
loss:
diffusion_loss:
# mse:
# weight_dna: 5.0
# weight_rna: 5.0
# weight_ligand: 10.0
# weight_protein: 1.0
# smooth_lddt:
# weight: 1.0
sd_data: 16.0
weight: 4.0
mse_loss:
sd_data: 16.0
weight: 4.0
smooth_lddt_loss:
weight: 1.0
weight: 4.0
epsilon: 1e-5

distogram:
Expand Down
1 change: 0 additions & 1 deletion src/models/diffusion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def rescale_with_updates(
r_update_scale = torch.sqrt(noisy_pos_scale) * timesteps.unsqueeze(-1)
return noisy_atoms * noisy_pos_scale + r_updates * r_update_scale

@torch.compile
def forward(
self,
noisy_atoms: Tensor, # (bs, S, n_atoms, 3)
Expand Down
23 changes: 9 additions & 14 deletions src/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def bond_loss(
raise NotImplementedError("the implementation of this function will depend on the input pipeline")


def diffusion_loss(
def mse_loss(
pred_atoms: Tensor, # (bs * samples_per_trunk, n_atoms, 3)
gt_atoms: Tensor, # (bs * samples_per_trunk, n_atoms, 3)
timesteps: Tensor, # (bs * samples_per_trunk, 1)
Expand All @@ -128,23 +128,18 @@ def diffusion_loss(
pred_atoms = Vec3Array.from_array(pred_atoms)
gt_atoms = Vec3Array.from_array(gt_atoms)

# Align the gt_atoms to pred_atoms, TODO: put the alignment back in
aligned_gt_atoms = gt_atoms # weighted_rigid_align(x=gt_atoms, x_gt=pred_atoms, weights=weights, mask=mask)
# Align the gt_atoms to pred_atoms
aligned_gt_atoms = weighted_rigid_align(x=gt_atoms, x_gt=pred_atoms, weights=weights, mask=mask)

# MSE loss
mse = mean_squared_error(pred_atoms, aligned_gt_atoms, weights, mask)

# Scale by (t**2 + σ**2) / (t + σ)**2
scaling_factor = torch.add(timesteps ** 2, sd_data ** 2) / (torch.mul(timesteps, sd_data) ** 2 + epsilon)
loss_diffusion = scaling_factor.squeeze(-1) * mse # (bs)

# Smooth LDDT Loss
# if use_smooth_lddt:
# lddt_loss = smooth_lddt_loss(pred_atoms, gt_atoms, atom_is_rna, atom_is_dna, mask)
# loss_diffusion = loss_diffusion + lddt_loss
scaled_mse = scaling_factor.squeeze(-1) * mse # (bs,)

# Average over batch dimension
return torch.mean(loss_diffusion)
return torch.mean(scaled_mse)


def softmax_cross_entropy(logits, labels):
Expand Down Expand Up @@ -285,13 +280,13 @@ def loss(self, out, batch, _return_breakdown=False):
# logits=out["plddt_logits"],
# **{**batch, **self.config.plddt_loss},
# ),
"diffusion_loss": lambda: diffusion_loss(
"mse_loss": lambda: mse_loss(
pred_atoms=out["denoised_atoms"],
gt_atoms=out["augmented_gt_atoms"], # rotated gt atoms from diffusion module
timesteps=out["timesteps"],
weights=batch["atom_exists"],
mask=batch["atom_exists"],
**{**self.config.diffusion_loss},
**{**self.config.mse_loss},
)
}
cumulative_loss = 0.0
Expand Down Expand Up @@ -335,13 +330,13 @@ def loss(self, out, batch, _return_breakdown=False):
atom_is_dna=batch["ref_mask"].new_zeros(batch["ref_mask"].shape), # (bs, n_atoms)
mask=batch["atom_exists"],
),
"diffusion_loss": lambda: diffusion_loss(
"diffusion_loss": lambda: mse_loss(
pred_atoms=out["denoised_atoms"],
gt_atoms=out["augmented_gt_atoms"], # rotated gt atoms from diffusion module
timesteps=out["timesteps"],
weights=batch["atom_exists"],
mask=batch["atom_exists"],
**{**self.config.diffusion_loss},
**{**self.config.mse_loss},
)
}
cumulative_loss = 0.0
Expand Down
1 change: 0 additions & 1 deletion tests/test_diffusion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def setUp(self):
atom_decoder_heads=self.atom_decoder_heads,
token_transformer_blocks=self.token_transformer_blocks,
token_transformer_heads=self.token_transformer_heads,
compile_model=False
) # values above are default values

self.optimizer = torch.optim.Adam(self.module.parameters(), lr=1e-3)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import torch
from src.utils.loss import mean_squared_error, smooth_lddt_loss, diffusion_loss
from src.utils.loss import mean_squared_error, smooth_lddt_loss, mse_loss
from src.diffusion.sample import sample_noise_level, noise_positions
from src.utils.geometry.vector import Vec3Array

Expand Down Expand Up @@ -102,14 +102,14 @@ def test_basic_functionality(self):
weights = torch.ones((bs, n_atoms))
mask = None

loss_when_identical = diffusion_loss(pred_atoms, gt_atoms, timesteps, atom_is_rna, atom_is_dna, weights, mask)
loss_when_identical = mse_loss(pred_atoms, gt_atoms, timesteps, atom_is_rna, atom_is_dna, weights, mask)

# Add noise and compare
noisy_pred_atoms = noise_positions(pred_atoms, timesteps)
noisier_pred_atoms = noise_positions(noisy_pred_atoms, timesteps)

loss_when_noisy = diffusion_loss(noisy_pred_atoms, gt_atoms, timesteps, atom_is_rna, atom_is_dna, weights, mask)
loss_when_noisier = diffusion_loss(noisier_pred_atoms, gt_atoms, timesteps, atom_is_rna, atom_is_dna, weights, mask)
loss_when_noisy = mse_loss(noisy_pred_atoms, gt_atoms, timesteps, atom_is_rna, atom_is_dna, weights, mask)
loss_when_noisier = mse_loss(noisier_pred_atoms, gt_atoms, timesteps, atom_is_rna, atom_is_dna, weights, mask)

self.assertTrue(isinstance(loss_when_identical, torch.Tensor))
self.assertTrue(torch.all((loss_when_identical < loss_when_noisy) & (loss_when_noisy < loss_when_noisier)))
Expand Down

0 comments on commit 9d492a3

Please sign in to comment.