Skip to content

Commit c1129d0

Browse files
committed
auto-calculate ADDED_MOS based on basissets
1 parent 05d8199 commit c1129d0

File tree

5 files changed

+218
-1
lines changed

5 files changed

+218
-1
lines changed

aiida_cp2k/calculations/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
validate_pseudos_namespace,
2323
write_basissets,
2424
write_pseudos,
25+
estimate_added_mos,
2526
)
2627
from ..utils import Cp2kInput
2728

@@ -139,7 +140,7 @@ def prepare_for_submission(self, folder):
139140
:return: `aiida.common.datastructures.CalcInfo` instance
140141
"""
141142

142-
# pylint: disable=too-many-statements,too-many-branches
143+
# pylint: disable=too-many-statements,too-many-branches,too-many-locals
143144

144145
# Create cp2k input file.
145146
inp = Cp2kInput(self.inputs.parameters.get_dict())
@@ -167,6 +168,21 @@ def prepare_for_submission(self, folder):
167168
self.inputs.structure if 'structure' in self.inputs else None)
168169
write_basissets(inp, self.inputs.basissets, folder)
169170

171+
# if we have both basissets and structure we can start helping the user :)
172+
if 'basissets' in self.inputs and 'structure' in self.inputs:
173+
try:
174+
scf_section = inp.get_section_dict('FORCE_EVAL/DFT/SCF')
175+
176+
if 'SMEAR' in scf_section and 'ADDED_MOS' not in scf_section:
177+
# now is our time to shine!
178+
added_mos = estimate_added_mos(self.inputs.basissets, self.inputs.structure)
179+
inp.add_keyword('FORCE_EVAL/DFT/SCF/ADDED_MOS', added_mos)
180+
self.logger.info(f'The FORCE_EVAL/DFT/SCF/ADDED_MOS was added with an automatically estimated value'
181+
f' of {added_mos}')
182+
183+
except (KeyError, TypeError): # no SCF, no smearing, or multiple FORCE_EVAL, nothing to do (yet)
184+
pass
185+
170186
if 'pseudos' in self.inputs:
171187
validate_pseudos(inp, self.inputs.pseudos, self.inputs.structure if 'structure' in self.inputs else None)
172188
write_pseudos(inp, self.inputs.pseudos, folder)

aiida_cp2k/utils/datatype_helpers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,31 @@ def validate_basissets(inp, basissets, structure):
160160
kind_sec["ELEMENT"] = bset.element
161161

162162

163+
def estimate_added_mos(basissets, structure, fraction=0.3):
164+
"""Calculate an estimate for ADDED_MOS based on used basis sets"""
165+
166+
symbols = [structure.get_kind(s.kind_name).get_symbols_string() for s in structure.sites]
167+
n_mos = 0
168+
169+
# We are currently overcounting in the following cases:
170+
# * if we get a mix of ORB basissets for the same chemical symbol but different sites
171+
# * if we get multiple basissets for one element (merged within CP2K)
172+
173+
for label, bset in _unpack(basissets):
174+
try:
175+
_, bstype = label.split("_", maxsplit=1)
176+
except ValueError:
177+
bstype = "ORB"
178+
179+
if bstype != "ORB": # ignore non-ORB basissets
180+
continue
181+
182+
n_mos += symbols.count(bset.element) * bset.n_orbital_functions
183+
184+
# at least one additional MO per site, otherwise a fraction of the total number of orbital functions
185+
return max(len(symbols), int(fraction * n_mos))
186+
187+
163188
def write_basissets(inp, basissets, folder):
164189
"""Writes the unified BASIS_SETS file with the used basissets"""
165190
_write_gdt(inp, basissets, folder, "BASIS_SET_FILE_NAME", "BASIS_SETS")

aiida_cp2k/utils/input_generator.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,71 @@ def add_keyword(self, kwpath, value, override=True, conflicting_keys=None):
5656

5757
Cp2kInput._add_keyword(kwpath, value, self._params, ovrd=override, cfct=conflicting_keys)
5858

59+
@staticmethod
60+
def _stringify_path(kwpath):
61+
"""Stringify a kwpath argument"""
62+
if isinstance(kwpath, str):
63+
return kwpath
64+
65+
assert isinstance(kwpath, Sequence), "path is neither Sequence nor String"
66+
return "/".join(kwpath)
67+
68+
def get_section_dict(self, kwpath=""):
69+
"""Get a copy of a section from the current input structure
70+
71+
Args:
72+
73+
kwpath: Can be a single keyword, a path with `/` as divider for sections & key,
74+
or a sequence with sections and key.
75+
"""
76+
77+
section = self._get_section_or_kw(kwpath)
78+
79+
if not isinstance(section, Mapping):
80+
raise TypeError(f"Section '{self._stringify_path(kwpath)}' requested, but keyword found")
81+
82+
return deepcopy(section)
83+
84+
def get_keyword_value(self, kwpath):
85+
"""Get the value of a keyword from the current input structure
86+
87+
Args:
88+
89+
kwpath: Can be a single keyword, a path with `/` as divider for sections & key,
90+
or a sequence with sections and key.
91+
"""
92+
93+
keyword = self._get_section_or_kw(kwpath)
94+
95+
if isinstance(keyword, Mapping):
96+
raise TypeError(f"Keyword '{self._stringify_path(kwpath)}' requested, but section found")
97+
98+
return keyword
99+
100+
def _get_section_or_kw(self, kwpath):
101+
"""Retrieve either a section or a keyword given a path"""
102+
103+
if isinstance(kwpath, str):
104+
path = kwpath.split("/") # convert to list of sections if string
105+
else:
106+
path = kwpath # if not, assume some sort of sequence
107+
108+
# get a copy of the path in a mutable sequence
109+
# accept any case, but internally we use uppercase
110+
# strip empty strings to accept leading "/", "//", etc.
111+
path = [k.upper() for k in path if k]
112+
113+
# start with a reference to the root of the parameters
114+
current = self._params
115+
116+
try:
117+
while path:
118+
current = current[path.pop(0)]
119+
except KeyError:
120+
raise KeyError(f"Section '{self._stringify_path(kwpath)}' not found in parameters")
121+
122+
return current
123+
59124
def render(self):
60125
output = [self.DISCLAIMER]
61126
self._render_section(output, deepcopy(self._params))

test/test_gaussian_datatypes.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,3 +778,86 @@ def test_without_kinds(cp2k_code, cp2k_basissets, cp2k_pseudos, clear_database):
778778

779779
_, calc_node = run_get_node(CalculationFactory("cp2k"), **inputs)
780780
assert calc_node.exit_status == 0
781+
782+
783+
def test_added_mos(cp2k_code, cp2k_basissets, cp2k_pseudos, clear_database): # pylint: disable=unused-argument
784+
"""Testing CP2K with the Basis Set stored in gaussian.basisset and a smearing section but no predefined ADDED_MOS"""
785+
786+
structure = StructureData(cell=[[4.00759, 0.0, 0.0], [-2.003795, 3.47067475, 0.0],
787+
[3.06349683e-16, 5.30613216e-16, 5.00307]],
788+
pbc=True)
789+
structure.append_atom(position=(-0.00002004, 2.31379473, 0.87543719), symbols="H")
790+
structure.append_atom(position=(2.00381504, 1.15688001, 4.12763281), symbols="H")
791+
structure.append_atom(position=(2.00381504, 1.15688001, 3.37697219), symbols="H")
792+
structure.append_atom(position=(-0.00002004, 2.31379473, 1.62609781), symbols="H")
793+
794+
# parameters
795+
parameters = Dict(
796+
dict={
797+
'GLOBAL': {
798+
'RUN_TYPE': 'ENERGY',
799+
},
800+
'FORCE_EVAL': {
801+
'METHOD': 'Quickstep',
802+
'DFT': {
803+
"XC": {
804+
"XC_FUNCTIONAL": {
805+
"_": "PBE",
806+
},
807+
},
808+
"MGRID": {
809+
"CUTOFF": 100.0,
810+
"REL_CUTOFF": 10.0,
811+
},
812+
"QS": {
813+
"METHOD": "GPW",
814+
"EXTRAPOLATION": "USE_GUESS",
815+
},
816+
"SCF": {
817+
"EPS_SCF": 1e-05,
818+
"MAX_SCF": 3,
819+
"MIXING": {
820+
"METHOD": "BROYDEN_MIXING",
821+
"ALPHA": 0.4,
822+
},
823+
"SMEAR": {
824+
"METHOD": "FERMI_DIRAC",
825+
"ELECTRONIC_TEMPERATURE": 300.0,
826+
},
827+
},
828+
"KPOINTS": {
829+
"SCHEME": "MONKHORST-PACK 2 2 1",
830+
"FULL_GRID": False,
831+
"SYMMETRY": False,
832+
"PARALLEL_GROUP_SIZE": -1,
833+
},
834+
},
835+
},
836+
})
837+
838+
options = {
839+
"resources": {
840+
"num_machines": 1,
841+
"num_mpiprocs_per_machine": 1
842+
},
843+
"max_wallclock_seconds": 1 * 3 * 60,
844+
}
845+
846+
inputs = {
847+
"structure": structure,
848+
"parameters": parameters,
849+
"code": cp2k_code,
850+
"metadata": {
851+
"options": options,
852+
},
853+
"basissets": {label: b for label, b in cp2k_basissets.items() if label == "H"},
854+
"pseudos": {label: p for label, p in cp2k_pseudos.items() if label == "H"},
855+
}
856+
857+
_, calc_node = run_get_node(CalculationFactory("cp2k"), **inputs)
858+
859+
assert calc_node.exit_status == 0
860+
861+
# check that the ADDED_MOS keyword was added within the calculation
862+
with calc_node.open("aiida.inp") as fhandle:
863+
assert any("ADDED_MOS" in line for line in fhandle), "ADDED_MOS not found in the generated CP2K input file"

test/test_input_generator.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,31 @@ def test_invalid_preprocessor():
185185
inp = Cp2kInput({"@SET": "bar"})
186186
with pytest.raises(ValueError):
187187
inp.render()
188+
189+
190+
def test_get_keyword_value():
191+
"""Test get_keyword_value()"""
192+
inp = Cp2kInput({"FOO": "bar", "A": {"KW1": "val1"}})
193+
assert inp.get_keyword_value("FOO") == "bar"
194+
assert inp.get_keyword_value("/FOO") == "bar"
195+
assert inp.get_keyword_value("A/KW1") == "val1"
196+
assert inp.get_keyword_value("/A/KW1") == "val1"
197+
assert inp.get_keyword_value(["A", "KW1"]) == "val1"
198+
with pytest.raises(TypeError):
199+
inp.get_keyword_value("A")
200+
201+
202+
def test_get_section_dict():
203+
"""Test get_section_dict()"""
204+
orig_dict = {"FOO": "bar", "A": {"KW1": "val1"}}
205+
inp = Cp2kInput(orig_dict)
206+
assert inp.get_section_dict("/") == orig_dict
207+
assert inp.get_section_dict("////") == orig_dict
208+
assert inp.get_section_dict("") == orig_dict
209+
assert inp.get_section_dict() == orig_dict
210+
assert inp.get_section_dict("/") is not orig_dict # make sure we get a distinct object
211+
assert inp.get_section_dict("A") == orig_dict["A"]
212+
assert inp.get_section_dict("/A") == orig_dict["A"]
213+
assert inp.get_section_dict(["A"]) == orig_dict["A"]
214+
with pytest.raises(TypeError):
215+
inp.get_section_dict("FOO")

0 commit comments

Comments
 (0)