Skip to content

Commit f3b2b43

Browse files
committed
Added explicit input and fix testing for atoms_state.py
1 parent 908a34b commit f3b2b43

File tree

3 files changed

+61
-18
lines changed

3 files changed

+61
-18
lines changed

src/nomad_simulations/atoms_state.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from nomad.datamodel.metainfo.basesections import Entity
2929
from nomad.datamodel.metainfo.annotations import ELNAnnotation
3030

31-
from .utils import RussellSaundersState
31+
from nomad_simulations.utils import RussellSaundersState
3232

3333

3434
class OrbitalsState(Entity):
@@ -280,7 +280,7 @@ def resolve_degeneracy(self) -> Optional[int]:
280280
for jj in self.j_quantum_number:
281281
if self.mj_quantum_number is not None:
282282
mjs = RussellSaundersState.generate_MJs(
283-
self.j_quantum_number[0], rising=True
283+
J=self.j_quantum_number[0], rising=True
284284
)
285285
degeneracy += len(
286286
[mj for mj in mjs if mj in self.mj_quantum_number]
@@ -293,15 +293,15 @@ def normalize(self, archive, logger) -> None:
293293
super().normalize(archive, logger)
294294

295295
# General checks for physical quantum numbers and symbols
296-
if not self.validate_quantum_numbers(logger):
296+
if not self.validate_quantum_numbers(logger=logger):
297297
logger.error('The quantum numbers are not physical.')
298298
return
299299

300300
# Resolving the quantum numbers and symbols if not available
301301
for quantum_name in ['l', 'ml', 'ms']:
302302
for quantum_type in ['number', 'symbol']:
303303
quantity = self.resolve_number_and_symbol(
304-
quantum_name, quantum_type, logger
304+
quantum_name=quantum_name, quantum_type=quantum_type, logger=logger
305305
)
306306
if getattr(self, f'{quantum_name}_quantum_{quantum_type}') is None:
307307
setattr(self, f'{quantum_name}_quantum_{quantum_type}', quantity)
@@ -383,7 +383,7 @@ def normalize(self, archive, logger) -> None:
383383
self.n_excited_electrons = None
384384
self.orbital_ref.degeneracy = 1
385385
if self.orbital_ref.occupation is None:
386-
self.orbital_ref.occupation = self.resolve_occupation(logger)
386+
self.orbital_ref.occupation = self.resolve_occupation(logger=logger)
387387

388388

389389
class HubbardInteractions(ArchiveSection):
@@ -552,11 +552,11 @@ def normalize(self, archive, logger) -> None:
552552
self.u_interaction,
553553
self.u_interorbital_interaction,
554554
self.j_hunds_coupling,
555-
) = self.resolve_u_interactions(logger)
555+
) = self.resolve_u_interactions(logger=logger)
556556

557557
# If u_effective is not available, calculate it
558558
if self.u_effective is None:
559-
self.u_effective = self.resolve_u_effective(logger)
559+
self.u_effective = self.resolve_u_effective(logger=logger)
560560

561561
# Check if length of `orbitals_ref` is the same as the length of `umn`:
562562
if self.u_matrix is not None and self.orbitals_ref is not None:
@@ -652,6 +652,6 @@ def normalize(self, archive, logger) -> None:
652652

