Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 21, 2024
1 parent db64728 commit b9297ea
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
52 changes: 31 additions & 21 deletions aiida_nanotech_empa/workflows/cp2k/adsorbed_gw_ic_workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from aiida import engine, orm

from ...utils import common_utils

from .geo_opt_workchain import Cp2kGeoOptWorkChain
from .molecule_gw_workchain import Cp2kMoleculeGwWorkChain

Expand All @@ -11,14 +10,20 @@
}


def geometrical_analysis(ase_geo, substr_elem,mol_atoms=None):
def geometrical_analysis(ase_geo, substr_elem, mol_atoms=None):
"""Simple geometry analysis that returns in the case of
1) an isolated molecule -> geometry, None
2) adsorbed system -> molecular geometry, top substr. layer z
"""

if mol_atoms is not None:
s_atoms = ase_geo[[atom.index for atom in ase_geo if atom.index not in mol_atoms and atom.symbol == substr_elem]]
s_atoms = ase_geo[
[
atom.index
for atom in ase_geo
if atom.index not in mol_atoms and atom.symbol == substr_elem
]
]
else:
chem_symbols_arr = np.array(ase_geo.get_chemical_symbols())
s_atoms = ase_geo[chem_symbols_arr == substr_elem]
Expand All @@ -40,11 +45,15 @@ def geometrical_analysis(ase_geo, substr_elem,mol_atoms=None):


@engine.calcfunction
def analyze_structure(structure, substrate, mag_per_site, ads_h=None,molecule_atoms=None):
def analyze_structure(
structure, substrate, mag_per_site, ads_h=None, molecule_atoms=None
):
ase_geo = structure.get_ase()
substr_elem = substrate.value.split("(")[0]

mol_atoms, surf_z = geometrical_analysis(ase_geo, substr_elem,mol_atoms=molecule_atoms)
mol_atoms, surf_z = geometrical_analysis(
ase_geo, substr_elem, mol_atoms=molecule_atoms
)

if surf_z is None:
if ads_h is None:
Expand Down Expand Up @@ -163,12 +172,12 @@ def define(cls, spec):
)
spec.input("dft_params", valid_type=orm.Dict)
spec.input("sys_params", valid_type=orm.Dict)
#spec.input(
# spec.input(
# "molecule_atoms",
# valid_type=orm.List,
# default=lambda: orm.List(list=[]),
# required=False,
#)
# )
spec.input_namespace(
"options",
valid_type=dict,
Expand Down Expand Up @@ -239,12 +248,12 @@ def define(cls, spec):

def setup(self):
self.report("Inspecting input and setting up things.")

self.ctx.sys_params = self.inputs.sys_params.get_dict()
self.ctx.dft_params = self.inputs.dft_params.get_dict()

n_atoms = len(self.inputs.structure.get_ase())
mags = self.ctx.dft_params.get('magnetization_per_site',[])
mags = self.ctx.dft_params.get("magnetization_per_site", [])
n_mags = len(mags)
if n_mags not in (0, n_atoms):
self.report("If set, magnetization_per_site needs a value for every atom.")
Expand All @@ -253,15 +262,17 @@ def setup(self):
if self.inputs.substrate.value not in IC_PLANE_HEIGHTS:
return self.exit_codes.ERROR_SUBSTR_NOT_SUPPORTED

molecule_atoms = self.ctx.sys_params.get('molecule_atoms',[])#self.inputs.molecule_atoms
molecule_atoms = self.ctx.sys_params.get(
"molecule_atoms", []
) # self.inputs.molecule_atoms
if len(molecule_atoms) == 0:
molecule_atoms = None
an_out = analyze_structure(
self.inputs.structure,
self.inputs.substrate,
mags,
None if "ads_height" not in self.inputs else self.inputs.ads_height,
molecule_atoms=molecule_atoms,
molecule_atoms=molecule_atoms,
)

if "mol_struct" not in an_out:
Expand All @@ -272,8 +283,8 @@ def setup(self):
self.ctx.image_plane_z = an_out["image_plane_z"]
self.ctx.mol_mag_per_site = an_out["mol_mag_per_site"]
if "magnetization_per_site" in self.ctx.dft_params:
self.ctx.dft_params["magnetization_per_site"] = an_out["mol_mag_per_site"]
self.ctx.dft_params["magnetization_per_site"] = an_out["mol_mag_per_site"]

###

return engine.ExitCode(0)
Expand Down Expand Up @@ -318,7 +329,7 @@ def ic(self):
builder.protocol = self.inputs.protocol
builder.structure = self.ctx.mol_struct
builder.magnetization_per_site = self.ctx.mol_mag_per_site
builder.multiplicity = orm.Int(self.ctx.dft_params.get("multiplicity",0))
builder.multiplicity = orm.Int(self.ctx.dft_params.get("multiplicity", 0))
builder.run_image_charge = orm.Bool(True)
builder.z_ic_plane = self.ctx.image_plane_z
builder.options.scf = self.inputs.options.scf
Expand All @@ -339,7 +350,7 @@ def gw(self):
builder.protocol = self.inputs.protocol
builder.structure = self.ctx.mol_struct
builder.magnetization_per_site = self.ctx.mol_mag_per_site
builder.multiplicity = orm.Int(self.ctx.dft_params.get("multiplicity",0))
builder.multiplicity = orm.Int(self.ctx.dft_params.get("multiplicity", 0))
builder.options.scf = self.inputs.options.scf
builder.options.gw = self.inputs.options.gw
builder.metadata.description = "gw"
Expand All @@ -351,13 +362,13 @@ def finalize(self):
if not self.ctx.gw.is_finished_ok:
return self.exit_codes.ERROR_TERMINATION

if hasattr(self.ctx, 'gas_opt') and self.ctx.gas_opt is not None:
if hasattr(self.ctx, "gas_opt") and self.ctx.gas_opt is not None:
gas_geo_opt_params = self.ctx.gas_opt.outputs.output_parameters
self.out("gas_geo_opt_parameters", gas_geo_opt_params)

gw_out_params = self.ctx.gw.outputs.gw_output_parameters
ic_out_params = self.ctx.ic.outputs.gw_output_parameters

self.out("gw_output_parameters", gw_out_params)

self.out("ic_output_parameters", ic_out_params)
Expand All @@ -368,7 +379,6 @@ def finalize(self):

# Add extras.
struc = self.inputs.structure
common_utils.add_extras(struc, "surfaces", self.node.uuid)

common_utils.add_extras(struc, "surfaces", self.node.uuid)

return engine.ExitCode(0)
5 changes: 3 additions & 2 deletions examples/workflows/example_cp2k_ads_gw_ic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import ase,ase.io
import ase
import ase.io
from aiida import engine, orm, plugins

Cp2kAdsorbedGwIcWorkChain = plugins.WorkflowFactory("nanotech_empa.cp2k.ads_gw_ic")
Expand All @@ -27,7 +28,7 @@ def _example_cp2k_ads_gw_ic(cp2k_code, slab_included):
builder.ads_height = orm.Float(3.0)

builder.structure = orm.StructureData(ase=ase_geom)

dft_params = {
"uks": True,
"magnetization_per_site": mag_list,
Expand Down

0 comments on commit b9297ea

Please sign in to comment.