Skip to content

Commit 303f3df

Browse files
committed
Use custom WorkflowFactory to provide plugin install instructions
The `WorkflowFactory` from `aiida-core` is replaced with a custom version in the `aiida_common_workflows.plugins.factories` module. This function will call the factory from `aiida-core` but catch the `MissingEntryPointError` exception. In this case, if the entry point corresponds to a plugin implementation of one of the common workflows the exception is reraised but with a useful message that provides the user with the install command to install the necessary plugin package. While this should catch all cases of users trying to load a workflow for a plugin that is not installed through its entry point, it won't catch import errors that are raised when a module is imported directly from that plugin package. Therefore, these imports should not be placed at the top of modules, but placed inside functions/methods of the implementation as much as possible.
1 parent 510888f commit 303f3df

File tree

7 files changed

+98
-15
lines changed

7 files changed

+98
-15
lines changed
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
"""Module with utilities for working with the plugins provided by this plugin package."""
22
from .entry_point import get_entry_point_name_from_class, get_workflow_entry_point_names, load_workflow_entry_point
3+
from .factories import WorkflowFactory
34

4-
__all__ = ('get_workflow_entry_point_names', 'get_entry_point_name_from_class', 'load_workflow_entry_point')
5+
__all__ = (
6+
'WorkflowFactory',
7+
'get_workflow_entry_point_names',
8+
'get_entry_point_name_from_class',
9+
'load_workflow_entry_point',
10+
)

src/aiida_common_workflows/plugins/entry_point.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from aiida.plugins import entry_point
55

6+
from .factories import WorkflowFactory
7+
68
PACKAGE_PREFIX = 'common_workflows'
79

810
__all__ = ('get_workflow_entry_point_names', 'get_entry_point_name_from_class', 'load_workflow_entry_point')
@@ -38,5 +40,5 @@ def load_workflow_entry_point(workflow: str, plugin_name: str):
3840
:param plugin_name: name of the plugin implementation.
3941
:return: the workchain class of the plugin implementation of the common workflow.
4042
"""
41-
prefix = f'{PACKAGE_PREFIX}.{workflow}.{plugin_name}'
42-
return entry_point.load_entry_point('aiida.workflows', prefix)
43+
entry_point_name = f'{PACKAGE_PREFIX}.{workflow}.{plugin_name}'
44+
return WorkflowFactory(entry_point_name)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Factories to load entry points."""
2+
import typing as t
3+
4+
from aiida import plugins
5+
from aiida.common import exceptions
6+
7+
if t.TYPE_CHECKING:
8+
from aiida.engine import WorkChain
9+
from importlib_metadata import EntryPoint
10+
11+
__all__ = ('WorkflowFactory',)
12+
13+
14+
@t.overload
15+
def WorkflowFactory(entry_point_name: str, load: t.Literal[True] = True) -> t.Union[t.Type['WorkChain'], t.Callable]:
16+
...
17+
18+
19+
@t.overload
20+
def WorkflowFactory(entry_point_name: str, load: t.Literal[False]) -> 'EntryPoint':
21+
...
22+
23+
24+
def WorkflowFactory(entry_point_name: str, load: bool = True) -> t.Union['EntryPoint', t.Type['WorkChain'], t.Callable]: # noqa: N802
25+
"""Return the `WorkChain` sub class registered under the given entry point.
26+
27+
:param entry_point_name: the entry point name.
28+
:param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself.
29+
:return: sub class of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` or a `workfunction`
30+
:raises aiida.common.MissingEntryPointError: entry point was not registered
31+
:raises aiida.common.MultipleEntryPointError: entry point could not be uniquely resolved
32+
:raises aiida.common.LoadingEntryPointError: entry point could not be loaded
33+
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
34+
"""
35+
common_workflow_prefixes = ('common_workflows.relax.', 'common_workflows.bands.')
36+
try:
37+
return plugins.WorkflowFactory(entry_point_name, load)
38+
except exceptions.MissingEntryPointError as exception:
39+
for prefix in common_workflow_prefixes:
40+
if entry_point_name.startswith(prefix):
41+
plugin_name = entry_point_name.removeprefix(prefix)
42+
raise exceptions.MissingEntryPointError(
43+
f'Could not load the entry point `{entry_point_name}`, probably because the plugin package is not '
44+
f'installed. Please install it with `pip install aiida-common-workflows[{plugin_name}]`.'
45+
) from exception
46+
else: # noqa: PLW0120
47+
raise

src/aiida_common_workflows/workflows/relax/abinit/workchain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from aiida import orm
44
from aiida.common import exceptions
55
from aiida.engine import calcfunction
6-
from aiida_abinit.workflows.base import AbinitBaseWorkChain
6+
from aiida.plugins import WorkflowFactory
77

88
from ..workchain import CommonRelaxWorkChain
99
from .generator import AbinitCommonRelaxInputGenerator
@@ -44,7 +44,7 @@ def get_total_magnetization(parameters):
4444
class AbinitCommonRelaxWorkChain(CommonRelaxWorkChain):
4545
"""Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for Abinit."""
4646

47-
_process_class = AbinitBaseWorkChain
47+
_process_class = WorkflowFactory('abinit.base')
4848
_generator_class = AbinitCommonRelaxInputGenerator
4949

5050
def convert_outputs(self):

src/aiida_common_workflows/workflows/relax/castep/generator.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
import yaml
99
from aiida import engine, orm, plugins
1010
from aiida.common import exceptions
11-
from aiida_castep.data import get_pseudos_from_structure
12-
from aiida_castep.data.otfg import OTFGGroup
1311

1412
from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
1513
from aiida_common_workflows.generators import ChoiceType, CodeType
1614

1715
from ..generator import CommonRelaxInputGenerator
1816

