From 990fe82f41e3b5231c55cc39b7d95cca1dbc6bec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 11 Nov 2023 12:57:39 +0100 Subject: [PATCH] removed superfluous unused NeighborSpoof class --- apax/model/gmnn.py | 11 +++-------- apax/model/utils.py | 8 -------- tests/unit_tests/model/test_apax.py | 2 -- 3 files changed, 3 insertions(+), 18 deletions(-) delete mode 100644 apax/model/utils.py diff --git a/apax/model/gmnn.py b/apax/model/gmnn.py index 7a20fe41..c846aeae 100644 --- a/apax/model/gmnn.py +++ b/apax/model/gmnn.py @@ -16,7 +16,6 @@ from apax.layers.properties import stress_times_vol from apax.layers.readout import AtomisticReadout from apax.layers.scaling import PerElementScaleShift -from apax.model.utils import NeighborSpoof from apax.utils.math import fp64_sum DisplacementFn = Callable[[Array, Array], Array] @@ -26,11 +25,7 @@ def canonicalize_neighbors(neighbor): - return ( - neighbor.idx - if isinstance(neighbor, (partition.NeighborList, NeighborSpoof)) - else neighbor - ) + return neighbor.idx if isinstance(neighbor, partition.NeighborList) else neighbor def disp_fn(ri, rj, perturbation, box): @@ -96,7 +91,7 @@ def __call__( self, R: Array, Z: Array, - neighbor: Union[partition.NeighborList, NeighborSpoof, Array], + neighbor: Union[partition.NeighborList, Array], box, offsets, perturbation=None, @@ -143,7 +138,7 @@ def __call__( self, R: Array, Z: Array, - neighbor: Union[partition.NeighborList, NeighborSpoof, Array], + neighbor: Union[partition.NeighborList, Array], box, offsets, ): diff --git a/apax/model/utils.py b/apax/model/utils.py deleted file mode 100644 index edfb7a4e..00000000 --- a/apax/model/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -import dataclasses - -import jax.numpy as jnp - - -@dataclasses.dataclass -class NeighborSpoof: - idx: jnp.array diff --git a/tests/unit_tests/model/test_apax.py b/tests/unit_tests/model/test_apax.py index e7940047..67f28751 100644 --- a/tests/unit_tests/model/test_apax.py +++ b/tests/unit_tests/model/test_apax.py @@ -116,8 +116,6 @@ def test_energy_model(): ] ) offsets = jnp.full([6, 3], 0) - # neighbor = NeighborSpoof(idx=idx) - box = np.array([0.0, 0.0, 0.0]) model = EnergyModel()