1
+ #
2
+ # Copyright The NOMAD Authors.
3
+ #
4
+ # This file is part of NOMAD. See https://nomad-lab.eu for further info.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ #
18
+
1
19
import re
2
- from typing import TYPE_CHECKING , Optional
20
+ from functools import lru_cache
21
+ from hashlib import sha1
22
+ from typing import TYPE_CHECKING
3
23
4
24
import ase
5
25
import numpy as np
22
42
from nomad .units import ureg
23
43
24
44
if TYPE_CHECKING :
45
+ from collections .abc import Generator
46
+ from typing import Any , Callable , Optional
47
+
48
+ import pint
25
49
from nomad .datamodel .datamodel import EntryArchive
26
50
from nomad .metainfo import Context , Section
27
51
from structlog .stdlib import BoundLogger
28
52
29
53
from nomad_simulations .schema_packages .atoms_state import AtomsState
30
54
from nomad_simulations .schema_packages .utils import (
55
+ catch_not_implemented ,
31
56
get_sibling_section ,
32
57
is_not_representative ,
33
58
)
@@ -200,6 +225,72 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
200
225
return
201
226
202
227
228
+ def _check_implemented (func : 'Callable' ):
229
+ """
230
+ Decorator to restrict the comparison functions to the same class.
231
+ """
232
+
233
+ def wrapper (self , other ):
234
+ if not isinstance (other , self .__class__ ):
235
+ return NotImplemented
236
+ return func (self , other )
237
+
238
+ return wrapper
239
+
240
+
241
+ class PartialOrderElement :
242
+ def __init__ (self , representative_variable ):
243
+ self .representative_variable = representative_variable
244
+
245
+ def __hash__ (self ):
246
+ return self .representative_variable .__hash__ ()
247
+
248
+ @_check_implemented
249
+ def __eq__ (self , other ):
250
+ return self .representative_variable == other .representative_variable
251
+
252
+ @_check_implemented
253
+ def __lt__ (self , other ):
254
+ return False
255
+
256
+ @_check_implemented
257
+ def __gt__ (self , other ):
258
+ return False
259
+
260
+ def __le__ (self , other ):
261
+ return self .__eq__ (other )
262
+
263
+ def __ge__ (self , other ):
264
+ return self .__eq__ (other )
265
+
266
+ # __ne__ assumes that usage in a finite set with its comparison definitions
267
+
268
+
269
+ class HashedPositions (PartialOrderElement ):
270
+ # `representative_variable` is a `pint.Quantity` object
271
+
272
+ def __hash__ (self ):
273
+ hash_str = sha1 (
274
+ np .ascontiguousarray (
275
+ np .round (
276
+ self .representative_variable .to_base_units ().magnitude ,
277
+ decimals = configuration .equal_cell_positions_tolerance ,
278
+ out = None ,
279
+ )
280
+ ).tobytes ()
281
+ ).hexdigest ()
282
+ return int (hash_str , 16 )
283
+
284
+ def __eq__ (self , other ):
285
+ """Equality as defined between HashedPositions."""
286
+ if (
287
+ self .representative_variable is None
288
+ or other .representative_variable is None
289
+ ):
290
+ return NotImplemented
291
+ return np .allclose (self .representative_variable , other .representative_variable )
292
+
293
+
203
294
class Cell (GeometricSpace ):
204
295
"""
205
296
A base section used to specify the cell quantities of a system at a given moment in time.
@@ -217,7 +308,7 @@ class Cell(GeometricSpace):
217
308
type = MEnum ('original' , 'primitive' , 'conventional' ),
218
309
description = """
219
310
Representation type of the cell structure. It might be:
220
- - 'original' as in origanally parsed,
311
+ - 'original' as in originally parsed,
221
312
- 'primitive' as the primitive unit cell,
222
313
- 'conventional' as the conventional cell used for referencing.
223
314
""" ,
@@ -278,45 +369,36 @@ class Cell(GeometricSpace):
278
369
""" ,
279
370
)
280
371
281
- def _check_positions (self , positions_1 , positions_2 ) -> list :
282
- # Check that all the `positions`` of `cell_1` match with the ones in `cell_2`
283
- check_positions = []
284
- for i1 , pos1 in enumerate (positions_1 ):
285
- for i2 , pos2 in enumerate (positions_2 ):
286
- if np .allclose (
287
- pos1 , pos2 , atol = configuration .equal_cell_positions_tolerance
288
- ):
289
- check_positions .append ([i1 , i2 ])
290
- break
291
- return check_positions
292
-
293
- def is_equal_cell (self , other ) -> bool :
294
- """
295
- Check if the cell is equal to an`other` cell by comparing the `positions`.
296
- Args:
297
- other: The other cell to compare with.
298
- Returns:
299
- bool: True if the cells are equal, False otherwise.
300
- """
301
- # TODO implement checks on `lattice_vectors` and other quantities to ensure the equality of primitive cells
302
- if not isinstance (other , Cell ):
303
- return False
372
+ @staticmethod
373
+ def _generate_comparer (obj : 'Cell' ) -> 'Generator[Any, None, None]' :
374
+ try :
375
+ return ((HashedPositions (pos )) for pos in obj .positions )
376
+ except AttributeError :
377
+ raise NotImplementedError
304
378
305
- # If the `positions` are empty, return False
306
- if self . positions is None or other . positions is None :
307
- return False
379
+ @ catch_not_implemented
380
+ def is_lt_cell ( self , other ) -> bool :
381
+ return set ( self . _generate_comparer ( self )) < set ( self . _generate_comparer ( other ))
308
382
309
- # The `positions` should have the same length (same number of positions)
310
- if len (self .positions ) != len (other .positions ):
311
- return False
312
- n_positions = len (self .positions )
383
+ @catch_not_implemented
384
+ def is_gt_cell (self , other ) -> bool :
385
+ return set (self ._generate_comparer (self )) > set (self ._generate_comparer (other ))
313
386
314
- check_positions = self ._check_positions (
315
- positions_1 = self .positions , positions_2 = other .positions
316
- )
317
- if len (check_positions ) != n_positions :
318
- return False
319
- return True
387
+ @catch_not_implemented
388
+ def is_le_cell (self , other ) -> bool :
389
+ return set (self ._generate_comparer (self )) <= set (self ._generate_comparer (other ))
390
+
391
+ @catch_not_implemented
392
+ def is_ge_cell (self , other ) -> bool :
393
+ return set (self ._generate_comparer (self )) >= set (self ._generate_comparer (other ))
394
+
395
+ @catch_not_implemented
396
+ def is_equal_cell (self , other ) -> bool : # TODO: improve naming
397
+ return set (self ._generate_comparer (self )) == set (self ._generate_comparer (other ))
398
+
399
+ def is_ne_cell (self , other ) -> bool :
400
+ # this does not hold in general, but here we use finite sets
401
+ return not self .is_equal_cell (other )
320
402
321
403
def normalize (self , archive : 'EntryArchive' , logger : 'BoundLogger' ) -> None :
322
404
super ().normalize (archive , logger )
@@ -361,40 +443,20 @@ def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwarg
361
443
# Set the name of the section
362
444
self .name = self .m_def .name
363
445
364
- def is_equal_cell (self , other ) -> bool :
365
- """
366
- Check if the atomic cell is equal to an`other` atomic cell by comparing the `positions` and
367
- the `AtomsState[*].chemical_symbol`.
368
- Args:
369
- other: The other atomic cell to compare with.
370
- Returns:
371
- bool: True if the atomic cells are equal, False otherwise.
372
- """
373
- if not isinstance (other , AtomicCell ):
374
- return False
375
-
376
- # Compare positions using the parent sections's `__eq__` method
377
- if not super ().is_equal_cell (other = other ):
378
- return False
379
-
380
- # Check that the `chemical_symbol` of the atoms in `cell_1` match with the ones in `cell_2`
381
- check_positions = self ._check_positions (
382
- positions_1 = self .positions , positions_2 = other .positions
383
- )
446
+ @staticmethod
447
+ def _generate_comparer (obj : 'AtomicCell' ) -> 'Generator[Any, None, None]' :
448
+ # presumes `atoms_state` mapping 1-to-1 with `positions` and conserves the order
384
449
try :
385
- for atom in check_positions :
386
- element_1 = self .atoms_state [atom [0 ]].chemical_symbol
387
- element_2 = other .atoms_state [atom [1 ]].chemical_symbol
388
- if element_1 != element_2 :
389
- return False
390
- except Exception :
391
- return False
392
- return True
450
+ return (
451
+ (HashedPositions (pos ), PartialOrderElement (st .chemical_symbol ))
452
+ for pos , st in zip (obj .positions , obj .atoms_state )
453
+ )
454
+ except AttributeError :
455
+ raise NotImplementedError
393
456
394
457
def get_chemical_symbols (self , logger : 'BoundLogger' ) -> list [str ]:
395
458
"""
396
459
Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`.
397
-
398
460
Args:
399
461
logger (BoundLogger): The logger to log messages.
400
462
@@ -412,7 +474,7 @@ def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]:
412
474
chemical_symbols .append (atom_state .chemical_symbol )
413
475
return chemical_symbols
414
476
415
- def to_ase_atoms (self , logger : 'BoundLogger' ) -> Optional [ase .Atoms ]:
477
+ def to_ase_atoms (self , logger : 'BoundLogger' ) -> ' Optional[ase.Atoms]' :
416
478
"""
417
479
Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell`
418
480
section (labels, periodic_boundary_conditions, positions, and lattice_vectors).
@@ -602,8 +664,11 @@ class Symmetry(ArchiveSection):
602
664
)
603
665
604
666
def resolve_analyzed_atomic_cell (
605
- self , symmetry_analyzer : SymmetryAnalyzer , cell_type : str , logger : 'BoundLogger'
606
- ) -> Optional [AtomicCell ]:
667
+ self ,
668
+ symmetry_analyzer : 'SymmetryAnalyzer' ,
669
+ cell_type : str ,
670
+ logger : 'BoundLogger' ,
671
+ ) -> 'Optional[AtomicCell]' :
607
672
"""
608
673
Resolves the `AtomicCell` section from the `SymmetryAnalyzer` object and the cell_type
609
674
(primitive or conventional).
@@ -647,8 +712,8 @@ def resolve_analyzed_atomic_cell(
647
712
return atomic_cell
648
713
649
714
def resolve_bulk_symmetry (
650
- self , original_atomic_cell : AtomicCell , logger : 'BoundLogger'
651
- ) -> tuple [Optional [AtomicCell ], Optional [AtomicCell ]]:
715
+ self , original_atomic_cell : ' AtomicCell' , logger : 'BoundLogger'
716
+ ) -> ' tuple[Optional[AtomicCell], Optional[AtomicCell]]' :
652
717
"""
653
718
Resolves the symmetry of the material being simulated using MatID and the
654
719
originally parsed data under original_atomic_cell. It generates two other
0 commit comments