Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/aiida/common/workgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Adapter module for optional ``aiida-workgraph`` integration.

``aiida-workgraph`` is an **optional** dependency. This module centralises the
availability check and type detection so that other modules (e.g.
``aiida.engine.launch``) do not need to handle optional-dependency imports
themselves.

Type checkers will report missing-import warnings here; these are expected
and unavoidable for this optional-dependency pattern.
"""

# mypy: disable-error-code="unused-ignore"
# mypy raises import-not-found when aiida-workgraph is absent and
# import-untyped when it is present. Both codes are needed in type-ignore
# comments, but one will always be unused, so we suppress that check.
from __future__ import annotations

import typing as t

try:
import aiida_workgraph # type: ignore[import-not-found,import-untyped] # noqa: F401

WORKGRAPH_INSTALLED = True
except ImportError:
WORKGRAPH_INSTALLED = False


def is_workgraph_instance(obj: t.Any) -> bool:
"""Check if an object is a WorkGraph instance.

This helper exists so that call sites (e.g. ``aiida.engine.launch``) do not
need ``from aiida_workgraph import WorkGraph`` and the associated
``try/except ImportError`` boilerplate. The import is confined to this module
instead.

Guard with ``WORKGRAPH_INSTALLED`` first to separate the availability
question from the type question.

:param obj: The object to check.
:return: True if obj is a WorkGraph instance.
:raises ImportError: if aiida-workgraph is not installed.
"""
from aiida_workgraph import WorkGraph # type: ignore[import-not-found,import-untyped]

return isinstance(obj, WorkGraph)
77 changes: 66 additions & 11 deletions src/aiida/engine/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from aiida.common import InvalidOperation
from aiida.common.lang import type_check
from aiida.common.log import AIIDA_LOGGER
from aiida.common.workgraph import WORKGRAPH_INSTALLED, is_workgraph_instance
from aiida.engine.runners import Runner
from aiida.manage import manager
from aiida.orm import ProcessNode

Expand All @@ -41,9 +43,20 @@ def run(process: TYPE_RUN_PROCESS, inputs: dict[str, t.Any] | None = None, **kwa
:return: the outputs of the process
"""
if isinstance(process, Process):
runner = process.runner
else:
runner = manager.get_manager().get_runner()
return process.runner.run(process, inputs, **kwargs)

runner: Runner = manager.get_manager().get_runner()

if WORKGRAPH_INSTALLED and is_workgraph_instance(process):
# Typed as Any because WorkGraph is an optional dependency that cannot appear in TYPE_RUN_PROCESS.
workgraph: t.Any = process
# WorkGraph converts itself into a (ProcessClass, inputs) pair via the adapter pattern.
process_class, engine_inputs = workgraph.prepare_for_launch(inputs, **kwargs)
# Use run_get_node (not run) because we need the node for update_after_launch.
result, node = runner.run_get_node(process_class, engine_inputs)
# Store the process node reference on the WorkGraph so it can track the launched process.
workgraph.update_after_launch(node)
return result

return runner.run(process, inputs, **kwargs)

