Skip to content

Commit

Permalink
add epsilon to the loss division, compile diffusion forward
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 23, 2024
1 parent 85105c4 commit 80caef3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/models/diffusion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def rescale_with_updates(
r_update_scale = torch.sqrt(noisy_pos_scale) * timesteps.unsqueeze(-1)
return noisy_atoms * noisy_pos_scale + r_updates * r_update_scale

@torch.compile
def forward(
self,
noisy_atoms: Tensor, # (bs, S, n_atoms, 3)
Expand Down
3 changes: 2 additions & 1 deletion src/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def diffusion_loss(
weights: Tensor, # (bs, n_atoms)
mask: Optional[Tensor] = None, # (bs, n_atoms)
sd_data: float = 16.0, # Standard deviation of the data
epsilon: Optional[float] = 1e-5,
**kwargs
) -> Tensor: # (bs,)
"""Diffusion loss that scales the MSE and LDDT losses by the noise level (timestep)."""
Expand All @@ -134,7 +135,7 @@ def diffusion_loss(
mse = mean_squared_error(pred_atoms, aligned_gt_atoms, weights, mask)

# Scale by (t**2 + σ**2) / (t + σ)**2
scaling_factor = torch.add(timesteps ** 2, sd_data ** 2) / (torch.mul(timesteps, sd_data) ** 2)
scaling_factor = torch.add(timesteps ** 2, sd_data ** 2) / (torch.mul(timesteps, sd_data) ** 2 + epsilon)
loss_diffusion = scaling_factor * mse

# Smooth LDDT Loss
Expand Down

0 comments on commit 80caef3

Please sign in to comment.