diff --git a/src/aiida_quantumespresso/workflows/functions/get_marked_structures.py b/src/aiida_quantumespresso/workflows/functions/get_marked_structures.py new file mode 100644 index 000000000..024b4076a --- /dev/null +++ b/src/aiida_quantumespresso/workflows/functions/get_marked_structures.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +"""CalcFunction to create structures with a marked atom for each site in a list.""" +from aiida import orm +from aiida.common import ValidationError +from aiida.engine import calcfunction +from aiida.orm.nodes.data.structure import Kind, Site, StructureData + + +@calcfunction +def get_marked_structures(structure, atoms_list, marker='X'): + """Read a StructureData object and return structures for XPS calculations. + + :param atoms_list: the atoms_list of atoms to be marked. + :param marker: a Str node defining the name of the marked atom Kind. Default is 'X'. + :returns: StructureData objects for the generated structure. + """ + marker = marker.value + elements_present = [kind.symbol for kind in structure.kinds] + if marker in elements_present: + raise ValidationError( + f'The marker ("{marker}") should not match an existing Kind in ' + f'the input structure ({elements_present}.' + ) + + output_params = {} + result = {} + + for index in atoms_list.get_list(): + marked_structure = StructureData() + kinds = {kind.name: kind for kind in structure.kinds} + marked_structure.set_cell(structure.cell) + + for i, site in enumerate(structure.sites): + if i == index: + marked_kind = Kind(name=marker, symbols=site.kind_name) + marked_site = Site(kind_name=marked_kind.name, position=site.position) + marked_structure.append_kind(marked_kind) + marked_structure.append_site(marked_site) + output_params[f'site_{index}'] = {'symbol': site.kind_name, 'multiplicity': 1} + else: + if site.kind_name not in [kind.name for kind in marked_structure.kinds]: + marked_structure.append_kind(kinds[site.kind_name]) + new_site = Site(kind_name=site.kind_name, position=site.position) + marked_structure.append_site(new_site) + result[f'site_{index}'] = marked_structure + + result['output_parameters'] = orm.Dict(dict=output_params) + + return result diff --git a/src/aiida_quantumespresso/workflows/xps.py b/src/aiida_quantumespresso/workflows/xps.py index c0c13c481..1f6491456 100644 --- a/src/aiida_quantumespresso/workflows/xps.py +++ b/src/aiida_quantumespresso/workflows/xps.py @@ -28,21 +28,21 @@ def validate_inputs(inputs, _): """Validate the inputs before launching the WorkChain.""" structure = inputs['structure'] elements_present = [kind.name for kind in structure.kinds] - absorbing_elements_list = sorted(inputs['elements_list']) abs_atom_marker = inputs['abs_atom_marker'].value if abs_atom_marker in elements_present: raise ValidationError( f'The marker given for the absorbing atom ("{abs_atom_marker}") matches an existing Kind in the ' f'input structure ({elements_present}).' ) - - if inputs['calc_binding_energy'].value: - ce_list = sorted(inputs['correction_energies'].get_dict().keys()) - if ce_list != absorbing_elements_list: - raise ValidationError( - f'The ``correction_energies`` provided ({ce_list}) does not match the list of' - f' absorbing elements ({absorbing_elements_list})' - ) + if 'elements_list' in inputs: + absorbing_elements_list = sorted(inputs['elements_list']) + if inputs['calc_binding_energy'].value: + ce_list = sorted(inputs['correction_energies'].get_dict().keys()) + if ce_list != absorbing_elements_list: + raise ValidationError( + f'The ``correction_energies`` provided ({ce_list}) does not match the list of' + f' absorbing elements ({absorbing_elements_list})' + ) class XpsWorkChain(ProtocolMixin, WorkChain): @@ -81,7 +81,7 @@ def define(cls, spec): spec.expose_inputs( PwBaseWorkChain, namespace='ch_scf', - exclude=('kpoints', 'pw.structure'), + exclude=('pw.structure', ), namespace_options={ 'help': ('Input parameters for the basic xps workflow (core-hole SCF).'), 'validator': None @@ -170,6 +170,14 @@ def define(cls, spec): 'The list of elements to be considered for analysis, each must be valid elements of the periodic table.' ) ) + spec.input( + 'atoms_list', + valid_type=orm.List, + required=False, + help=( + 'The indices of atoms to be considered for analysis.' + ) + ) spec.input( 'calc_binding_energy', valid_type=orm.Bool, @@ -233,12 +241,14 @@ def define(cls, spec): spec.output( 'supercell_structure', valid_type=orm.StructureData, + required=False, help=('The supercell of ``outputs.standardized_structure`` used to generate structures for' ' XPS sub-processes.') ) spec.output( 'symmetry_analysis_data', valid_type=orm.Dict, + required=False, help='The output parameters from ``get_xspectra_structures()``.' ) spec.output( @@ -366,8 +376,8 @@ def get_treatment_filepath(cls): @classmethod def get_builder_from_protocol( cls, code, structure, pseudos, core_hole_treatments=None, protocol=None, - overrides=None, elements_list=None, options=None, - structure_preparation_settings=None, **kwargs + overrides=None, elements_list=None, atoms_list=None, options=None, + structure_preparation_settings=None, correction_energies=None, **kwargs ): """Return a builder prepopulated with inputs selected according to the chosen protocol. @@ -386,9 +396,6 @@ def get_builder_from_protocol( """ inputs = cls.get_protocol_inputs(protocol, overrides) - calc_binding_energy = kwargs.pop('calc_binding_energy', False) - correction_energies = kwargs.pop('correction_energies', orm.Dict()) - pw_args = (code, structure, protocol) # xspectra_args = (pw_code, xs_code, structure, protocol, upf2plotcore_code) @@ -412,8 +419,11 @@ def get_builder_from_protocol( builder.ch_scf = ch_scf builder.structure = structure builder.abs_atom_marker = abs_atom_marker - builder.calc_binding_energy = calc_binding_energy - builder.correction_energies = correction_energies + if correction_energies: + builder.correction_energies = orm.Dict(correction_energies) + builder.calc_binding_energy = orm.Bool(True) + else: + builder.calc_binding_energy = orm.Bool(False) builder.clean_workdir = orm.Bool(inputs['clean_workdir']) core_hole_pseudos = {} gipaw_pseudos = {} @@ -434,6 +444,12 @@ def get_builder_from_protocol( for element in elements_list: core_hole_pseudos[element] = pseudos[element]['core_hole'] gipaw_pseudos[element] = pseudos[element]['gipaw'] + elif atoms_list: + builder.atoms_list = orm.List(atoms_list) + for index in atoms_list: + element = structure.sites[index].kind_name + core_hole_pseudos[element] = pseudos[element]['core_hole'] + gipaw_pseudos[element] = pseudos[element]['gipaw'] # if no elements list is given, we instead initalise the pseudos dict with all # elements in the structure else: @@ -453,12 +469,18 @@ def get_builder_from_protocol( def setup(self): """Init required context variables.""" - custom_elements_list = self.inputs.get('elements_list', None) - if not custom_elements_list: + elements_list = self.inputs.get('elements_list', None) + atoms_list = self.inputs.get('atoms_list', None) + if elements_list: + self.ctx.elements_list = elements_list.get_list() + self.ctx.atoms_list = None + elif atoms_list: + self.ctx.atoms_list = atoms_list.get_list() + self.ctx.elements_list = None + else: structure = self.inputs.structure self.ctx.elements_list = [Kind.symbol for Kind in structure.kinds] - else: - self.ctx.elements_list = custom_elements_list.get_list() + def should_run_relax(self): @@ -511,48 +533,59 @@ def prepare_structures(self): formatted as { : } for each variable in the ``get_symmetry_dataset()`` method. """ + from aiida_quantumespresso.workflows.functions.get_marked_structures import get_marked_structures from aiida_quantumespresso.workflows.functions.get_xspectra_structures import get_xspectra_structures - elements_list = orm.List(self.ctx.elements_list) - inputs = { - 'absorbing_elements_list' : elements_list, - 'absorbing_atom_marker' : self.inputs.abs_atom_marker, - 'metadata' : { - 'call_link_label' : 'get_xspectra_structures' + input_structure = self.inputs.structure if 'relax' not in self.inputs else self.ctx.relaxed_structure + if self.ctx.elements_list: + elements_list = orm.List(self.ctx.elements_list) + inputs = { + 'absorbing_elements_list' : elements_list, + 'absorbing_atom_marker' : self.inputs.abs_atom_marker, + 'metadata' : { + 'call_link_label' : 'get_xspectra_structures' + } + } # populate this further once the schema for WorkChain options is figured out + if 'structure_preparation_settings' in self.inputs: + optional_cell_prep = self.inputs.structure_preparation_settings + for key, node in optional_cell_prep.items(): + inputs[key] = node + if 'spglib_settings' in self.inputs: + spglib_settings = self.inputs.spglib_settings + inputs['spglib_settings'] = spglib_settings + else: + spglib_settings = None + + result = get_xspectra_structures(input_structure, **inputs) + + supercell = result.pop('supercell') + out_params = result.pop('output_parameters') + if out_params.get_dict().get('structure_is_standardized', None): + standardized = result.pop('standardized_structure') + self.out('standardized_structure', standardized) + + # structures_to_process = {Key : Value for Key, Value in result.items()} + for site in ['output_parameters', 'supercell', 'standardized_structure']: + result.pop(site, None) + self.ctx.supercell = supercell + self.ctx.equivalent_sites_data = out_params['equivalent_sites_data'] + self.out('supercell_structure', supercell) + self.out('symmetry_analysis_data', out_params) + elif self.ctx.atoms_list: + atoms_list = orm.List(self.ctx.atoms_list) + inputs = { + 'atoms_list' : atoms_list, + 'marker' : self.inputs.abs_atom_marker, + 'metadata' : { + 'call_link_label' : 'get_marked_structures' + } } - } # populate this further once the schema for WorkChain options is figured out - if 'structure_preparation_settings' in self.inputs: - optional_cell_prep = self.inputs.structure_preparation_settings - for key, node in optional_cell_prep.items(): - inputs[key] = node - if 'spglib_settings' in self.inputs: - spglib_settings = self.inputs.spglib_settings - inputs['spglib_settings'] = spglib_settings - else: - spglib_settings = None - - if 'relax' in self.inputs: - relaxed_structure = self.ctx.relaxed_structure - result = get_xspectra_structures(relaxed_structure, **inputs) - else: - result = get_xspectra_structures(self.inputs.structure, **inputs) - - supercell = result.pop('supercell') - out_params = result.pop('output_parameters') - if out_params.get_dict().get('structure_is_standardized', None): - standardized = result.pop('standardized_structure') - self.out('standardized_structure', standardized) - - # structures_to_process = {Key : Value for Key, Value in result.items()} - for site in ['output_parameters', 'supercell', 'standardized_structure']: - result.pop(site, None) + result = get_marked_structures(input_structure, **inputs) + self.ctx.supercell = input_structure + self.ctx.equivalent_sites_data = result.pop('output_parameters').get_dict() structures_to_process = {f'{Key.split("_")[0]}_{Key.split("_")[1]}' : Value for Key, Value in result.items()} - self.ctx.supercell = supercell + self.report(f'structures_to_process: {structures_to_process}') self.ctx.structures_to_process = structures_to_process - self.ctx.equivalent_sites_data = out_params['equivalent_sites_data'] - - self.out('supercell_structure', supercell) - self.out('symmetry_analysis_data', out_params) def should_run_gs_scf(self): """If the 'calc_binding_energy' input namespace is True, we run a scf calculation for the supercell.""" @@ -566,9 +599,9 @@ def run_gs_scf(self): inputs.metadata.call_link_label = 'supercell_xps' inputs = prepare_process_inputs(PwBaseWorkChain, inputs) - equivalent_sites_data = self.ctx.equivalent_sites_data - for site in equivalent_sites_data: - abs_element = equivalent_sites_data[site]['symbol'] + # pseudos for all elements to be calculated should be replaced + for site in self.ctx.equivalent_sites_data: + abs_element = self.ctx.equivalent_sites_data[site]['symbol'] inputs.pw.pseudos[abs_element] = self.inputs.gipaw_pseudos[abs_element] running = self.submit(PwBaseWorkChain, **inputs) @@ -600,7 +633,6 @@ def run_all_scf(self): equivalent_sites_data = self.ctx.equivalent_sites_data abs_atom_marker = self.inputs.abs_atom_marker.value - for site in structures_to_process: inputs = AttributeDict(self.exposed_inputs(PwBaseWorkChain, namespace='ch_scf')) structure = structures_to_process[site] @@ -630,9 +662,10 @@ def run_all_scf(self): core_hole_pseudo = self.inputs.core_hole_pseudos[abs_element] inputs.pw.pseudos[abs_atom_marker] = core_hole_pseudo - # all element in the elements_list should be replaced - for element in self.inputs.elements_list: - inputs.pw.pseudos[element] = self.inputs.gipaw_pseudos[element] + # pseudos for all elements to be calculated should be replaced + for key in self.ctx.equivalent_sites_data: + abs_element = self.ctx.equivalent_sites_data[key]['symbol'] + inputs.pw.pseudos[abs_element] = self.inputs.gipaw_pseudos[abs_element] # remove pseudo if the only element is replaced by the marker inputs.pw.pseudos = {kind.name: inputs.pw.pseudos[kind.name] for kind in structure.kinds} @@ -674,11 +707,15 @@ def results(self): kwargs['correction_energies'] = self.inputs.correction_energies kwargs['metadata'] = {'call_link_label' : 'compile_final_spectra'} - equivalent_sites_data = orm.Dict(dict=self.ctx.equivalent_sites_data) - elements_list = orm.List(list=self.ctx.elements_list) + if self.ctx.elements_list: + elements_list = orm.List(list=self.ctx.elements_list) + else: + symbols = {value['symbol'] for value in self.ctx.equivalent_sites_data.values()} + elements_list = orm.List(list(symbols)) voight_gamma = self.inputs.voight_gamma voight_sigma = self.inputs.voight_sigma + equivalent_sites_data = orm.Dict(dict=self.ctx.equivalent_sites_data) result = get_spectra_by_element( elements_list, equivalent_sites_data, diff --git a/tests/workflows/functions/test_get_marked_structures.py b/tests/workflows/functions/test_get_marked_structures.py new file mode 100644 index 000000000..42d5c845f --- /dev/null +++ b/tests/workflows/functions/test_get_marked_structures.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +"""Tests for the `get_marked_structure` class.""" + + +def test_get_marked_structure(): + """Test the get_marked_structure function.""" + from aiida.orm import List, StructureData + from ase.build import molecule + + from aiida_quantumespresso.workflows.functions.get_marked_structures import get_marked_structures + + mol = molecule('CH3CH2OH') + mol.center(vacuum=2.0) + structure = StructureData(ase=mol) + indices = List(list=[0, 1, 2]) + output = get_marked_structures(structure, indices) + assert len(output) == 4 + assert output['site_0'].get_site_kindnames() == ['X', 'C', 'O', 'H', 'H', 'H', 'H', 'H', 'H'] + assert output['site_1'].get_site_kindnames() == ['C', 'X', 'O', 'H', 'H', 'H', 'H', 'H', 'H']