Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ jobs:
strategy:
fail-fast: true
matrix:
wflow_engine: [covalent, dask, parsl, prefect, redun, jobflow]
wflow_engine: [aiida, covalent, dask, parsl, prefect, redun, jobflow]

defaults:
run:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
]

[project.optional-dependencies]
aiida = ["aiida-workgraph>=0.7.6", "aiida-gui-workgraph>=0.1.3"]
covalent = ["covalent>=0.234.1-rc.0; platform_system!='Windows'", "covalent-cloud>=0.39.0; platform_system!='Windows'"]
dask = ["dask[distributed]>=2023.12.1", "dask-jobqueue>=0.8.2"]
db = ["maggma>=0.64.0"]
Expand Down
3 changes: 2 additions & 1 deletion src/quacc/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class QuaccSettings(BaseSettings):
# ---------------------------

WORKFLOW_ENGINE: (
Literal["covalent", "dask", "parsl", "prefect", "redun", "jobflow"] | None
Literal["aiida", "covalent", "dask", "parsl", "prefect", "redun", "jobflow"]
| None
) = Field(None, description=("The workflow manager to use, if any."))

# ---------------------------
Expand Down
17 changes: 17 additions & 0 deletions src/quacc/wflow_tools/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ def wrapper(*f_args, **f_kwargs):
return wrapper
else:
return task(_func, **kwargs)
elif settings.WORKFLOW_ENGINE == "aiida":
from aiida_workgraph import task

@wraps(_func)
def wrapper(*f_args, **f_kwargs):
decorated = task(_func, **kwargs)
return decorated(*f_args, **f_kwargs).result

return wrapper
else:
return _func

Expand Down Expand Up @@ -352,6 +361,10 @@ def workflow(a, b, c):
return task(_func, namespace=_func.__module__, **kwargs)
elif settings.WORKFLOW_ENGINE == "prefect":
return _get_prefect_wrapped_flow(_func, settings, **kwargs)
elif settings.WORKFLOW_ENGINE == "aiida":
from aiida_workgraph import task

return task.graph()(_func, **kwargs)
else:
return _func

Expand Down Expand Up @@ -585,6 +598,10 @@ def wrapper(*f_args, **f_kwargs):
from redun import task

return task(_func, namespace=_func.__module__, **kwargs)
elif settings.WORKFLOW_ENGINE == "aiida":
from aiida_workgraph import task

return task.graph()(_func, **kwargs)
else:
return _func

Expand Down
36 changes: 36 additions & 0 deletions tests/aiida/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import subprocess
from importlib.util import find_spec
from pathlib import Path
from shutil import rmtree

TEST_RESULTS_DIR = Path(__file__).parent / "_test_results"
TEST_SCRATCH_DIR = Path(__file__).parent / "_test_scratch"

has_aiida = bool(find_spec("aiida_workgraph"))

if has_aiida:
subprocess.run(
["verdi", "presto", "--profile-name", "test_profile"],
capture_output=True,
text=True,
check=False,
)

def pytest_sessionstart():
import os

from aiida import load_profile

file_dir = Path(__file__).parent
os.environ["QUACC_CONFIG_FILE"] = str(file_dir / "quacc.yaml")
os.environ["QUACC_RESULTS_DIR"] = str(TEST_RESULTS_DIR)
os.environ["QUACC_SCRATCH_DIR"] = str(TEST_SCRATCH_DIR)

load_profile("test_profile", allow_switch=True)

def pytest_sessionfinish(exitstatus):
rmtree(TEST_RESULTS_DIR, ignore_errors=True)
if exitstatus == 0:
rmtree(TEST_SCRATCH_DIR, ignore_errors=True)
1 change: 1 addition & 0 deletions tests/aiida/quacc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
WORKFLOW_ENGINE: aiida
24 changes: 24 additions & 0 deletions tests/aiida/test_syntax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import pytest

pytest.importorskip("aiida_workgraph")


from quacc import flow, job


def test_aiida_decorators():
@job
def add(a, b):
return a + b

@job
def mult(a, b):
return a * b

@flow
def workflow(a, b, c):
return mult(add(a, b), c)

assert workflow.run(1, 2, 3)["result"] == 9
2 changes: 2 additions & 0 deletions tests/requirements-aiida.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
aiida-workgraph==0.7.6
aiida-gui-workgraph==0.1.3
Loading