17+
if t.TYPE_CHECKING:
18+
from aiida_castep.data.otfg import OTFGGroup
19+
1920
KNOWN_BUILTIN_FAMILIES = ('C19', 'NCP19', 'QC5', 'C17', 'C9')
2021

2122
__all__ = ('CastepCommonRelaxInputGenerator',)
@@ -247,8 +248,8 @@ def generate_inputs(
247248
:param override: a dictionary to override specific inputs
248249
:return: input dictionary
249250
"""
250-
251251
from aiida.common.lang import type_check
252+
from aiida_castep.data.otfg import OTFGGroup
252253

253254
family_name = protocol['relax']['base']['pseudos_family']
254255
if isinstance(family_name, orm.Str):
@@ -285,7 +286,7 @@ def generate_inputs_relax(
285286
protocol: t.Dict,
286287
code: orm.Code,
287288
structure: orm.StructureData,
288-
otfg_family: OTFGGroup,
289+
otfg_family: 'OTFGGroup',
289290
override: t.Optional[t.Dict[str, t.Any]] = None,
290291
) -> t.Dict[str, t.Any]:
291292
"""Generate the inputs for the `CastepCommonRelaxWorkChain` for a given code, structure and pseudo potential family.
@@ -321,7 +322,7 @@ def generate_inputs_base(
321322
protocol: t.Dict,
322323
code: orm.Code,
323324
structure: orm.StructureData,
324-
otfg_family: OTFGGroup,
325+
otfg_family: 'OTFGGroup',
325326
override: t.Optional[t.Dict[str, t.Any]] = None,
326327
) -> t.Dict[str, t.Any]:
327328
"""Generate the inputs for the `CastepBaseWorkChain` for a given code, structure and pseudo potential family.
@@ -359,7 +360,7 @@ def generate_inputs_calculation(
359360
protocol: t.Dict,
360361
code: orm.Code,
361362
structure: orm.StructureData,
362-
otfg_family: OTFGGroup,
363+
otfg_family: 'OTFGGroup',
363364
override: t.Optional[t.Dict[str, t.Any]] = None,
364365
) -> t.Dict[str, t.Any]:
365366
"""Generate the inputs for the `CastepCalculation` for a given code, structure and pseudo potential family.
@@ -372,6 +373,7 @@ def generate_inputs_calculation(
372373
:return: the fully defined input dictionary.
373374
"""
374375
from aiida_castep.calculations.helper import CastepHelper
376+
from aiida_castep.data import get_pseudos_from_structure
375377

376378
override = {} if not override else override.get('calc', {})
377379
# This merge perserves the merged `parameters` in the override
@@ -415,9 +417,8 @@ def ensure_otfg_family(family_name, force_update=False):
415417
NOTE: CASTEP also supports UPF families, but it is not enabled here, since no UPS based protocol
416418
has been implemented.
417419
"""
418-
419420
from aiida.common import NotExistent
420-
from aiida_castep.data.otfg import upload_otfg_family
421+
from aiida_castep.data.otfg import OTFGGroup, upload_otfg_family
421422

422423
# Ensure family name is a str
423424
if isinstance(family_name, orm.Str):

src/aiida_common_workflows/workflows/relax/quantum_espresso/generator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import yaml
55
from aiida import engine, orm, plugins
6-
from aiida_quantumespresso.workflows.protocols.utils import recursive_merge
76

87
from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
98
from aiida_common_workflows.generators import ChoiceType, CodeType
@@ -108,8 +107,8 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
108107
109108
The keyword arguments will have been validated against the input generator specification.
110109
"""
111-
112110
from aiida_quantumespresso.common import types
111+
from aiida_quantumespresso.workflows.protocols.utils import recursive_merge
113112
from qe_tools import CONSTANTS
114113

115114
structure = kwargs['structure']

tests/test_minimal_install.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
installed. This guarantees that most of the code can be imported without any plugin packages being installed.
55
"""
66
import pytest
7+
from aiida.common import exceptions
8+
from aiida_common_workflows.plugins import WorkflowFactory, get_workflow_entry_point_names
79

810

911
@pytest.mark.minimal_install
@@ -18,3 +20,29 @@ def test_imports():
1820
import aiida_common_workflows.workflows
1921
import aiida_common_workflows.workflows.dissociation
2022
import aiida_common_workflows.workflows.eos # noqa: F401
23+
24+
25+
@pytest.mark.minimal_install
26+
@pytest.mark.parametrize('entry_point_name', get_workflow_entry_point_names('relax'))
27+
def test_workflow_factory_relax(entry_point_name):
28+
"""Test that trying to load common relax workflow implementations will raise if not installed.
29+
30+
The exception message should provide the pip command to install the require plugin package.
31+
"""
32+
plugin_name = entry_point_name.removeprefix('common_workflows.relax.')
33+
match = rf'.*plugin package is not installed.*`pip install aiida-common-workflows\[{plugin_name}\]`.*'
34+
with pytest.raises(exceptions.MissingEntryPointError, match=match):
35+
WorkflowFactory(entry_point_name)
36+
37+
38+
@pytest.mark.minimal_install
39+
@pytest.mark.parametrize('entry_point_name', get_workflow_entry_point_names('bands'))
40+
def test_workflow_factory_bands(entry_point_name):
41+
"""Test that trying to load common bands workflow implementations will raise if not installed.
42+
43+
The exception message should provide the pip command to install the require plugin package.
44+
"""
45+
plugin_name = entry_point_name.removeprefix('common_workflows.bands.')
46+
match = rf'.*plugin package is not installed.*`pip install aiida-common-workflows\[{plugin_name}\]`.*'
47+
with pytest.raises(exceptions.MissingEntryPointError, match=match):
48+
WorkflowFactory(entry_point_name)

0 commit comments

Comments
 (0)