Expand All @@ -58,9 +71,16 @@ def run_get_node(
:return: tuple of the outputs of the process and the process node
"""
if isinstance(process, Process):
runner = process.runner
else:
runner = manager.get_manager().get_runner()
return process.runner.run_get_node(process, inputs, **kwargs)

runner: Runner = manager.get_manager().get_runner()

if WORKGRAPH_INSTALLED and is_workgraph_instance(process):
workgraph: t.Any = process
process_class, engine_inputs = workgraph.prepare_for_launch(inputs, **kwargs)
result, node = runner.run_get_node(process_class, engine_inputs)
workgraph.update_after_launch(node) # Store process node reference on the WorkGraph
return result, node

return runner.run_get_node(process, inputs, **kwargs)

Expand All @@ -73,9 +93,16 @@ def run_get_pk(process: TYPE_RUN_PROCESS, inputs: dict[str, t.Any] | None = None
:return: tuple of the outputs of the process and process node pk
"""
if isinstance(process, Process):
runner = process.runner
else:
runner = manager.get_manager().get_runner()
return process.runner.run_get_pk(process, inputs, **kwargs)

runner: Runner = manager.get_manager().get_runner()

if WORKGRAPH_INSTALLED and is_workgraph_instance(process):
workgraph: t.Any = process
process_class, engine_inputs = workgraph.prepare_for_launch(inputs, **kwargs)
result, node = runner.run_get_node(process_class, engine_inputs)
workgraph.update_after_launch(node) # Store process node reference on the WorkGraph
return ResultAndPk(result, node.pk)

return runner.run_get_pk(process, inputs, **kwargs)

Expand All @@ -86,6 +113,7 @@ def submit(
*,
wait: bool = False,
wait_interval: int = 5,
timeout: int | None = None,
**kwargs: t.Any,
) -> ProcessNode:
"""Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter.
Expand All @@ -100,14 +128,26 @@ def submit(
:param wait: when set to ``True``, the submission will be blocking and wait for the process to complete at which
point the function returns the calculation node.
:param wait_interval: the number of seconds to wait between checking the state of the process when ``wait=True``.
:param timeout: optional timeout in seconds when ``wait=True``. If the process does not terminate within this time,
a ``TimeoutError`` is raised. If ``None`` (default), waits indefinitely.
:param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument.
:return: the calculation node of the process
:raises TimeoutError: if ``wait=True`` and the process does not terminate within ``timeout`` seconds.
"""
from aiida.common.docs import URL_NO_BROKER

inputs = prepare_inputs(inputs, **kwargs)
# Unlike the run functions, submit cannot early-return after prepare_for_launch because the workgraph
# post-processing (update_after_launch) must happen after the standard submit flow completes.
# Typed as Any because WorkGraph is an optional dependency that cannot appear in TYPE_SUBMIT_PROCESS.
workgraph: t.Any = None

# Submitting from within another process requires ``self.submit``` unless it is a work function, in which case the
if WORKGRAPH_INSTALLED and is_workgraph_instance(process):
workgraph = process
process, inputs = workgraph.prepare_for_launch(inputs, **kwargs)
# Clear kwargs since they were already consumed by prepare_for_launch.
kwargs = {}

# Submitting from within another process requires ``self.submit`` unless it is a work function, in which case the
# current process in the scope should be an instance of ``FunctionProcess``.
if is_process_scoped() and not isinstance(Process.current(), FunctionProcess):
raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead')
Expand All @@ -125,6 +165,10 @@ def submit(

assert runner.persister is not None, 'runner does not have a persister'

# Moved below the WorkGraph guard: prepare_for_launch already consumes both inputs and kwargs,
# so prepare_inputs must run after to avoid double-merging.
inputs = prepare_inputs(inputs, **kwargs)

process_inited = instantiate_process(runner, process, **inputs)

# If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this
Expand All @@ -145,15 +189,26 @@ def submit(
node = process_inited.node

if not wait:
if workgraph is not None:
workgraph.update_after_launch(node)
return node

start_time = time.time()

while not node.is_terminated:
if timeout is not None and (time.time() - start_time) >= timeout:
msg = f'Process<{node.pk}> did not terminate within {timeout} seconds.'
raise TimeoutError(msg)

LOGGER.report(
f'Process<{node.pk}> has not yet terminated, current state is `{node.process_state}`. '
f'Waiting for {wait_interval} seconds.'
)
time.sleep(wait_interval)

if workgraph is not None:
workgraph.update_after_launch(node)

return node


Expand Down
73 changes: 73 additions & 0 deletions tests/engine/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,76 @@ def test_calcjob_dry_run_no_provenance(self):
assert 'folder' in node.dry_run_info
for filename in ['path', 'file_one', 'file_two']:
assert filename in os.listdir(node.dry_run_info['folder'])


class TestWorkGraphLaunchers:
"""Test WorkGraph support in launchers using mocks.

These tests verify that the launch functions correctly detect WorkGraph instances,
call ``prepare_for_launch`` to convert them, and then use the standard Process
launch path. All WorkGraph internals are mocked.
"""

@pytest.fixture
def mock_workgraph(self, monkeypatch):
"""Create a mock WorkGraph and patch detection + runner."""
from unittest.mock import MagicMock

from aiida.engine.runners import ResultAndPk

mock_wg = MagicMock()

# prepare_for_launch returns a (process_class, inputs) pair
mock_process_class = MagicMock(spec=Process)
mock_engine_inputs = {'workgraph_data': {}, 'tasks': {}, 'graph_inputs': {}, 'metadata': {}}
mock_wg.prepare_for_launch.return_value = (mock_process_class, mock_engine_inputs)

mock_node = MagicMock(spec=orm.ProcessNode)
mock_node.pk = 123
mock_node.is_terminated = True
mock_node.process_state = 'finished'

# Enable the WorkGraph code path and patch is_workgraph_instance
monkeypatch.setattr('aiida.engine.launch.WORKGRAPH_INSTALLED', True)
monkeypatch.setattr(
'aiida.engine.launch.is_workgraph_instance',
lambda obj: obj is mock_wg,
)

# Patch the runner methods to return our mock results
mock_runner = MagicMock()
mock_runner.run.return_value = {'result': 42}
mock_runner.run_get_node.return_value = ({'result': 42}, mock_node)
mock_runner.run_get_pk.return_value = ResultAndPk({'result': 42}, mock_node.pk)
monkeypatch.setattr(
'aiida.engine.launch.manager.get_manager', lambda: MagicMock(get_runner=lambda: mock_runner)
)

return mock_wg

@pytest.mark.parametrize(
'launcher',
[launch.run, launch.run_get_node, launch.run_get_pk],
ids=['run', 'run_get_node', 'run_get_pk'],
)
def test_workgraph_launcher(self, mock_workgraph, launcher):
"""All run launchers call prepare_for_launch and update_after_launch for WorkGraph."""
launcher(mock_workgraph, inputs={'add': 1})
mock_workgraph.prepare_for_launch.assert_called_once()
mock_workgraph.update_after_launch.assert_called_once()

def test_run_workgraph_passes_inputs_through(self, mock_workgraph):
"""Test that inputs (including metadata) are passed directly to prepare_for_launch."""
launch.run(mock_workgraph, inputs={'add': 1, 'metadata': {'label': 'test'}})
mock_workgraph.prepare_for_launch.assert_called_once_with(
{'add': 1, 'metadata': {'label': 'test'}},
)

def test_run_workgraph_with_kwargs(self, mock_workgraph):
"""Test that kwargs are forwarded to prepare_for_launch."""
launch.run(mock_workgraph, add=1, metadata={'label': 'test'})
mock_workgraph.prepare_for_launch.assert_called_once_with(
None,
add=1,
metadata={'label': 'test'},
)
Loading