Skip to content

Commit a007e1e

Browse files
committed
merge develop
2 parents c94e995 + 20d2fb2 commit a007e1e

File tree

5 files changed

+282
-160
lines changed

5 files changed

+282
-160
lines changed

src/nomad_simulations/schema_packages/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ class NOMADSimulationsEntryPoint(SchemaPackageEntryPoint):
3131
description='Limite of the number of atoms in the unit cell to be treated for the system type classification from MatID to work. This is done to avoid overhead of the package.',
3232
)
3333
equal_cell_positions_tolerance: float = Field(
34-
1e-12,
35-
description='Tolerance (in meters) for the cell positions to be considered equal.',
34+
12,
35+
description='Decimal order or tolerance (in meters) for comparing cell positions.',
3636
)
3737

3838
def load(self):

src/nomad_simulations/schema_packages/model_system.py

Lines changed: 137 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
1+
#
2+
# Copyright The NOMAD Authors.
3+
#
4+
# This file is part of NOMAD. See https://nomad-lab.eu for further info.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
119
import re
2-
from typing import TYPE_CHECKING, Optional
20+
from functools import lru_cache
21+
from hashlib import sha1
22+
from typing import TYPE_CHECKING
323

424
import ase
525
import numpy as np
@@ -22,12 +42,17 @@
2242
from nomad.units import ureg
2343

2444
if TYPE_CHECKING:
45+
from collections.abc import Generator
46+
from typing import Any, Callable, Optional
47+
48+
import pint
2549
from nomad.datamodel.datamodel import EntryArchive
2650
from nomad.metainfo import Context, Section
2751
from structlog.stdlib import BoundLogger
2852

2953
from nomad_simulations.schema_packages.atoms_state import AtomsState
3054
from nomad_simulations.schema_packages.utils import (
55+
catch_not_implemented,
3156
get_sibling_section,
3257
is_not_representative,
3358
)
@@ -200,6 +225,72 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
200225
return
201226

202227

228+
def _check_implemented(func: 'Callable'):
229+
"""
230+
Decorator to restrict the comparison functions to the same class.
231+
"""
232+
233+
def wrapper(self, other):
234+
if not isinstance(other, self.__class__):
235+
return NotImplemented
236+
return func(self, other)
237+
238+
return wrapper
239+
240+
241+
class PartialOrderElement:
242+
def __init__(self, representative_variable):
243+
self.representative_variable = representative_variable
244+
245+
def __hash__(self):
246+
return self.representative_variable.__hash__()
247+
248+
@_check_implemented
249+
def __eq__(self, other):
250+
return self.representative_variable == other.representative_variable
251+
252+
@_check_implemented
253+
def __lt__(self, other):
254+
return False
255+
256+
@_check_implemented
257+
def __gt__(self, other):
258+
return False
259+
260+
def __le__(self, other):
261+
return self.__eq__(other)
262+
263+
def __ge__(self, other):
264+
return self.__eq__(other)
265+
266+
# __ne__ assumes that usage in a finite set with its comparison definitions
267+
268+
269+
class HashedPositions(PartialOrderElement):
270+
# `representative_variable` is a `pint.Quantity` object
271+
272+
def __hash__(self):
273+
hash_str = sha1(
274+
np.ascontiguousarray(
275+
np.round(
276+
self.representative_variable.to_base_units().magnitude,
277+
decimals=configuration.equal_cell_positions_tolerance,
278+
out=None,
279+
)
280+
).tobytes()
281+
).hexdigest()
282+
return int(hash_str, 16)
283+
284+
def __eq__(self, other):
285+
"""Equality as defined between HashedPositions."""
286+
if (
287+
self.representative_variable is None
288+
or other.representative_variable is None
289+
):
290+
return NotImplemented
291+
return np.allclose(self.representative_variable, other.representative_variable)
292+
293+
203294
class Cell(GeometricSpace):
204295
"""
205296
A base section used to specify the cell quantities of a system at a given moment in time.
@@ -217,7 +308,7 @@ class Cell(GeometricSpace):
217308
type=MEnum('original', 'primitive', 'conventional'),
218309
description="""
219310
Representation type of the cell structure. It might be:
220-
- 'original' as in origanally parsed,
311+
- 'original' as in originally parsed,
221312
- 'primitive' as the primitive unit cell,
222313
- 'conventional' as the conventional cell used for referencing.
223314
""",
@@ -278,45 +369,36 @@ class Cell(GeometricSpace):
278369
""",
279370
)
280371

