Skip to content

Commit

Permalink
Fix io.cp2k.input.DataFile (#3745)
Browse files Browse the repository at this point in the history
* fix cp2k input

* test_openff.py file replace assert np.allclose with assert_allclose for better err msg

* rename variables for readability in tests/io/cp2k/test_inputs.py

* make pymatgen.io.cp2k.inputs DataFile.from_str an abc.abstractmethod

* add TestDataFile and test PotentialFile.from_file

---------

Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
DanielYang59 and janosh authored Apr 11, 2024
1 parent 7064c43 commit 9337a4e
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 94 deletions.
27 changes: 16 additions & 11 deletions pymatgen/electronic_structure/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2972,8 +2972,10 @@ def plot_power_factor_mu(
a matplotlib object
"""
ax = pretty_plot(9, 7)
pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp]
ax.semilogy(self._bz.mu_steps, pf, linewidth=3.0)
pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output=output, doping_levels=False)[
temp
]
ax.semilogy(self._bz.mu_steps, pow_factor, linewidth=3.0)
self._plot_bg_limits(ax)
self._plot_doping(ax, temp)
if output == "eig":
Expand Down Expand Up @@ -3145,20 +3147,20 @@ def plot_power_factor_temp(self, doping="all", output="average", relaxation_time
a matplotlib object
"""
if output == "average":
pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average")
pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average")
elif output == "eigs":
pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs")
pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs")

ax = pretty_plot(22, 14)
tlist = sorted(pf["n"])
tlist = sorted(pow_factor["n"])
doping = self._bz.doping["n"] if doping == "all" else doping
for idx, doping_type in enumerate(["n", "p"]):
plt.subplot(121 + idx)
for dop in doping:
dop_idx = self._bz.doping[doping_type].index(dop)
pf_temp = []
for temp in tlist:
pf_temp.append(pf[doping_type][temp][dop_idx])
pf_temp.append(pow_factor[doping_type][temp][dop_idx])
if output == "average":
ax.plot(tlist, pf_temp, marker="s", label=f"{dop} $cm^{-3}$")
elif output == "eigs":
Expand Down Expand Up @@ -3387,22 +3389,25 @@ def plot_power_factor_dop(self, temps="all", output="average", relaxation_time=1
a matplotlib object
"""
if output == "average":
pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average")
pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average")
elif output == "eigs":
pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs")
pow_factor = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs")

tlist = sorted(pf["n"]) if temps == "all" else temps
tlist = sorted(pow_factor["n"]) if temps == "all" else temps
ax = pretty_plot(22, 14)
for i, dt in enumerate(["n", "p"]):
plt.subplot(121 + i)
for temp in tlist:
if output == "eigs":
for xyz in range(3):
ax.semilogx(
self._bz.doping[dt], list(zip(*pf[dt][temp]))[xyz], marker="s", label=f"{xyz} {temp} K"
self._bz.doping[dt],
list(zip(*pow_factor[dt][temp]))[xyz],
marker="s",
label=f"{xyz} {temp} K",
)
elif output == "average":
ax.semilogx(self._bz.doping[dt], pf[dt][temp], marker="s", label=f"{temp} K")
ax.semilogx(self._bz.doping[dt], pow_factor[dt][temp], marker="s", label=f"{temp} K")
ax.set_title(dt + "-type", fontsize=20)
if i == 0:
ax.set_ylabel("Power Factor ($\\mu$W/(mK$^2$))", fontsize=30.0)
Expand Down
19 changes: 10 additions & 9 deletions pymatgen/io/cp2k/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from __future__ import annotations

import abc
import copy
import hashlib
import itertools
Expand Down Expand Up @@ -2749,17 +2750,17 @@ class DataFile(MSONable):
objects: Sequence | None = None

@classmethod
def from_file(cls, filename) -> None:
"""Load from a file."""
raise NotImplementedError
# with open(filename, encoding="utf-8") as file:
# data = cls.from_str(file.read())
# for obj in data.objects:
# obj.filename = filename
# return data
def from_file(cls, filename) -> Self:
"""Load from a file, reserved for child classes."""
with open(filename, encoding="utf-8") as file:
data = cls.from_str(file.read()) # type: ignore[call-arg]
for obj in data.objects: # type: ignore[attr-defined]
obj.filename = filename
return data # type: ignore[return-value]

@classmethod
def from_str(cls) -> None:
@abc.abstractmethod
def from_str(cls, string: str) -> None:
"""Initialize from a string."""
raise NotImplementedError

Expand Down
148 changes: 79 additions & 69 deletions tests/io/cp2k/test_inputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal
from pytest import approx

Expand All @@ -10,6 +11,7 @@
BasisInfo,
Coord,
Cp2kInput,
DataFile,
GaussianTypeOrbitalBasisSet,
GthPotential,
Keyword,
Expand All @@ -22,7 +24,7 @@
)
from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest

Si_structure = Structure(
si_struct = Structure(
lattice=[[0, 2.734364, 2.734364], [2.734364, 0, 2.734364], [2.734364, 2.734364, 0]],
species=["Si", "Si"],
coords=[[0, 0, 0], [0.25, 0.25, 0.25]],
Expand All @@ -34,9 +36,9 @@
coords=[[-1, -1, -1]],
)

molecule = Molecule(species=["C", "H"], coords=[[0, 0, 0], [1, 1, 1]])
ch_mol = Molecule(species=["C", "H"], coords=[[0, 0, 0], [1, 1, 1]])

basis = """
BASIS_FILE_STR = """
H SZV-MOLOPT-GTH SZV-MOLOPT-GTH-q1
1
2 0 0 7 1
Expand All @@ -48,40 +50,46 @@
0.066918004004 0.037148121400
0.021708243634 -0.001125195500
"""
all_hydrogen = """
ALL_HYDROGEN_STR = """
H ALLELECTRON ALL
1 0 0
0.20000000 0
"""
pot_hydrogen = """
POT_HYDROGEN_STR = """
H GTH-PBE-q1 GTH-PBE
1
0.20000000 2 -4.17890044 0.72446331
0
"""
CP2K_INPUT_STR = """
&GLOBAL
RUN_TYPE ENERGY
PROJECT_NAME CP2K ! default name
&END
"""


class TestBasisAndPotential(PymatgenTest):
def test_basis_info(self):
# Ensure basis metadata can be read from string
b = BasisInfo.from_str("cc-pc-DZVP-MOLOPT-q1-SCAN")
assert b.valence == 2
assert b.molopt
assert b.electrons == 1
assert b.polarization == 1
assert b.cc
assert b.pc
assert b.xc == "SCAN"

# Ensure one-way softmatching works
b2 = BasisInfo.from_str("cc-pc-DZVP-MOLOPT-q1")
assert b2.softmatch(b)
assert not b.softmatch(b2)

b3 = BasisInfo.from_str("cpFIT3")
assert b3.valence == 3
assert b3.polarization == 1
assert b3.contracted, True
basis_info = BasisInfo.from_str("cc-pc-DZVP-MOLOPT-q1-SCAN")
assert basis_info.valence == 2
assert basis_info.molopt
assert basis_info.electrons == 1
assert basis_info.polarization == 1
assert basis_info.cc
assert basis_info.pc
assert basis_info.xc == "SCAN"

# Ensure one-way soft-matching works
basis_info2 = BasisInfo.from_str("cc-pc-DZVP-MOLOPT-q1")
assert basis_info2.softmatch(basis_info)
assert not basis_info.softmatch(basis_info2)

basis_info3 = BasisInfo.from_str("cpFIT3")
assert basis_info3.valence == 3
assert basis_info3.polarization == 1
assert basis_info3.contracted, True

def test_potential_info(self):
# Ensure potential metadata can be read from string
Expand All @@ -90,17 +98,17 @@ def test_potential_info(self):
assert pot_info.xc == "PBE"
assert pot_info.nlcc

# Ensure one-way softmatching works
p2 = PotentialInfo.from_str("GTH-q1-NLCC")
assert p2.softmatch(pot_info)
assert not pot_info.softmatch(p2)
# Ensure one-way soft-matching works
pot_info2 = PotentialInfo.from_str("GTH-q1-NLCC")
assert pot_info2.softmatch(pot_info)
assert not pot_info.softmatch(pot_info2)

def test_basis(self):
# Ensure cp2k formatted string can be read for data correctly
mol_opt = GaussianTypeOrbitalBasisSet.from_str(basis)
mol_opt = GaussianTypeOrbitalBasisSet.from_str(BASIS_FILE_STR)
assert mol_opt.nexp == [7]
# Basis file can read from strings
bf = BasisFile.from_str(basis)
bf = BasisFile.from_str(BASIS_FILE_STR)
for obj in [mol_opt, bf.objects[0]]:
assert_allclose(
obj.exponents[0],
Expand All @@ -125,17 +133,22 @@ def test_basis(self):

def test_potentials(self):
# Ensure cp2k formatted string can be read for data correctly
h_all_elec = GthPotential.from_str(all_hydrogen)
h_all_elec = GthPotential.from_str(ALL_HYDROGEN_STR)
assert h_all_elec.potential == "All Electron"
pot = GthPotential.from_str(pot_hydrogen)
pot = GthPotential.from_str(POT_HYDROGEN_STR)
assert pot.potential == "Pseudopotential"
assert pot.r_loc == approx(0.2)
assert pot.nexp_ppl == approx(2)
assert_allclose(pot.c_exp_ppl, [-4.17890044, 0.72446331])

# Basis file can read from strings
pf = PotentialFile.from_str(pot_hydrogen)
assert pf.objects[0] == pot
pot_file = PotentialFile.from_str(POT_HYDROGEN_STR)
assert pot_file.objects[0] == pot

pot_file_path = self.tmp_path / "potential-file"
pot_file_path.write_text(POT_HYDROGEN_STR)
pot_from_file = PotentialFile.from_file(pot_file_path)
assert pot_file != pot_from_file # unequal because pot_from_file has filename != None

# Ensure keyword can be properly generated
kw = pot.get_keyword()
Expand All @@ -149,27 +162,21 @@ def setUp(self):
self.ci = Cp2kInput.from_file(f"{TEST_FILES_DIR}/cp2k/cp2k.inp")

def test_basic_sections(self):
cp2k_input_str = """
&GLOBAL
RUN_TYPE ENERGY
PROJECT_NAME CP2K ! default name
&END
"""
cp2k_input = Cp2kInput.from_str(cp2k_input_str)
cp2k_input = Cp2kInput.from_str(CP2K_INPUT_STR)
assert cp2k_input["GLOBAL"]["RUN_TYPE"] == Keyword("RUN_TYPE", "energy")
assert cp2k_input["GLOBAL"]["PROJECT_NAME"].description == "default name"
self.assert_msonable(cp2k_input)

def test_section_list(self):
s1 = Section("TEST")
sl = SectionList(sections=[s1, s1])
for s in sl:
sec1 = Section("TEST")
sec_list = SectionList(sections=[sec1, sec1])
for s in sec_list:
assert isinstance(s, Section)
assert sl[0].name == "TEST"
assert sl[1].name == "TEST"
assert len(sl) == 2
sl += s1
assert len(sl) == 3
assert sec_list[0].name == "TEST"
assert sec_list[1].name == "TEST"
assert len(sec_list) == 2
sec_list += sec1
assert len(sec_list) == 3

def test_basic_keywords(self):
kwd = Keyword("TEST1", 1, 2)
Expand All @@ -181,14 +188,14 @@ def test_basic_keywords(self):
assert "[Ha]" in kwd.get_str()

def test_coords(self):
for struct in [nonsense_struct, Si_structure, molecule]:
for struct in [nonsense_struct, si_struct, ch_mol]:
coords = Coord(struct)
for c in coords.keywords.values():
assert isinstance(c, (Keyword, KeywordList))
for val in coords.keywords.values():
assert isinstance(val, (Keyword, KeywordList))

def test_kind(self):
for s in [nonsense_struct, Si_structure, molecule]:
for spec in s.species:
for struct in [nonsense_struct, si_struct, ch_mol]:
for spec in struct.species:
assert spec == Kind(spec).specie

def test_ci_file(self):
Expand All @@ -205,20 +212,20 @@ def test_ci_file(self):

def test_odd_file(self):
scramble = ""
for s in self.ci.get_str():
for string in self.ci.get_str():
if np.random.rand(1) > 0.5:
if s == "\t":
if string == "\t":
scramble += " "
elif s == " ":
elif string == " ":
scramble += " "
elif s in ("&", "\n"):
scramble += s
elif s.isalpha():
scramble += s.lower()
elif string in ("&", "\n"):
scramble += string
elif string.isalpha():
scramble += string.lower()
else:
scramble += s
scramble += string
else:
scramble += s
scramble += string
# Can you initialize from jumbled input
# should be case insensitive and ignore
# excessive white space or tabs
Expand All @@ -236,13 +243,7 @@ def test_preprocessor(self):
assert self.ci["FORCE_EVAL"]["DFT"]["SCF"]["MAX_SCF"] == Keyword("MAX_SCF", 1)

def test_mongo(self):
cp2k_input_str = """
&GLOBAL
RUN_TYPE ENERGY
PROJECT_NAME CP2K ! default name
&END
"""
cp2k_input = Cp2kInput.from_str(cp2k_input_str)
cp2k_input = Cp2kInput.from_str(CP2K_INPUT_STR)
cp2k_input.inc({"GLOBAL": {"TEST": 1}})
assert cp2k_input["global"]["test"] == Keyword("TEST", 1)

Expand All @@ -252,3 +253,12 @@ def test_mongo(self):
cp2k_input.set({"GLOBAL": {"SUBSEC": {"TEST2": 2}, "SUBSEC2": {"Test2": 1}}})
assert cp2k_input.check("global/SUBSEC")
assert cp2k_input.check("global/subsec2")


class TestDataFile(PymatgenTest):
def test_data_file(self):
# make temp file with BASIS_FILE_STR
data_file = self.tmp_path / "data-file"
data_file.write_text(BASIS_FILE_STR)
with pytest.raises(NotImplementedError):
DataFile.from_file(data_file)
Loading

0 comments on commit 9337a4e

Please sign in to comment.