Skip to content

Commit

Permalink
Patch Atoms class to use np.allclose rather than strict equality
Browse files Browse the repository at this point in the history
  • Loading branch information
elinscott committed Oct 11, 2024
1 parent 4aa6e80 commit ae87b90
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 30 deletions.
1 change: 0 additions & 1 deletion src/koopmans/processes/_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __eq__(self, other):
if isinstance(cond, np.ndarray):
cond = cond.all()
if not cond:
raise ValueError(f'Debug: {key} differs: {getattr(self, key)} != {getattr(other, key)}')
return False
return True

Expand Down
25 changes: 0 additions & 25 deletions src/koopmans/processes/ui/_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,28 +81,3 @@ def fromatoms(cls, atoms: Atoms, supercell_matrix: Optional[np.ndarray] = None):

# Return the new UIAtoms object
return ui_atoms

def __eq__(self, other):
# Patching for test
if not isinstance(other, UIAtoms):
return False
a = self.arrays
b = other.arrays
if len(self) != len(other):
utils.warn('Atoms have different lengths')
return False
if (a['numbers'] != b['numbers']).any():
utils.warn('Atoms have different numbers')
return False
if (a['positions'] != b['positions']).any():
utils.warn('Atoms have different positions')
utils.warn(f'{a["positions"]}\n{b["positions"]}')
utils.warn(f'{a["positions"] == b["positions"]}')
return False
if not (self.cell == other.cell).all():
utils.warn('Atoms have different cells')
return False
if not (self.pbc == other.pbc).all():
utils.warn('Atoms have different pbc')
return False
return True
6 changes: 2 additions & 4 deletions src/koopmans/workflows/_unfold_and_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def _run(self) -> None:
dft_smooth_ham_file = None

process = self.new_ui_process(presets, centers=centers[mask], spreads=spreads[mask].tolist(),
dft_smooth_ham_file=dft_smooth_ham_file,
dft_ham_file=self._dft_ham_files[(filling, spin)])
dft_smooth_ham_file=dft_smooth_ham_file)

# Run the process
self.run_process(process)
Expand Down Expand Up @@ -178,8 +177,7 @@ def new_ui_process(self, presets: str, **kwargs) -> UnfoldAndInterpolateProcess:
preset_tuple = (presets, None)

kwargs['kc_ham_file'] = self._koopmans_ham_files[preset_tuple]
if self.calculator_parameters['ui'].do_smooth_interpolation:
kwargs['dft_ham_file'] = self._dft_ham_files[preset_tuple]
kwargs['dft_ham_file'] = self._dft_ham_files[preset_tuple]

parameters = self.calculator_parameters['ui']
parameters.kgrid = self.kpoints.grid
Expand Down
17 changes: 17 additions & 0 deletions tests/helpers/patches/_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Union

import numpy as np
from ase import Atoms

from koopmans.files import FilePointer
from koopmans.io import read_pkl
Expand All @@ -15,6 +16,19 @@
recursively_find_files)


def atoms_eq(self, other):
# Patching the Atoms class to compare positions and cell with np.allclose rather than strict equality
if not isinstance(other, Atoms):
return False
a = self.arrays
b = other.arrays
return (len(self) == len(other) and
np.allclose(a['positions'], b['positions']) and
(a['numbers'] == b['numbers']).all() and
np.allclose(self.cell, other.cell) and
(self.pbc == other.pbc).all())


def write_mock_file(filename: Union[Path, str], written_by: str):
filename = Path(filename)
with chdir(filename.parent):
Expand Down Expand Up @@ -173,3 +187,6 @@ def monkeypatch_mock(monkeypatch):
# Processes
for p in [ExtractCoefficientsFromXMLProcess, ComputePowerSpectrumProcess, Bin2XMLProcess, ConvertFilesFromSpin1To2, ConvertFilesFromSpin2To1, ExtendProcess, MergeProcess, UnfoldAndInterpolateProcess, MergeEVCProcess]:
monkeypatch.setattr(p, '_run', mock_process_run)

# Patch the Atoms class
monkeypatch.setattr(Atoms, '__eq__', atoms_eq)

0 comments on commit ae87b90

Please sign in to comment.