Skip to content

Commit 7e0c8f4

Browse files
authored
tests/io/aims use numpy.testing.assert_allclose and pytest.MonkeyPatch (#3575)
* tests/io/aims use numpy.testing.assert_allclose for better error msgs and pytest.MonkeyPatch for auto-reverted side-effects * remove superfluous np.all() in asserts * fix pytest.raises match regex escape * rename d->dct
1 parent f392c8d commit 7e0c8f4

38 files changed

+384
-412
lines changed

examples/aims_io/FHI-aims-example.ipynb

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
"metadata": {},
88
"outputs": [],
99
"source": [
10-
"from pymatgen.io.aims.inputs import AimsGeometryIn, AimsCube, AimsControlIn\n",
11-
"from pymatgen.io.aims.outputs import AimsOutput\n",
12-
"\n",
13-
"from pymatgen.core import Structure, Lattice\n",
10+
"from pathlib import Path\n",
11+
"from subprocess import check_call\n",
1412
"\n",
1513
"import numpy as np\n",
14+
"from numpy.testing import assert_allclose\n",
1615
"\n",
17-
"from pathlib import Path\n",
18-
"from subprocess import check_call\n"
16+
"from pymatgen.core import Lattice, Structure\n",
17+
"from pymatgen.io.aims.inputs import AimsControlIn, AimsCube, AimsGeometryIn\n",
18+
"from pymatgen.io.aims.outputs import AimsOutput"
1919
]
2020
},
2121
{
@@ -29,7 +29,7 @@
2929
"AIMS_CMD = \"aims.x\"\n",
3030
"AIMS_OUTPUT = \"aims.out\"\n",
3131
"AIMS_SD = \"species_dir\"\n",
32-
"AIMS_TEST_DIR = \"../../tests/io/aims/species_directory/light/\"\n"
32+
"AIMS_TEST_DIR = \"../../tests/io/aims/species_directory/light/\""
3333
]
3434
},
3535
{
@@ -41,12 +41,10 @@
4141
"source": [
4242
"# Create test structure\n",
4343
"structure = Structure(\n",
44-
" lattice=Lattice(\n",
45-
" np.array([[0, 2.715, 2.715],[2.715, 0, 2.715], [2.715, 2.715, 0]])\n",
46-
" ),\n",
44+
" lattice=Lattice(np.array([[0, 2.715, 2.715], [2.715, 0, 2.715], [2.715, 2.715, 0]])),\n",
4745
" species=[\"Si\", \"Si\"],\n",
48-
" coords=np.array([np.zeros(3), np.ones(3) * 0.25])\n",
49-
")\n"
46+
" coords=np.array([np.zeros(3), np.ones(3) * 0.25]),\n",
47+
")"
5048
]
5149
},
5250
{
@@ -78,9 +76,9 @@
7876
"\n",
7977
"# Cube file output controlled by the AimsCube class\n",
8078
"cont_in[\"cubes\"] = [\n",
81-
" AimsCube(\"total_density\", origin=[0,0,0], points=[11, 11, 11]),\n",
79+
" AimsCube(\"total_density\", origin=[0, 0, 0], points=[11, 11, 11]),\n",
8280
" AimsCube(\"eigenstate_density 1\", origin=[0, 0, 0], points=[11, 11, 11]),\n",
83-
"]\n"
81+
"]"
8482
]
8583
},
8684
{
@@ -91,11 +89,11 @@
9189
"outputs": [],
9290
"source": [
9391
"# Write the input files\n",
94-
"workdir = Path.cwd() / \"workdir/\"\n",
95-
"workdir.mkdir(exist_ok=True)\n",
92+
"work_dir = Path.cwd() / \"workdir/\"\n",
93+
"work_dir.mkdir(exist_ok=True)\n",
9694
"\n",
97-
"geo_in.write_file(workdir, overwrite=True)\n",
98-
"cont_in.write_file(structure, workdir, overwrite=True)\n"
95+
"geo_in.write_file(work_dir, overwrite=True)\n",
96+
"cont_in.write_file(structure, work_dir, overwrite=True)"
9997
]
10098
},
10199
{
@@ -106,8 +104,8 @@
106104
"outputs": [],
107105
"source": [
108106
"# Run the calculation\n",
109-
"with open(f\"{workdir}/{AIMS_OUTPUT}\", \"w\") as outfile:\n",
110-
" aims_run = check_call([AIMS_CMD], cwd=workdir, stdout=outfile)\n"
107+
"with open(f\"{work_dir}/{AIMS_OUTPUT}\", \"w\") as outfile:\n",
108+
" aims_run = check_call([AIMS_CMD], cwd=work_dir, stdout=outfile)"
111109
]
112110
},
113111
{
@@ -118,22 +116,14 @@
118116
"outputs": [],
119117
"source": [
120118
"# Read the aims output file and the final relaxed geometry\n",
121-
"outputs = AimsOutput.from_outfile(f\"{workdir}/{AIMS_OUTPUT}\")\n",
122-
"relaxed_structure = AimsGeometryIn.from_file(f\"{workdir}/geometry.in.next_step\")\n",
119+
"outputs = AimsOutput.from_outfile(f\"{work_dir}/{AIMS_OUTPUT}\")\n",
120+
"relaxed_structure = AimsGeometryIn.from_file(f\"{work_dir}/geometry.in.next_step\")\n",
123121
"\n",
124122
"# Check the results\n",
125123
"assert outputs.get_results_for_image(-1).lattice == relaxed_structure.structure.lattice\n",
126-
"assert np.all(outputs.get_results_for_image(-1).frac_coords == relaxed_structure.structure.frac_coords)\n",
127-
"\n",
128-
"assert np.allclose(\n",
129-
" outputs.get_results_for_image(-1).properties[\"stress\"],\n",
130-
" outputs.stress\n",
131-
")\n",
132-
"\n",
133-
"assert np.allclose(\n",
134-
" outputs.get_results_for_image(-1).site_properties[\"force\"],\n",
135-
" outputs.forces\n",
136-
")\n"
124+
"assert_allclose(outputs.get_results_for_image(-1).frac_coords, relaxed_structure.structure.frac_coords)\n",
125+
"assert_allclose(outputs.get_results_for_image(-1).properties[\"stress\"], outputs.stress)\n",
126+
"assert_allclose(outputs.get_results_for_image(-1).site_properties[\"force\"], outputs.forces)"
137127
]
138128
}
139129
],
@@ -153,7 +143,7 @@
153143
"name": "python",
154144
"nbconvert_exporter": "python",
155145
"pygments_lexer": "ipython3",
156-
"version": "3.9.16"
146+
"version": "3.11.7"
157147
}
158148
},
159149
"nbformat": 4,

