diff --git a/src/diffusion/augmentation.py b/src/diffusion/augmentation.py index 21138eb..69f0a31 100644 --- a/src/diffusion/augmentation.py +++ b/src/diffusion/augmentation.py @@ -22,7 +22,7 @@ def centre_random_augmentation( device = atom_positions.x.device # Center the atoms - center = atom_positions.mean(dim=-2, keepdim=True) + center = atom_positions.mean(dim=-1, keepdim=True) atom_positions = atom_positions - center # Sample random rotation