Skip to content

Commit

Permalink
Fix arg passing in inverse property of SymmOp (#4113)
Browse files Browse the repository at this point in the history
* fix arg passing

* add more unit test

* perhaps overwrite? not sure which is better

* what about making a copy?

* make sure copy is correctly done

* use np.asarray to avoid copying

* remove unnecessary id check now that we're deep copy

* try to bump networkx to 3.0+

* Revert "try to bump networkx to 3.0+"

This reverts commit b27adba.
  • Loading branch information
DanielYang59 authored Oct 21, 2024
1 parent 65e21cc commit f4e2838
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
33 changes: 19 additions & 14 deletions src/pymatgen/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import copy
import re
import string
import warnings
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
Raises:
ValueError: if matrix is not 4x4.
"""
affine_transformation_matrix = np.array(affine_transformation_matrix)
affine_transformation_matrix = np.asarray(affine_transformation_matrix)
shape = affine_transformation_matrix.shape
if shape != (4, 4):
raise ValueError(f"Affine Matrix must be a 4x4 numpy array, got {shape=}")
Expand Down Expand Up @@ -96,8 +97,8 @@ def from_rotation_and_translation(
Returns:
SymmOp object
"""
rotation_matrix = np.array(rotation_matrix)
translation_vec = np.array(translation_vec)
rotation_matrix = np.asarray(rotation_matrix)
translation_vec = np.asarray(translation_vec)
if rotation_matrix.shape != (3, 3):
raise ValueError("Rotation Matrix must be a 3x3 numpy array.")
if translation_vec.shape != (3,):
Expand All @@ -117,7 +118,7 @@ def operate(self, point: ArrayLike) -> np.ndarray:
Returns:
Coordinates of point after operation.
"""
affine_point = np.array([*point, 1])
affine_point = np.asarray([*point, 1])
return np.dot(self.affine_matrix, affine_point)[:3]

def operate_multi(self, points: ArrayLike) -> np.ndarray:
Expand All @@ -129,7 +130,7 @@ def operate_multi(self, points: ArrayLike) -> np.ndarray:
Returns:
Numpy array of coordinates after operation
"""
points = np.array(points)
points = np.asarray(points)
affine_points = np.concatenate([points, np.ones(points.shape[:-1] + (1,))], axis=-1)
return np.inner(affine_points, self.affine_matrix)[..., :-1]

Expand Down Expand Up @@ -243,12 +244,16 @@ def translation_vector(self) -> np.ndarray:
@property
def inverse(self) -> Self:
"""Inverse of transformation."""
inverse = np.linalg.inv(self.affine_matrix)
return type(self)(inverse)
new_instance = copy.deepcopy(self)
new_instance.affine_matrix = np.linalg.inv(self.affine_matrix)
return new_instance

@staticmethod
def from_axis_angle_and_translation(
axis: ArrayLike, angle: float, angle_in_radians: bool = False, translation_vec: ArrayLike = (0, 0, 0)
axis: ArrayLike,
angle: float,
angle_in_radians: bool = False,
translation_vec: ArrayLike = (0, 0, 0),
) -> SymmOp:
"""Generate a SymmOp for a rotation about a given axis plus translation.
Expand All @@ -266,7 +271,7 @@ def from_axis_angle_and_translation(
if isinstance(axis, tuple | list):
axis = np.array(axis)

vec = np.array(translation_vec)
vec = np.asarray(translation_vec)

ang = angle if angle_in_radians else angle * pi / 180
cos_a = cos(ang)
Expand Down Expand Up @@ -368,7 +373,7 @@ def reflection(normal: ArrayLike, origin: ArrayLike = (0, 0, 0)) -> SymmOp:
u, v, w = normal

translation = np.eye(4)
translation[:3, 3] = -np.array(origin)
translation[:3, 3] = -np.asarray(origin)

xx = 1 - 2 * u**2
yy = 1 - 2 * v**2
Expand All @@ -395,7 +400,7 @@ def inversion(origin: ArrayLike = (0, 0, 0)) -> SymmOp:
"""
mat = -np.eye(4)
mat[3, 3] = 1
mat[:3, 3] = 2 * np.array(origin)
mat[:3, 3] = 2 * np.asarray(origin)
return SymmOp(mat)

@staticmethod
Expand Down Expand Up @@ -505,11 +510,11 @@ def __init__(
tol (float): Tolerance for determining if matrices are equal.
"""
super().__init__(affine_transformation_matrix, tol=tol)
if time_reversal in {-1, 1}:
self.time_reversal = time_reversal
else:
if time_reversal not in {-1, 1}:
raise RuntimeError(f"Invalid {time_reversal=}, must be 1 or -1")

self.time_reversal = time_reversal

def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
Expand Down
16 changes: 14 additions & 2 deletions tests/core/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def test_operate_multi(self):
def test_inverse(self):
point = np.random.default_rng().random(3)
new_coord = self.op.operate(point)
assert_allclose(self.op.inverse.operate(new_coord), point, 2)

# Make sure tol is passed correctly
self.op.tol = 0.02 # non-default
inverse_op = self.op.inverse
assert_allclose(inverse_op.operate(new_coord), point, 2)
assert_allclose(self.op.tol, inverse_op.tol)

def test_reflection(self):
rng = np.random.default_rng()
Expand Down Expand Up @@ -237,7 +242,6 @@ def test_xyzt_string(self):

magop = MagSymmOp.from_symmop(op, -1)
magop_str = magop.as_xyzt_str()
assert magop.time_reversal == -1
assert magop_str == "3x-2y-z+1/2, -x+12/13, z+1/2, -1"

def test_as_from_dict(self):
Expand All @@ -263,3 +267,11 @@ def test_operate_magmom(self):
for magmom in magmoms:
op = MagSymmOp.from_xyzt_str(xyzt_string)
assert_allclose(transformed_magmom, op.operate_magmom(magmom).global_moment)

def test_inverse(self):
op = SymmOp([[3, -2, -1, 0.5], [-1, 0, 0, 12.0 / 13], [0, 0, 1, 0.5 + 1e-7], [0, 0, 0, 1]], tol=0.02)

magop = MagSymmOp.from_symmop(op, -1)
assert magop.time_reversal == -1
assert magop.tol == 0.02
assert_allclose(magop.inverse.affine_matrix, np.linalg.inv(magop.affine_matrix))

0 comments on commit f4e2838

Please sign in to comment.