281-
def _check_positions(self, positions_1, positions_2) -> list:
282-
# Check that all the `positions`` of `cell_1` match with the ones in `cell_2`
283-
check_positions = []
284-
for i1, pos1 in enumerate(positions_1):
285-
for i2, pos2 in enumerate(positions_2):
286-
if np.allclose(
287-
pos1, pos2, atol=configuration.equal_cell_positions_tolerance
288-
):
289-
check_positions.append([i1, i2])
290-
break
291-
return check_positions
292-
293-
def is_equal_cell(self, other) -> bool:
294-
"""
295-
Check if the cell is equal to an`other` cell by comparing the `positions`.
296-
Args:
297-
other: The other cell to compare with.
298-
Returns:
299-
bool: True if the cells are equal, False otherwise.
300-
"""
301-
# TODO implement checks on `lattice_vectors` and other quantities to ensure the equality of primitive cells
302-
if not isinstance(other, Cell):
303-
return False
372+
@staticmethod
373+
def _generate_comparer(obj: 'Cell') -> 'Generator[Any, None, None]':
374+
try:
375+
return ((HashedPositions(pos)) for pos in obj.positions)
376+
except AttributeError:
377+
raise NotImplementedError
304378

305-
# If the `positions` are empty, return False
306-
if self.positions is None or other.positions is None:
307-
return False
379+
@catch_not_implemented
380+
def is_lt_cell(self, other) -> bool:
381+
return set(self._generate_comparer(self)) < set(self._generate_comparer(other))
308382

309-
# The `positions` should have the same length (same number of positions)
310-
if len(self.positions) != len(other.positions):
311-
return False
312-
n_positions = len(self.positions)
383+
@catch_not_implemented
384+
def is_gt_cell(self, other) -> bool:
385+
return set(self._generate_comparer(self)) > set(self._generate_comparer(other))
313386

314-
check_positions = self._check_positions(
315-
positions_1=self.positions, positions_2=other.positions
316-
)
317-
if len(check_positions) != n_positions:
318-
return False
319-
return True
387+
@catch_not_implemented
388+
def is_le_cell(self, other) -> bool:
389+
return set(self._generate_comparer(self)) <= set(self._generate_comparer(other))
390+
391+
@catch_not_implemented
392+
def is_ge_cell(self, other) -> bool:
393+
return set(self._generate_comparer(self)) >= set(self._generate_comparer(other))
394+
395+
@catch_not_implemented
396+
def is_equal_cell(self, other) -> bool: # TODO: improve naming
397+
return set(self._generate_comparer(self)) == set(self._generate_comparer(other))
398+
399+
def is_ne_cell(self, other) -> bool:
400+
# this does not hold in general, but here we use finite sets
401+
return not self.is_equal_cell(other)
320402

321403
def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
322404
super().normalize(archive, logger)
@@ -361,40 +443,20 @@ def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwarg
361443
# Set the name of the section
362444
self.name = self.m_def.name
363445

