diff --git a/opencomplex/loss/loss_fns_rna.py b/opencomplex/loss/loss_fns_rna.py index 3e459b5..a723417 100644 --- a/opencomplex/loss/loss_fns_rna.py +++ b/opencomplex/loss/loss_fns_rna.py @@ -1087,9 +1087,9 @@ def extreme_c4_c4_distance_violations( Fraction of consecutive CA-CA pairs with violation. """ this_c4_pos = pred_atom_positions[..., :-1, 3, :] - this_c4_mask = pred_atom_mask[..., :-1, 3] + this_c4_mask = pred_atom_mask[..., :-1] next_c4_pos = pred_atom_positions[..., 1:, 3, :] - next_c4_mask = pred_atom_mask[..., 1:, 3] + next_c4_mask = pred_atom_mask[..., 1:] has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0 c4_c4_distance = torch.sqrt( eps + torch.sum((this_c4_pos - next_c4_pos) ** 2, dim=-1)