From cc642d0454564c949c79390cff5a5e172c92f01a Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 14 May 2024 11:57:05 +0200 Subject: [PATCH] Changes for pydantic v2 --- aiida_submission_controller/base.py | 14 ++++++---- aiida_submission_controller/from_group.py | 33 ++++++++++++----------- pyproject.toml | 3 +-- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/aiida_submission_controller/base.py b/aiida_submission_controller/base.py index 3a51bcb..984b4bc 100644 --- a/aiida_submission_controller/base.py +++ b/aiida_submission_controller/base.py @@ -6,7 +6,7 @@ from aiida import engine, orm from aiida.common import NotExistent -from pydantic import BaseModel, validator +from pydantic import BaseModel, validator, field_validator from rich import print from rich.console import Console from rich.table import Table @@ -33,7 +33,7 @@ def add_to_nested_dict(nested_dict, key, value): return extras_dict -def validate_group_exists(value: str) -> str: +def _validate_group_exists(value: str) -> str: """Validator that makes sure the ``Group`` with the provided label exists.""" try: orm.Group.collection.get(label=value) @@ -56,12 +56,15 @@ class BaseSubmissionController(BaseModel): unique_extra_keys: Optional[tuple] = None """Tuple of keys defined in the extras that uniquely define each process to be run.""" - _validate_group_exists = validator("group_label", allow_reuse=True)(validate_group_exists) + @field_validator('group_label') + @classmethod + def validate_group_exists(cls, v: str) -> str: + return _validate_group_exists(v) @property def group(self): """Return the AiiDA ORM Group instance that is managed by this class.""" - return orm.Group.objects.get(label=self.group_label) + return orm.Group.collection.get(label=self.group_label) def get_query(self, process_projections, only_active=False): """Return a QueryBuilder object to get all processes in the group associated to this. @@ -233,10 +236,11 @@ def submit_new_batch(self, dry_run=False, sort=False, verbose=False): except Exception as exc: CMDLINE_LOGGER.error(f"Failed to submit work chain for extras <{workchain_extras}>: {exc}") + raise else: CMDLINE_LOGGER.report(f"Submitted work chain <{wc_node}> for extras <{workchain_extras}>.") - wc_node.set_extra_many(get_extras_dict(self.get_extra_unique_keys(), workchain_extras)) + wc_node.base.extras.set_many(get_extras_dict(self.get_extra_unique_keys(), workchain_extras)) self.group.add_nodes([wc_node]) submitted[workchain_extras] = wc_node diff --git a/aiida_submission_controller/from_group.py b/aiida_submission_controller/from_group.py index 1eea6a2..7e787f6 100644 --- a/aiida_submission_controller/from_group.py +++ b/aiida_submission_controller/from_group.py @@ -3,9 +3,9 @@ from typing import Optional from aiida import orm -from pydantic import validator +from pydantic import field_validator -from .base import BaseSubmissionController, validate_group_exists +from .base import BaseSubmissionController, _validate_group_exists class FromGroupSubmissionController(BaseSubmissionController): # pylint: disable=abstract-method @@ -22,12 +22,15 @@ class FromGroupSubmissionController(BaseSubmissionController): # pylint: disabl order_by: Optional[dict] = None """Ordering applied to the query of the nodes in the parent group.""" - _validate_group_exists = validator("parent_group_label", allow_reuse=True)(validate_group_exists) + @field_validator('group_label') + @classmethod + def validate_group_exists(cls, v: str) -> str: + return _validate_group_exists(v) @property def parent_group(self): """Return the AiiDA ORM Group instance of the parent group.""" - return orm.Group.objects.get(label=self.parent_group_label) + return orm.Group.collection.get(label=self.parent_group_label) def get_parent_node_from_extras(self, extras_values): """Return the Node instance (in the parent group) from the (unique) extras identifying it.""" @@ -35,14 +38,14 @@ def get_parent_node_from_extras(self, extras_values): assert len(extras_values) == len(extras_projections), f"The extras must be of length {len(extras_projections)}" filters = dict(zip(extras_projections, extras_values)) - qbuild = orm.QueryBuilder() - qbuild.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group") - qbuild.append(orm.Node, project="*", filters=filters, tag="process", with_group="group") - qbuild.limit(2) - results = qbuild.all(flat=True) + qb = orm.QueryBuilder() + qb.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group") + qb.append(orm.Node, project="*", filters=filters, tag="process", with_group="group") + qb.limit(2) + results = qb.all(flat=True) if len(results) != 1: raise ValueError( - "I would have expected only 1 result for extras={extras}, I found {'>1' if len(qbuild) else '0'}" + "I would have expected only 1 result for extras={extras}, I found {'>1' if len(qb) else '0'}" ) return results[0] @@ -57,9 +60,9 @@ def get_all_extras_to_submit(self): """ extras_projections = self.get_process_extra_projections() - qbuild = orm.QueryBuilder() - qbuild.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group") - qbuild.append( + qb = orm.QueryBuilder() + qb.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group") + qb.append( orm.Node, project=extras_projections, filters=self.filters, @@ -68,9 +71,9 @@ def get_all_extras_to_submit(self): ) if self.order_by is not None: - qbuild.order_by(self.order_by) + qb.order_by(self.order_by) - results = qbuild.all() + results = qb.all() # I return a set of results as required by the API # First, however, convert to a list of tuples otherwise diff --git a/pyproject.toml b/pyproject.toml index 0b85b55..8d10c1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,8 +27,7 @@ classifiers = [ requires-python = ">=3.6" dependencies = [ - "aiida-core>=1.0", - "pydantic~=1.10.4", + "aiida-core~=2.5", "rich", ]