364-
def is_equal_cell(self, other) -> bool:
365-
"""
366-
Check if the atomic cell is equal to an`other` atomic cell by comparing the `positions` and
367-
the `AtomsState[*].chemical_symbol`.
368-
Args:
369-
other: The other atomic cell to compare with.
370-
Returns:
371-
bool: True if the atomic cells are equal, False otherwise.
372-
"""
373-
if not isinstance(other, AtomicCell):
374-
return False
375-
376-
# Compare positions using the parent sections's `__eq__` method
377-
if not super().is_equal_cell(other=other):
378-
return False
379-
380-
# Check that the `chemical_symbol` of the atoms in `cell_1` match with the ones in `cell_2`
381-
check_positions = self._check_positions(
382-
positions_1=self.positions, positions_2=other.positions
383-
)
446+
@staticmethod
447+
def _generate_comparer(obj: 'AtomicCell') -> 'Generator[Any, None, None]':
448+
# presumes `atoms_state` mapping 1-to-1 with `positions` and conserves the order
384449
try:
385-
for atom in check_positions:
386-
element_1 = self.atoms_state[atom[0]].chemical_symbol
387-
element_2 = other.atoms_state[atom[1]].chemical_symbol
388-
if element_1 != element_2:
389-
return False
390-
except Exception:
391-
return False
392-
return True
450+
return (
451+
(HashedPositions(pos), PartialOrderElement(st.chemical_symbol))
452+
for pos, st in zip(obj.positions, obj.atoms_state)
453+
)
454+
except AttributeError:
455+
raise NotImplementedError
393456

394457
def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]:
395458
"""
396459
Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`.
397-
398460
Args:
399461
logger (BoundLogger): The logger to log messages.
400462
@@ -412,7 +474,7 @@ def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]:
412474
chemical_symbols.append(atom_state.chemical_symbol)
413475
return chemical_symbols
414476

415-
def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]:
477+
def to_ase_atoms(self, logger: 'BoundLogger') -> 'Optional[ase.Atoms]':
416478
"""
417479
Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell`
418480
section (labels, periodic_boundary_conditions, positions, and lattice_vectors).
@@ -602,8 +664,11 @@ class Symmetry(ArchiveSection):
602664
)
603665

604666
def resolve_analyzed_atomic_cell(
605-
self, symmetry_analyzer: SymmetryAnalyzer, cell_type: str, logger: 'BoundLogger'
606-
) -> Optional[AtomicCell]:
667+
self,
668+
symmetry_analyzer: 'SymmetryAnalyzer',
669+
cell_type: str,
670+
logger: 'BoundLogger',
671+
) -> 'Optional[AtomicCell]':
607672
"""
608673
Resolves the `AtomicCell` section from the `SymmetryAnalyzer` object and the cell_type
609674
(primitive or conventional).
@@ -647,8 +712,8 @@ def resolve_analyzed_atomic_cell(
647712
return atomic_cell
648713

649714
def resolve_bulk_symmetry(
650-
self, original_atomic_cell: AtomicCell, logger: 'BoundLogger'
651-
) -> tuple[Optional[AtomicCell], Optional[AtomicCell]]:
715+
self, original_atomic_cell: 'AtomicCell', logger: 'BoundLogger'
716+
) -> 'tuple[Optional[AtomicCell], Optional[AtomicCell]]':
652717
"""
653718
Resolves the symmetry of the material being simulated using MatID and the
654719
originally parsed data under original_atomic_cell. It generates two other

src/nomad_simulations/schema_packages/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .utils import (
22
RussellSaundersState,
3+
catch_not_implemented,
34
get_composition,
45
get_sibling_section,
56
get_variables,

src/nomad_simulations/schema_packages/utils/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from nomad.config import config
66

77
if TYPE_CHECKING:
8-
from typing import Optional
8+
from typing import Callable, Optional
99

1010
from nomad.datamodel.data import ArchiveSection
1111
from structlog.stdlib import BoundLogger
@@ -154,3 +154,19 @@ def get_composition(children_names: 'list[str]') -> str:
154154
children_count_tup = np.unique(children_names, return_counts=True)
155155
formula = ''.join([f'{name}({count})' for name, count in zip(*children_count_tup)])
156156
return formula if formula else None
157+
158+
159+
def catch_not_implemented(func: 'Callable') -> 'Callable':
160+
"""
161+
Decorator to default comparison functions outside the same class to `False`.
162+
"""
163+
164+
def wrapper(self, other) -> bool:
165+
if not isinstance(other, self.__class__):
166+
return False # ? should this throw an error instead?
167+
try:
168+
return func(self, other)
169+
except (TypeError, NotImplementedError):
170+
return False
171+
172+
return wrapper

0 commit comments

Comments
 (0)