pymatgen/analysis/ferroelectricity/polarization.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def get_total_ionic_dipole(structure, zval_dict):
104104

105105

106106
class PolarizationLattice(Structure):
107-
"""Why is a Lattice inheriting a structure? This is ridiculous."""
107+
"""TODO Why is a Lattice inheriting a structure? This is ridiculous."""
108108

109109
def get_nearest_site(self, coords, site, r=None):
110110
"""
@@ -144,14 +144,7 @@ class Polarization:
144144
electron Angstroms along the three lattice directions (a,b,c).
145145
"""
146146

147-
def __init__(
148-
self,
149-
p_elecs,
150-
p_ions,
151-
structures,
152-
p_elecs_in_cartesian=True,
153-
p_ions_in_cartesian=False,
154-
):
147+
def __init__(self, p_elecs, p_ions, structures, p_elecs_in_cartesian=True, p_ions_in_cartesian=False):
155148
"""
156149
p_elecs: np.array of electronic contribution to the polarization with shape [N, 3]
157150
p_ions: np.array of ionic contribution to the polarization with shape [N, 3]
@@ -270,7 +263,7 @@ def get_same_branch_polarization_data(self, convert_to_muC_per_cm2=True, all_in_
270263
lattices = [s.lattice for s in self.structures]
271264
volumes = np.array([s.lattice.volume for s in self.structures])
272265

273-
L = len(p_elec)
266+
n_elecs = len(p_elec)
274267

275268
e_to_muC = -1.6021766e-13
276269
cm2_to_A2 = 1e16
@@ -282,38 +275,41 @@ def get_same_branch_polarization_data(self, convert_to_muC_per_cm2=True, all_in_
282275
# Convert the total polarization
283276
p_tot = np.multiply(units.T[:, np.newaxis], p_tot)
284277
# adjust lattices
285-
for i in range(L):
286-
lattice = lattices[i]
287-
lattices[i] = Lattice.from_parameters(*(np.array(lattice.lengths) * units.ravel()[i]), *lattice.angles)
278+
for idx in range(n_elecs):
279+
lattice = lattices[idx]
280+
lattices[idx] = Lattice.from_parameters(
281+
*(np.array(lattice.lengths) * units.ravel()[idx]), *lattice.angles
282+
)
288283
# convert polarizations to polar lattice
289284
elif convert_to_muC_per_cm2 and all_in_polar:
290285
abc = [lattice.abc for lattice in lattices]
291286
abc = np.array(abc) # [N, 3]
292287
p_tot /= abc # e * Angstroms to e
293288
p_tot *= abc[-1] / volumes[-1] * e_to_muC * cm2_to_A2 # to muC / cm^2
294-
for i in range(L):
289+
for idx in range(n_elecs):
295290
lattice = lattices[-1] # Use polar lattice
296291
# Use polar units (volume)
297-
lattices[i] = Lattice.from_parameters(*(np.array(lattice.lengths) * units.ravel()[-1]), *lattice.angles)
292+
lattices[idx] = Lattice.from_parameters(
293+
*(np.array(lattice.lengths) * units.ravel()[-1]), *lattice.angles
294+
)
298295

299296
d_structs = []
300297
sites = []
301-
for i in range(L):
302-
lattice = lattices[i]
303-
frac_coord = np.divide(np.array([p_tot[i]]), np.array(lattice.lengths))
298+
for idx in range(n_elecs):
299+
lattice = lattices[idx]
300+
frac_coord = np.divide(np.array([p_tot[idx]]), np.array(lattice.lengths))
304301
d = PolarizationLattice(lattice, ["C"], [np.array(frac_coord).ravel()])
305302
d_structs.append(d)
306303
site = d[0]
307304
# Adjust nonpolar polarization to be closest to zero.
308305
# This is compatible with both a polarization of zero or a half quantum.
309-
prev_site = [0, 0, 0] if i == 0 else sites[-1].coords
306+
prev_site = [0, 0, 0] if idx == 0 else sites[-1].coords
310307
new_site = d.get_nearest_site(prev_site, site)
311308
sites.append(new_site[0])
312309

313310
adjust_pol = []
314-
for s, d in zip(sites, d_structs):
315-
lattice = d.lattice
316-
adjust_pol.append(np.multiply(s.frac_coords, np.array(lattice.lengths)).ravel())
311+
for site, d in zip(sites, d_structs):
312+
adjust_pol.append(np.multiply(site.frac_coords, np.array(d.lattice.lengths)).ravel())
317313
return np.array(adjust_pol)
318314

319315
def get_lattice_quanta(self, convert_to_muC_per_cm2=True, all_in_polar=True):
@@ -324,7 +320,7 @@ def get_lattice_quanta(self, convert_to_muC_per_cm2=True, all_in_polar=True):
324320
lattices = [s.lattice for s in self.structures]
325321
volumes = np.array([struct.volume for struct in self.structures])
326322

327-
L = len(self.structures)
323+
n_structs = len(self.structures)
328324

329325
e_to_muC = -1.6021766e-13
330326
cm2_to_A2 = 1e16
@@ -334,13 +330,17 @@ def get_lattice_quanta(self, convert_to_muC_per_cm2=True, all_in_polar=True):
334330
# convert polarizations and lattice lengths prior to adjustment
335331
if convert_to_muC_per_cm2 and not all_in_polar:
336332
# adjust lattices
337-
for i in range(L):
338-
lattice = lattices[i]
339-
lattices[i] = Lattice.from_parameters(*(np.array(lattice.lengths) * units.ravel()[i]), *lattice.angles)
333+
for idx in range(n_structs):
334+
lattice = lattices[idx]
335+
lattices[idx] = Lattice.from_parameters(
336+
*(np.array(lattice.lengths) * units.ravel()[idx]), *lattice.angles
337+
)
340338
elif convert_to_muC_per_cm2 and all_in_polar:
341-
for i in range(L):
339+
for idx in range(n_structs):
342340
lattice = lattices[-1]
343-
lattices[i] = Lattice.from_parameters(*(np.array(lattice.lengths) * units.ravel()[-1]), *lattice.angles)
341+
lattices[idx] = Lattice.from_parameters(
342+
*(np.array(lattice.lengths) * units.ravel()[-1]), *lattice.angles
343+
)
344344

345345
return np.array([np.array(latt.lengths) for latt in lattices])
346346

pymatgen/analysis/local_env.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,12 +2286,12 @@ def __init__(self, types, parameters=None, cutoff=-10.0) -> None:
22862286

22872287
self._comp_azi = False
22882288
self._params = []
2289-
for i, typ in enumerate(self._types):
2290-
d = deepcopy(default_op_params[typ]) if default_op_params[typ] is not None else None
2291-
if parameters is None or parameters[i] is None:
2292-
self._params.append(d)
2289+
for idx, typ in enumerate(self._types):
2290+
dct = deepcopy(default_op_params[typ]) if default_op_params[typ] is not None else None
2291+
if parameters is None or parameters[idx] is None:
2292+
self._params.append(dct)
22932293
else:
2294-
self._params.append(deepcopy(parameters[i]))
2294+
self._params.append(deepcopy(parameters[idx]))
22952295

22962296
self._computerijs = self._computerjks = self._geomops = False
22972297
self._geomops2 = self._boops = False
@@ -3923,20 +3923,20 @@ def get_nn_data(self, structure: Structure, n: int, length=None):
39233923
for entry in nn:
39243924
r2 = _get_radius(entry["site"])
39253925
if r1 > 0 and r2 > 0:
3926-
d = r1 + r2
3926+
dist = r1 + r2
39273927
else:
39283928
warnings.warn(
39293929
"CrystalNN: cannot locate an appropriate radius, "
39303930
"covalent or atomic radii will be used, this can lead "
39313931
"to non-optimal results."
39323932
)
3933-
d = _get_default_radius(structure[n]) + _get_default_radius(entry["site"])
3933+
dist = _get_default_radius(structure[n]) + _get_default_radius(entry["site"])
39343934

39353935
dist = np.linalg.norm(structure[n].coords - entry["site"].coords)
39363936
dist_weight: float = 0
39373937

3938-
cutoff_low = d + self.distance_cutoffs[0]
3939-
cutoff_high = d + self.distance_cutoffs[1]
3938+
cutoff_low = dist + self.distance_cutoffs[0]
3939+
cutoff_high = dist + self.distance_cutoffs[1]
39403940

39413941
if dist <= cutoff_low:
39423942
dist_weight = 1

pymatgen/analysis/molecule_matcher.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,9 @@ def _align_heavy_atoms(mol1, mol2, vmol1, vmol2, ilabel1, ilabel2, eq_atoms):
402402
a2 = aligned_mol2.GetAtom(c2)
403403
for c1 in candidates1:
404404
a1 = canon_mol1.GetAtom(c1)
405-
d = a1.GetDistance(a2)
406-
if d < distance:
407-
distance = d
405+
dist = a1.GetDistance(a2)
406+
if dist < distance:
407+
distance = dist
408408
canon_idx = c1
409409
canon_label2[c2 - 1] = canon_idx
410410
candidates1.remove(canon_idx)
@@ -462,9 +462,9 @@ def _align_hydrogen_atoms(mol1, mol2, heavy_indices1, heavy_indices2):
462462
a2 = cmol2.GetAtom(h2)
463463
for h1 in hydrogen_label1:
464464
a1 = cmol1.GetAtom(h1)
465-
d = a1.GetDistance(a2)
466-
if d < distance:
467-
distance = d
465+
dist = a1.GetDistance(a2)
466+
if dist < distance:
467+
distance = dist
468468
idx = h1
469469
hydrogen_label2.append(idx)
470470
hydrogen_label1.remove(idx)
@@ -828,11 +828,11 @@ def kabsch(P: np.ndarray, Q: np.ndarray):
828828
V, _S, WT = np.linalg.svd(C)
829829

830830
# Getting the sign of the det(V*Wt) to decide whether
831-
d = np.linalg.det(np.dot(V, WT))
831+
det = np.linalg.det(np.dot(V, WT))
832832

833833
# And finally calculating the optimal rotation matrix R
834834
# we need to correct our rotation matrix to ensure a right-handed coordinate system.
835-
return np.dot(np.dot(V, np.diag([1, 1, d])), WT)
835+
return np.dot(np.dot(V, np.diag([1, 1, det])), WT)
836836

837837

838838
class BruteForceOrderMatcher(KabschMatcher):

pymatgen/analysis/piezo_sensitivity.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ def get_unstable_FCM(self, max_force=1):
377377
D[3 * op[1] : 3 * op[1] + 3, 3 * op[0] : 3 * op[0] + 3] = np.zeros([3, 3])
378378

379379
for symop in op[4]:
380-
tempfcm = D[3 * op[2] : 3 * op[2] + 3, 3 * op[3] : 3 * op[3] + 3]
381-
tempfcm = symop.transform_tensor(tempfcm)
382-
D[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] += tempfcm
380+
temp_fcm = D[3 * op[2] : 3 * op[2] + 3, 3 * op[3] : 3 * op[3] + 3]
381+
temp_fcm = symop.transform_tensor(temp_fcm)
382+
D[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] += temp_fcm
383383

384384
if len(op[4]) != 0:
385385
D[3 * op[0] : 3 * op[0] + 3, 3 * op[1] : 3 * op[1] + 3] = D[

pymatgen/io/aims/inputs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,13 @@ def __post_init__(self) -> None:
306306
ValueError: If any of the inputs is invalid
307307
"""
308308
split_type = self.type.split()
309+
cube_type = split_type[0]
309310
if split_type[0] in ALLOWED_AIMS_CUBE_TYPES:
310311
if len(split_type) > 1:
311-
msg = f"Cube of type {split_type[0]} can not have a state associated with it"
312-
raise ValueError(msg)
312+
raise ValueError(f"{cube_type=} can not have a state associated with it")
313313
elif split_type[0] in ALLOWED_AIMS_CUBE_TYPES_STATE:
314314
if len(split_type) != 2:
315-
msg = f"Cube of type {split_type[0]} must have a state associated with it"
316-
raise ValueError(msg)
315+
raise ValueError(f"{cube_type=} must have a state associated with it")
317316
else:
318317
raise ValueError("Cube type undefined")
319318

