diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7ec27c31fa..63588cd8cb 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -63,7 +63,7 @@ jobs: strategy: fail-fast: true matrix: - wflow_engine: [covalent, dask, parsl, prefect, redun, jobflow] + wflow_engine: [aiida, covalent, dask, parsl, prefect, redun, jobflow] defaults: run: diff --git a/pyproject.toml b/pyproject.toml index 1a84b8f7c8..f8346c6421 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ ] [project.optional-dependencies] +aiida = ["aiida-workgraph>=0.7.6", "aiida-gui-workgraph>=0.1.3"] covalent = ["covalent>=0.234.1-rc.0; platform_system!='Windows'", "covalent-cloud>=0.39.0; platform_system!='Windows'"] dask = ["dask[distributed]>=2023.12.1", "dask-jobqueue>=0.8.2"] db = ["maggma>=0.64.0"] diff --git a/src/quacc/settings.py b/src/quacc/settings.py index 95570dbeab..e0101f4ce2 100644 --- a/src/quacc/settings.py +++ b/src/quacc/settings.py @@ -59,7 +59,8 @@ class QuaccSettings(BaseSettings): # --------------------------- WORKFLOW_ENGINE: ( - Literal["covalent", "dask", "parsl", "prefect", "redun", "jobflow"] | None + Literal["aiida", "covalent", "dask", "parsl", "prefect", "redun", "jobflow"] + | None ) = Field(None, description=("The workflow manager to use, if any.")) # --------------------------- diff --git a/src/quacc/wflow_tools/decorators.py b/src/quacc/wflow_tools/decorators.py index f730b5038c..8ad37d9d05 100644 --- a/src/quacc/wflow_tools/decorators.py +++ b/src/quacc/wflow_tools/decorators.py @@ -189,6 +189,15 @@ def wrapper(*f_args, **f_kwargs): return wrapper else: return task(_func, **kwargs) + elif settings.WORKFLOW_ENGINE == "aiida": + from aiida_workgraph import task + + @wraps(_func) + def wrapper(*f_args, **f_kwargs): + decorated = task(_func, **kwargs) + return decorated(*f_args, **f_kwargs).result + + return wrapper else: return _func @@ -352,6 +361,10 @@ def workflow(a, b, c): return task(_func, namespace=_func.__module__, **kwargs) elif settings.WORKFLOW_ENGINE == "prefect": return _get_prefect_wrapped_flow(_func, settings, **kwargs) + elif settings.WORKFLOW_ENGINE == "aiida": + from aiida_workgraph import task + + return task.graph()(_func, **kwargs) else: return _func @@ -585,6 +598,10 @@ def wrapper(*f_args, **f_kwargs): from redun import task return task(_func, namespace=_func.__module__, **kwargs) + elif settings.WORKFLOW_ENGINE == "aiida": + from aiida_workgraph import task + + return task.graph()(_func, **kwargs) else: return _func diff --git a/tests/aiida/conftest.py b/tests/aiida/conftest.py new file mode 100644 index 0000000000..fd0a0d401b --- /dev/null +++ b/tests/aiida/conftest.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import subprocess +from importlib.util import find_spec +from pathlib import Path +from shutil import rmtree + +TEST_RESULTS_DIR = Path(__file__).parent / "_test_results" +TEST_SCRATCH_DIR = Path(__file__).parent / "_test_scratch" + +has_aiida = bool(find_spec("aiida_workgraph")) + +if has_aiida: + subprocess.run( + ["verdi", "presto", "--profile-name", "test_profile"], + capture_output=True, + text=True, + check=False, + ) + + def pytest_sessionstart(): + import os + + from aiida import load_profile + + file_dir = Path(__file__).parent + os.environ["QUACC_CONFIG_FILE"] = str(file_dir / "quacc.yaml") + os.environ["QUACC_RESULTS_DIR"] = str(TEST_RESULTS_DIR) + os.environ["QUACC_SCRATCH_DIR"] = str(TEST_SCRATCH_DIR) + + load_profile("test_profile", allow_switch=True) + + def pytest_sessionfinish(exitstatus): + rmtree(TEST_RESULTS_DIR, ignore_errors=True) + if exitstatus == 0: + rmtree(TEST_SCRATCH_DIR, ignore_errors=True) diff --git a/tests/aiida/quacc.yaml b/tests/aiida/quacc.yaml new file mode 100644 index 0000000000..7916f1b853 --- /dev/null +++ b/tests/aiida/quacc.yaml @@ -0,0 +1 @@ +WORKFLOW_ENGINE: aiida diff --git a/tests/aiida/test_syntax.py b/tests/aiida/test_syntax.py new file mode 100644 index 0000000000..a25573bb09 --- /dev/null +++ b/tests/aiida/test_syntax.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import pytest + +pytest.importorskip("aiida_workgraph") + + +from quacc import flow, job + + +def test_aiida_decorators(): + @job + def add(a, b): + return a + b + + @job + def mult(a, b): + return a * b + + @flow + def workflow(a, b, c): + return mult(add(a, b), c) + + assert workflow.run(1, 2, 3)["result"] == 9 diff --git a/tests/requirements-aiida.txt b/tests/requirements-aiida.txt new file mode 100644 index 0000000000..9204aa9f35 --- /dev/null +++ b/tests/requirements-aiida.txt @@ -0,0 +1,2 @@ +aiida-workgraph==0.7.6 +aiida-gui-workgraph==0.1.3