@@ -26,30 +26,30 @@ def weighted_rigid_align(
26
26
) -> Vec3Array :
27
27
"""Performs a weighted alignment of x to x_gt. Warning: ground truth here only refers to the structure
28
28
not being moved, not to be confused with ground truth during training."""
29
-
30
- # Mean-centre positions
31
- mu = (x * weights ).mean (dim = 1 , keepdim = True ) / weights .mean (dim = 1 , keepdim = True )
32
- mu_gt = (x_gt * weights ).mean (dim = 1 , keepdim = True ) / weights .mean (dim = 1 , keepdim = True )
33
- x -= mu # Vec3Array of shape (bs, n_atoms)
34
- x_gt -= mu_gt
35
-
36
- # Mask atoms before computing covariance matrix
37
- if mask is not None :
38
- x *= mask
39
- x_gt *= mask
40
-
41
- # Find optimal rotation from singular value decomposition
42
- U , S , Vh = torch .linalg .svd (compute_covariance_matrix (x_gt .to_tensor (), x .to_tensor ())) # shapes: (bs, 3, 3)
43
- R = U @ Vh
44
-
45
- # Remove reflection
46
- if torch .linalg .det (R ) < 0 :
47
- reflection_matrix = torch .diag ((torch .tensor ([1 , 1 , - 1 ], device = U .device , dtype = U .dtype )))
48
- reflection_matrix = reflection_matrix .unsqueeze (0 ).expand_as (R )
49
- R = U @ reflection_matrix @ Vh # (bs, 3, 3)
50
-
51
- R = Rot3Array .from_array (R )
52
-
53
- # Apply alignment
54
- x_aligned = R .apply_to_point (x ) + mu
55
- return x_aligned
29
+ with torch . no_grad ():
30
+ # Mean-centre positions
31
+ mu = (x * weights ).mean (dim = 1 , keepdim = True ) / weights .mean (dim = 1 , keepdim = True )
32
+ mu_gt = (x_gt * weights ).mean (dim = 1 , keepdim = True ) / weights .mean (dim = 1 , keepdim = True )
33
+ x -= mu # Vec3Array of shape (bs, n_atoms)
34
+ x_gt -= mu_gt
35
+
36
+ # Mask atoms before computing covariance matrix
37
+ if mask is not None :
38
+ x *= mask
39
+ x_gt *= mask
40
+
41
+ # Find optimal rotation from singular value decomposition
42
+ U , S , Vh = torch .linalg .svd (compute_covariance_matrix (x_gt .to_tensor (), x .to_tensor ())) # shapes: (bs, 3, 3)
43
+ R = U @ Vh
44
+
45
+ # Remove reflection
46
+ if torch .linalg .det (R ) < 0 :
47
+ reflection_matrix = torch .diag ((torch .tensor ([1 , 1 , - 1 ], device = U .device , dtype = U .dtype )))
48
+ reflection_matrix = reflection_matrix .unsqueeze (0 ).expand_as (R )
49
+ R = U @ reflection_matrix @ Vh # (bs, 3, 3)
50
+
51
+ R = Rot3Array .from_array (R )
52
+
53
+ # Apply alignment
54
+ x_aligned = R .apply_to_point (x ) + mu
55
+ return x_aligned
0 commit comments