From 6bdcd5e67fca9d77e0bdc383e5e011f64e5bcf93 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Mon, 15 Jan 2024 14:32:30 +0000 Subject: [PATCH 1/2] add (9,) and (6,) format for stress and virials --- mace/data/atomic_data.py | 4 +++- mace/data/utils.py | 4 ++-- mace/tools/torch_tools.py | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index 31170d3d..edb91b14 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -175,7 +175,9 @@ def from_config( else None ) virials = ( - torch.tensor(config.virials, dtype=torch.get_default_dtype()).unsqueeze(0) + voigt_to_matrix( + torch.tensor(config.virials, dtype=torch.get_default_dtype()) + ).unsqueeze(0) if config.virials is not None else None ) diff --git a/mace/data/utils.py b/mace/data/utils.py index 0069550f..908fdc17 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -17,8 +17,8 @@ Vector = np.ndarray # [3,] Positions = np.ndarray # [..., 3] Forces = np.ndarray # [..., 3] -Stress = np.ndarray # [6, ] -Virials = np.ndarray # [3,3] +Stress = np.ndarray # [6, ], [3,3], [9, ] +Virials = np.ndarray # [6, ], [3,3], [9, ] Charges = np.ndarray # [..., 1] Cell = np.ndarray # [3,3] Pbc = tuple # (3,) diff --git a/mace/tools/torch_tools.py b/mace/tools/torch_tools.py index e0c4d546..349f1e3b 100644 --- a/mace/tools/torch_tools.py +++ b/mace/tools/torch_tools.py @@ -107,7 +107,7 @@ def cartesian_to_spherical(t: torch.Tensor): def voigt_to_matrix(t: torch.Tensor): """ Convert voigt notation to matrix notation - :param t: (6,) tensor or (3, 3) tensor + :param t: (6,) tensor or (3, 3) tensor or (9,) tensor :return: (3, 3) tensor """ if t.shape == (3, 3): @@ -121,9 +121,11 @@ def voigt_to_matrix(t: torch.Tensor): ], dtype=t.dtype, ) + if t.shape == (9,): + return t.view(3, 3) raise ValueError( - f"Stress tensor must be of shape (6,) or (3, 3), but has shape {t.shape}" + f"Stress tensor must be of shape (6,) or (3, 3), or (9,) but has shape {t.shape}" ) From 9395221e3e0deca4182ba6ba17039bc79825d395 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Tue, 16 Jan 2024 15:48:47 +0000 Subject: [PATCH 2/2] update version to 0.3.4 --- mace/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/__version__.py b/mace/__version__.py index e19434e2..334b8995 100644 --- a/mace/__version__.py +++ b/mace/__version__.py @@ -1 +1 @@ -__version__ = "0.3.3" +__version__ = "0.3.4"