Skip to content

Commit

Permalink
added weighted alignment for the loss
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Aug 13, 2024
1 parent aa60790 commit 7bd3f73
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ def diffusion_loss(
pred_atoms = Vec3Array.from_array(pred_atoms)
gt_atoms = Vec3Array.from_array(gt_atoms)

# Align the gt_atoms to pred_atoms TODO: add the alignment back in
aligned_gt_atoms = gt_atoms
# weighted_rigid_align(x=gt_atoms, x_gt=pred_atoms, weights=weights, mask=mask))
# Align the gt_atoms to pred_atoms
aligned_gt_atoms = weighted_rigid_align(x=gt_atoms, x_gt=pred_atoms, weights=weights, mask=mask)

# MSE loss
mse = mean_squared_error(pred_atoms, aligned_gt_atoms, weights, mask)
Expand Down

0 comments on commit 7bd3f73

Please sign in to comment.