Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix max_iterations not overridden and add tests for PhBaseWorkChain protocols #984

Merged
merged 2 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/aiida_quantumespresso/workflows/ph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aiida.plugins import CalculationFactory

from aiida_quantumespresso.calculations.functions.merge_ph_outputs import merge_ph_outputs
from aiida_quantumespresso.common.types import ElectronicType
from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin

PhCalculation = CalculationFactory('quantumespresso.ph')
Expand Down Expand Up @@ -64,7 +65,16 @@ def get_protocol_filepath(cls):
return files(ph_protocols) / 'base.yaml'

@classmethod
def get_builder_from_protocol(cls, code, parent_folder=None, protocol=None, overrides=None, options=None, **_):
def get_builder_from_protocol(
cls,
code,
parent_folder=None,
protocol=None,
overrides=None,
electronic_type=ElectronicType.METAL,
options=None,
**_
):
"""Return a builder prepopulated with inputs selected according to the chosen protocol.

:param code: the ``Code`` instance configured for the ``quantumespresso.ph`` plugin.
Expand All @@ -73,6 +83,7 @@ def get_builder_from_protocol(cls, code, parent_folder=None, protocol=None, over
:param overrides: optional dictionary of inputs to override the defaults of the protocol.
:param options: A dictionary of options that will be recursively set for the ``metadata.options`` input of all
the ``CalcJobs`` that are nested in this work chain.
:param electronic_type: indicate the electronic character of the system through ``ElectronicType`` instance.
:return: a process builder instance with all inputs defined ready for launch.
"""
from aiida_quantumespresso.workflows.protocols.utils import recursive_merge
Expand All @@ -81,9 +92,16 @@ def get_builder_from_protocol(cls, code, parent_folder=None, protocol=None, over
code = orm.load_code(code)

type_check(code, orm.AbstractCode)
type_check(electronic_type, ElectronicType)

if electronic_type not in [ElectronicType.METAL, ElectronicType.INSULATOR]:
raise NotImplementedError(f'electronic type `{electronic_type}` is not supported.')

inputs = cls.get_protocol_inputs(protocol, overrides)

if electronic_type is ElectronicType.INSULATOR:
inputs['ph']['parameters']['INPUTPH']['epsil'] = True

qpoints_mesh = inputs['ph'].pop('qpoints')
qpoints = orm.KpointsData()
qpoints.set_kpoints_mesh(qpoints_mesh)
Expand All @@ -104,6 +122,7 @@ def get_builder_from_protocol(cls, code, parent_folder=None, protocol=None, over
builder.ph['settings'] = orm.Dict(inputs['ph']['settings'])
builder.clean_workdir = orm.Bool(inputs['clean_workdir'])
builder.ph['qpoints'] = qpoints
builder.max_iterations = orm.Int(inputs['max_iterations'])
# pylint: enable=no-member

return builder
Expand Down
1 change: 1 addition & 0 deletions src/aiida_quantumespresso/workflows/protocols/ph/base.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
default_inputs:
clean_workdir: False
max_iterations: 5
ph:
metadata:
options:
Expand Down
1 change: 1 addition & 0 deletions src/aiida_quantumespresso/workflows/protocols/pw/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ default_inputs:
clean_workdir: False
kpoints_distance: 0.15
kpoints_force_parity: False
max_iterations: 5
meta_parameters:
conv_thr_per_atom: 0.2e-9
etot_conv_thr_per_atom: 1.e-5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ default_inputs:
clean_workdir: False
kpoints_distance: 0.15
kpoints_force_parity: False
max_iterations: 5
xspectra:
parameters:
INPUT_XSPECTRA:
Expand Down
1 change: 1 addition & 0 deletions src/aiida_quantumespresso/workflows/pw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def get_builder_from_protocol(
else:
builder.kpoints_distance = orm.Float(inputs['kpoints_distance'])
builder.kpoints_force_parity = orm.Bool(inputs['kpoints_force_parity'])
builder.max_iterations = orm.Int(inputs['max_iterations'])
# pylint: enable=no-member

return builder
Expand Down
1 change: 1 addition & 0 deletions src/aiida_quantumespresso/workflows/xspectra/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def get_builder_from_protocol(
if 'settings' in inputs['xspectra']:
builder.xspectra['settings'] = orm.Dict(inputs['xspectra']['settings'])
builder.clean_workdir = orm.Bool(inputs['clean_workdir'])
builder.max_iterations = orm.Int(inputs['max_iterations'])
# pylint: enable=no-member

return builder
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def generate_remote_data():
"""Return a `RemoteData` node."""

def _generate_remote_data(computer, remote_path, entry_point_name=None):
"""Return a `KpointsData` with a mesh of npoints in each direction."""
"""Return a `RemoteData` node."""
from aiida.common.links import LinkType
from aiida.orm import CalcJobNode, RemoteData
from aiida.plugins.entry_point import format_entry_point_string
Expand Down
Empty file.
98 changes: 98 additions & 0 deletions tests/workflows/protocols/ph/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
# pylint: disable=no-member,redefined-outer-name
"""Tests for the ``PhBaseWorkChain.get_builder_from_protocol`` method."""
from aiida.engine import ProcessBuilder
import pytest

from aiida_quantumespresso.common.types import ElectronicType
from aiida_quantumespresso.workflows.ph.base import PhBaseWorkChain


def test_get_available_protocols():
"""Test ``PhBaseWorkChain.get_available_protocols``."""
protocols = PhBaseWorkChain.get_available_protocols()
assert sorted(protocols.keys()) == ['fast', 'moderate', 'precise']
assert all('description' in protocol for protocol in protocols.values())


def test_get_default_protocol():
"""Test ``PhBaseWorkChain.get_default_protocol``."""
assert PhBaseWorkChain.get_default_protocol() == 'moderate'


def test_default(fixture_code, data_regression, serialize_builder):
"""Test ``PhBaseWorkChain.get_builder_from_protocol`` for the default protocol."""
code = fixture_code('quantumespresso.ph')
builder = PhBaseWorkChain.get_builder_from_protocol(code)

assert isinstance(builder, ProcessBuilder)
data_regression.check(serialize_builder(builder))


def test_electronic_type(fixture_code):
"""Test ``PhBaseWorkChain.get_builder_from_protocol`` with ``electronic_type`` keyword."""
code = fixture_code('quantumespresso.ph')

with pytest.raises(NotImplementedError):
for electronic_type in [ElectronicType.AUTOMATIC]:
PhBaseWorkChain.get_builder_from_protocol(code, electronic_type=electronic_type)

builder = PhBaseWorkChain.get_builder_from_protocol(code, electronic_type=ElectronicType.INSULATOR)
parameters = builder.ph.parameters.get_dict() # pylint: disable=no-member

assert parameters['INPUTPH']['epsil']


def test_parameter_overrides(fixture_code):
"""Test specifying parameter ``overrides`` for the ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.ph')

overrides = {'ph': {'parameters': {'INPUTHP': {'nmix_ph': 20}}}}
builder = PhBaseWorkChain.get_builder_from_protocol(code, overrides=overrides)
assert builder.ph.parameters['INPUTHP']['nmix_ph'] == 20 # pylint: disable=no-member


def test_settings_overrides(fixture_code):
"""Test specifying settings ``overrides`` for the ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.ph')

overrides = {'ph': {'settings': {'cmdline': ['--kickass-mode']}}}
builder = PhBaseWorkChain.get_builder_from_protocol(code, overrides=overrides)
assert builder.ph.settings['cmdline'] == ['--kickass-mode'] # pylint: disable=no-member


def test_metadata_overrides(fixture_code):
"""Test specifying metadata ``overrides`` for the ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.ph')

overrides = {'ph': {'metadata': {'options': {'resources': {'num_machines': 1e90}, 'max_wallclock_seconds': 1}}}}
builder = PhBaseWorkChain.get_builder_from_protocol(code, overrides=overrides)
metadata = builder.ph.metadata # pylint: disable=no-member

assert metadata['options']['resources']['num_machines'] == 1e90
assert metadata['options']['max_wallclock_seconds'] == 1


def test_options(fixture_code):
"""Test specifying ``options`` for the ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.ph')

queue_name = 'super-fast'
withmpi = False # The protocol default is ``True``

options = {'queue_name': queue_name, 'withmpi': withmpi}
builder = PhBaseWorkChain.get_builder_from_protocol(code, options=options)
metadata = builder.ph.metadata # pylint: disable=no-member

assert metadata['options']['queue_name'] == queue_name
assert metadata['options']['withmpi'] == withmpi


def test_parent_folder(fixture_code, generate_remote_data, fixture_localhost, fixture_sandbox):
"""Test specifying ``options`` for the ``get_builder_from_protocol()`` method."""
code = fixture_code('quantumespresso.ph')
remote_folder = generate_remote_data(fixture_localhost, fixture_sandbox.abspath, 'quantumespresso.pw')

builder = PhBaseWorkChain.get_builder_from_protocol(code, parent_folder=remote_folder)

assert builder.ph.parent_folder == remote_folder # pylint: disable=no-member
20 changes: 20 additions & 0 deletions tests/workflows/protocols/ph/test_base/test_default.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
clean_workdir: false
max_iterations: 5
ph:
code: test.quantumespresso.ph@localhost
metadata:
options:
max_wallclock_seconds: 43200
resources:
num_machines: 1
withmpi: true
parameters:
INPUTPH:
tr2_ph: 1.0e-18
qpoints:
- - 3
- 3
- 3
- - 0.0
- 0.0
- 0.0
3 changes: 3 additions & 0 deletions tests/workflows/protocols/pw/test_bands/test_default.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
bands:
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down Expand Up @@ -39,6 +40,7 @@ relax:
base:
kpoints_distance: 0.15
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down Expand Up @@ -77,6 +79,7 @@ relax:
scf:
kpoints_distance: 0.15
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down
1 change: 1 addition & 0 deletions tests/workflows/protocols/pw/test_base/test_default.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
clean_workdir: false
kpoints_distance: 0.15
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down
2 changes: 2 additions & 0 deletions tests/workflows/protocols/pw/test_relax/test_default.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
base:
kpoints_distance: 0.15
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down Expand Up @@ -36,6 +37,7 @@ base:
base_final_scf:
kpoints_distance: 0.15
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down
2 changes: 2 additions & 0 deletions tests/workflows/protocols/test_pdos/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dos:
nscf:
kpoints_distance: 0.1
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down Expand Up @@ -55,6 +56,7 @@ projwfc:
scf:
kpoints_distance: 0.15
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ kpoints:
- - 0.0
- 0.0
- 0.0
max_iterations: 5
xspectra:
code: test.quantumespresso.xspectra@localhost
core_wfc_data: '# number of core states 3 = 1 0; 2 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ get_powder_spectrum: false
scf:
kpoints_distance: 0.15
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ core:
clean_workdir: false
kpoints_distance: 0.15
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down Expand Up @@ -79,6 +80,7 @@ relax:
base:
kpoints_distance: 0.15
kpoints_force_parity: false
max_iterations: 5
pw:
code: test.quantumespresso.pw@localhost
metadata:
Expand Down
Loading