653653
# Get chemical_symbol from atomic_number and viceversa
654654
if self.chemical_symbol is None:
655-
self.chemical_symbol = self.resolve_chemical_symbol(logger)
655+
self.chemical_symbol = self.resolve_chemical_symbol(logger=logger)
656656
if self.atomic_number is None:
657-
self.atomic_number = self.resolve_atomic_number(logger)
657+
self.atomic_number = self.resolve_atomic_number(logger=logger)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def generate_atomic_cell(
145145
for index, atom in enumerate(chemical_symbols):
146146
atom_state = AtomsState()
147147
setattr(atom_state, 'chemical_symbol', atom)
148-
atomic_number = atom_state.resolve_atomic_number(logger)
148+
atomic_number = atom_state.resolve_atomic_number(logger=logger)
149149
assert atomic_number == atomic_numbers[index]
150150
atom_state.atomic_number = atomic_number
151151
atomic_cell.atoms_state.append(atom_state)

tests/test_atoms_state.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,18 @@ def test_validate_quantum_numbers(
6969
):
7070
"""
7171
Test the `validate_quantum_numbers` method.
72+
73+
Args:
74+
number_label (str): The quantum number string to be tested.
75+
values (List[int]): The values stored in `OrbitalState`.
76+
results (List[bool]): The expected results after validation.
7277
"""
7378
orbital_state = OrbitalsState(n_quantum_number=2)
7479
for val, res in zip(values, results):
7580
if number_label == 'ml_quantum_number':
7681
orbital_state.l_quantum_number = 2
7782
setattr(orbital_state, number_label, val)
78-
assert orbital_state.validate_quantum_numbers(logger) == res
83+
assert orbital_state.validate_quantum_numbers(logger=logger) == res
7984

8085
@pytest.mark.parametrize(
8186
'quantum_name, value, expected_result',
@@ -103,6 +108,11 @@ def test_number_and_symbol(
103108
):
104109
"""
105110
Test the number and symbol resolution for each of the quantum numbers defined in the parametrization.
111+
112+
Args:
113+
quantum_name (str): The quantum number string to be tested.
114+
value (Union[int, float]): The value stored in `OrbitalState`.
115+
expected_result (Optional[str]): The expected result after resolving the counter-type.
106116
"""
107117
# Adding quantum numbers to the `OrbitalsState` section
108118
orbital_state = OrbitalsState(n_quantum_number=2)
@@ -112,13 +122,13 @@ def test_number_and_symbol(
112122

113123
# Making sure that the `'number'` is assigned
114124
resolved_type = orbital_state.resolve_number_and_symbol(
115-
quantum_name, 'number', logger
125+
quantum_name=quantum_name, quantum_type='number', logger=logger
116126
)
117127
assert resolved_type == value
118128

119129
# Resolving if the counter-type is assigned
120130
resolved_countertype = orbital_state.resolve_number_and_symbol(
121-
quantum_name, 'symbol', logger
131+
quantum_name=quantum_name, quantum_type='symbol', logger=logger
122132
)
123133
assert resolved_countertype == expected_result
124134

@@ -146,6 +156,14 @@ def test_degeneracy(
146156
):
147157
"""
148158
Test the degeneracy of each orbital states defined in the parametrization.
159+
160+
Args:
161+
l_quantum_number (int): The angular momentum quantum number.
162+
ml_quantum_number (Optional[int]): The magnetic quantum number.
163+
j_quantum_number (Optional[List[float]]): The total angular momentum quantum number.
164+
mj_quantum_number (Optional[List[float]]): The magnetic quantum number for the total angular momentum.
165+
ms_quantum_number (Optional[float]): The spin quantum number.
166+
degeneracy (int): The expected degeneracy of the orbital state.
149167
"""
150168
orbital_state = OrbitalsState(n_quantum_number=2)
151169
self.add_state(
@@ -195,13 +213,19 @@ def test_occupation(
195213
):
196214
"""
197215
Test the occupation of a core hole for a given set of orbital reference and degeneracy.
216+
217+
Args:
218+
orbital_ref (Optional[OrbitalsState]): The orbital reference of the core hole.
219+
degeneracy (Optional[int]): The degeneracy of the orbital reference.
220+
n_excited_electrons (float): The number of excited electrons.
221+
occupation (Optional[float]): The expected occupation of the core hole.
198222
"""
199223
core_hole = CoreHole(
200224
orbital_ref=orbital_ref, n_excited_electrons=n_excited_electrons
201225
)
202226
if orbital_ref is not None:
203227
assert orbital_ref.resolve_degeneracy() == degeneracy
204-
resolved_occupation = core_hole.resolve_occupation(logger)
228+
resolved_occupation = core_hole.resolve_occupation(logger=logger)
205229
if resolved_occupation is not None:
206230
assert np.isclose(resolved_occupation, occupation)
207231
else:
@@ -232,6 +256,12 @@ def test_normalize(
232256
):
233257
"""
234258
Test the normalization of the `CoreHole`. Inputs are defined as the quantities of the `CoreHole` section.
259+
260+
Args:
261+
orbital_ref (Optional[OrbitalsState]): The orbital reference of the core hole.
262+
n_excited_electrons (Optional[float]): The number of excited electrons.
263+
dscf_state (Optional[str]): The DSCF state of the core hole.
264+
results (Tuple[Optional[float], Optional[float], Optional[float]]): The expected results after normalization.
235265
"""
236266
core_hole = CoreHole(
237267
orbital_ref=orbital_ref,
@@ -265,6 +295,10 @@ def test_u_interactions(
265295
):
266296
"""
267297
Test the Hubbard interactions `U`, `U'`, and `J` for a given set of Slater integrals.
298+
299+
Args:
300+
slater_integrals (Optional[List[float]]): The Slater integrals of the Hubbard interactions.
301+
results (Tuple[Optional[float], Optional[float], Optional[float]]): The expected results of the Hubbard interactions.
268302
"""
269303
# Adding `slater_integrals` to the `HubbardInteractions` section
270304
hubbard_interactions = HubbardInteractions()
@@ -276,7 +310,7 @@ def test_u_interactions(
276310
u_interaction,
277311
u_interorbital_interaction,
278312
j_hunds_coupling,
279-
) = hubbard_interactions.resolve_u_interactions(logger)
313+
) = hubbard_interactions.resolve_u_interactions(logger=logger)
280314

281315
if None not in (u_interaction, u_interorbital_interaction, j_hunds_coupling):
282316
assert np.isclose(u_interaction.to('eV').magnitude, results[0])
@@ -306,6 +340,11 @@ def test_u_effective(
306340
):
307341
"""
308342
Test the effective Hubbard interaction `U_eff` for a given set of Hubbard interactions `U` and `J`.
343+
344+
Args:
345+
u_interaction (Optional[float]): The Hubbard interaction `U`.
346+
j_local_exchange_interaction (Optional[float]): The Hubbard interaction `J`.
347+
u_effective (Optional[float]): The expected effective Hubbard interaction `U_eff`.
309348
"""
310349
# Adding `u_interaction` and `j_local_exchange_interaction` to the `HubbardInteractions` section
311350
hubbard_interactions = HubbardInteractions()
@@ -317,7 +356,7 @@ def test_u_effective(
317356
)
318357

319358
# Resolving Ueff from class method
320-
resolved_u_effective = hubbard_interactions.resolve_u_effective(logger)
359+
resolved_u_effective = hubbard_interactions.resolve_u_effective(logger=logger)
321360
if resolved_u_effective is not None:
322361
assert np.isclose(resolved_u_effective.to('eV').magnitude, u_effective)
323362
else:
@@ -358,10 +397,14 @@ def test_chemical_symbol_and_atomic_number(
358397
):
359398
"""
360399
Test the `chemical_symbol` and `atomic_number` resolution for the `AtomsState` section.
400+
401+
Args:
402+
chemical_symbol (str): The chemical symbol of the atom.
403+
atomic_number (int): The atomic number of the atom.
361404
"""
362405
# Testing `chemical_symbol`
363406
atom_state = AtomsState(chemical_symbol=chemical_symbol)
364-
assert atom_state.resolve_atomic_number(logger) == atomic_number
407+
assert atom_state.resolve_atomic_number(logger=logger) == atomic_number
365408
# Testing `atomic_number`
366409
atom_state.atomic_number = atomic_number
367-
assert atom_state.resolve_chemical_symbol(logger) == chemical_symbol
410+
assert atom_state.resolve_chemical_symbol(logger=logger) == chemical_symbol

0 commit comments

Comments
 (0)