pymatgen/io/vasp/outputs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,7 +1931,7 @@ def __init__(self, filename):
19311931
self.read_pattern({key: rf"{key}\s+=\s+([\d\-\.]+)"})
19321932
if not self.data[key]:
19331933
continue
1934-
final_energy_contribs[key] = sum(float(f) for f in self.data[key][-1])
1934+
final_energy_contribs[key] = sum(map(float, self.data[key][-1]))
19351935
self.final_energy_contribs = final_energy_contribs
19361936

19371937
def read_pattern(self, patterns, reverse=False, terminate_on_match=False, postprocess=str):
@@ -2053,7 +2053,7 @@ def read_electrostatic_potential(self):
20532053

20542054
pattern = {"radii": r"the test charge radii are((?:\s+[\.\-\d]+)+)"}
20552055
self.read_pattern(pattern, reverse=True, terminate_on_match=True, postprocess=str)
2056-
self.sampling_radii = [float(f) for f in self.data["radii"][0][0].split()]
2056+
self.sampling_radii = [*map(float, self.data["radii"][0][0].split())]
20572057

20582058
header_pattern = r"\(the norm of the test charge is\s+[\.\-\d]+\)"
20592059
table_pattern = r"((?:\s+\d+\s*[\.\-\d]+)+)"
@@ -2064,7 +2064,7 @@ def read_electrostatic_potential(self):
20642064

20652065
pots = re.findall(r"\s+\d+\s*([\.\-\d]+)+", pots)
20662066

2067-
self.electrostatic_potential = [float(f) for f in pots]
2067+
self.electrostatic_potential = [*map(float, pots)]
20682068

20692069
@staticmethod
20702070
def _parse_sci_notation(line):

0 commit comments

Comments
 (0)