From 80caef3f695ea4fda5f4004ee78be92aacb9f606 Mon Sep 17 00:00:00 2001 From: ardagoreci <62720042+ardagoreci@users.noreply.github.com> Date: Fri, 23 Aug 2024 09:39:55 -0700 Subject: [PATCH] add epsilon to the loss division, compile diffusion forward --- src/models/diffusion_module.py | 1 + src/utils/loss.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/models/diffusion_module.py b/src/models/diffusion_module.py index 0484f4b..6acc182 100644 --- a/src/models/diffusion_module.py +++ b/src/models/diffusion_module.py @@ -158,6 +158,7 @@ def rescale_with_updates( r_update_scale = torch.sqrt(noisy_pos_scale) * timesteps.unsqueeze(-1) return noisy_atoms * noisy_pos_scale + r_updates * r_update_scale + @torch.compile def forward( self, noisy_atoms: Tensor, # (bs, S, n_atoms, 3) diff --git a/src/utils/loss.py b/src/utils/loss.py index 432b3ef..2a2f121 100644 --- a/src/utils/loss.py +++ b/src/utils/loss.py @@ -111,6 +111,7 @@ def diffusion_loss( weights: Tensor, # (bs, n_atoms) mask: Optional[Tensor] = None, # (bs, n_atoms) sd_data: float = 16.0, # Standard deviation of the data + epsilon: Optional[float] = 1e-5, **kwargs ) -> Tensor: # (bs,) """Diffusion loss that scales the MSE and LDDT losses by the noise level (timestep).""" @@ -134,7 +135,7 @@ def diffusion_loss( mse = mean_squared_error(pred_atoms, aligned_gt_atoms, weights, mask) # Scale by (t**2 + σ**2) / (t + σ)**2 - scaling_factor = torch.add(timesteps ** 2, sd_data ** 2) / (torch.mul(timesteps, sd_data) ** 2) + scaling_factor = torch.add(timesteps ** 2, sd_data ** 2) / (torch.mul(timesteps, sd_data) ** 2 + epsilon) loss_diffusion = scaling_factor * mse # Smooth LDDT Loss