|
| 1 | +"""Implementation of `aiida_common_workflows.common.relax.generator.CommonRelaxInputGenerator` for pyscf.""" |
| 2 | +import pathlib |
| 3 | +import warnings |
| 4 | + |
| 5 | +import yaml |
| 6 | +from aiida import engine, orm, plugins |
| 7 | + |
| 8 | +from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType |
| 9 | +from aiida_common_workflows.generators import ChoiceType, CodeType |
| 10 | + |
| 11 | +from ..generator import CommonRelaxInputGenerator |
| 12 | + |
| 13 | +__all__ = ('PyscfCommonRelaxInputGenerator',) |
| 14 | + |
| 15 | +StructureData = plugins.DataFactory('structure') |
| 16 | + |
| 17 | + |
| 18 | +class PyscfCommonRelaxInputGenerator(CommonRelaxInputGenerator): |
| 19 | + """Input generator for the common relax workflow implementation of pyscf.""" |
| 20 | + |
| 21 | + def __init__(self, *args, **kwargs): |
| 22 | + """Construct an instance of the input generator, validating the class attributes.""" |
| 23 | + process_class = kwargs.get('process_class', None) |
| 24 | + super().__init__(*args, **kwargs) |
| 25 | + self._initialize_protocols() |
| 26 | + |
| 27 | + def _initialize_protocols(self): |
| 28 | + """Initialize the protocols class attribute by parsing them from the configuration file.""" |
| 29 | + with (pathlib.Path(__file__).parent / 'protocol.yml').open() as handle: |
| 30 | + self._protocols = yaml.safe_load(handle) |
| 31 | + self._default_protocol = 'moderate' |
| 32 | + |
| 33 | + @classmethod |
| 34 | + def define(cls, spec): |
| 35 | + """Define the specification of the input generator. |
| 36 | +
|
| 37 | + The ports defined on the specification are the inputs that will be accepted by the ``get_builder`` method. |
| 38 | + """ |
| 39 | + super().define(spec) |
| 40 | + spec.inputs['spin_type'].valid_type = ChoiceType((SpinType.NONE, SpinType.COLLINEAR)) |
| 41 | + spec.inputs['relax_type'].valid_type = ChoiceType((RelaxType.NONE, RelaxType.POSITIONS)) |
| 42 | + spec.inputs['electronic_type'].valid_type = ChoiceType((ElectronicType.METAL, ElectronicType.INSULATOR)) |
| 43 | + spec.inputs['engines']['relax']['code'].valid_type = CodeType('pyscf.base') |
| 44 | + |
| 45 | + def _construct_builder( |
| 46 | + self, |
| 47 | + structure, |
| 48 | + engines, |
| 49 | + protocol, |
| 50 | + spin_type, |
| 51 | + relax_type, |
| 52 | + electronic_type, |
| 53 | + magnetization_per_site=None, |
| 54 | + **kwargs, |
| 55 | + ) -> engine.ProcessBuilder: |
| 56 | + """Construct a process builder based on the provided keyword arguments. |
| 57 | +
|
| 58 | + The keyword arguments will have been validated against the input generator specification. |
| 59 | + """ |
| 60 | + if not self.is_valid_protocol(protocol): |
| 61 | + raise ValueError( |
| 62 | + f'selected protocol {protocol} is not valid, please choose from: {", ".join(self.get_protocol_names())}' |
| 63 | + ) |
| 64 | + |
| 65 | + protocol_inputs = self.get_protocol(protocol) |
| 66 | + parameters = protocol_inputs.pop('parameters') |
| 67 | + |
| 68 | + if relax_type == RelaxType.NONE: |
| 69 | + parameters.pop('optimizer') |
| 70 | + |
| 71 | + if spin_type == SpinType.COLLINEAR: |
| 72 | + parameters['mean_field']['method'] = 'DKS' |
| 73 | + parameters['mean_field']['collinear'] = 'mcol' |
| 74 | + |
| 75 | + num_electrons = structure.get_pymatgen_molecule().nelectrons |
| 76 | + |
| 77 | + if spin_type == SpinType.NONE and num_electrons % 2 == 1: |
| 78 | + raise ValueError('structure has odd number of electrons, please select `spin_type = SpinType.COLLINEAR`') |
| 79 | + |
| 80 | + if spin_type == SpinType.COLLINEAR: |
| 81 | + if magnetization_per_site is None: |
| 82 | + multiplicity = 1 |
| 83 | + else: |
| 84 | + warnings.warn('magnetization_per_site site-resolved info is disregarded, only total spin is processed.') |
| 85 | + # ``magnetization_per_site`` is in units of Bohr magnetons, multiple by 0.5 to get atomic units |
| 86 | + total_spin = 0.5 * abs(sum(magnetization_per_site)) |
| 87 | + multiplicity = 2 * total_spin + 1 |
| 88 | + |
| 89 | + # In case of even/odd electrons, find closest odd/even multiplicity |
| 90 | + if num_electrons % 2 == 0: |
| 91 | + # round guess to nearest odd integer |
| 92 | + spin_multiplicity = int(round((multiplicity - 1) / 2) * 2 + 1) |
| 93 | + else: |
| 94 | + # round guess to nearest even integer; 0 goes to 2 |
| 95 | + spin_multiplicity = max([int(round(multiplicity / 2) * 2), 2]) |
| 96 | + |
| 97 | + parameters['structure']['spin'] = int((spin_multiplicity - 1) / 2) |
| 98 | + |
| 99 | + builder = self.process_class.get_builder() |
| 100 | + builder.pyscf.code = engines['relax']['code'] |
| 101 | + builder.pyscf.structure = structure |
| 102 | + builder.pyscf.parameters = orm.Dict(parameters) |
| 103 | + builder.pyscf.metadata.options = engines['relax']['options'] |
| 104 | + |
| 105 | + return builder |
0 commit comments