Skip to content

Commit

Permalink
Dependencies: Update pydantic~=2.4 (aiidateam#977)
Browse files Browse the repository at this point in the history
The requirement for `pydantic` was pinned to `v1` since `v2` has a lot
of backwards incompatible changes and it is difficult to provide a
version that is compatible with both versions.

As of `v2.5.0`, `aiida-core` also directly depends on `pydantic` and it
requires `~=2.4`, so here we apply the same requirement. The deprecated
code is replaced.

As a side effect, a test for `PwCalculation` started failing. It was
testing that no warnings were raised for specific inputs, but some
warnings _were_ being raised. These were not the warnings tested for in
the tests though, but raised by SQLAlchemy. AiiDA v2.5 upgraded to
SQLAlchemy v2 which as of v2.0.19 started emitting a warning. This is
ignored by a filter in `aiida-core`, but this is made undone by `pytest`.
The warning is filtered again in the `pyproject.toml` config for
`pytest` but this is not considered by `pytest.warns`. Therefore, the
warnings in `PwCalculation` are turned into the more specific
`UserWarning` such that the test can explicitly check for those.
  • Loading branch information
sphuber authored and bastonero committed Jan 6, 2025
1 parent 85e7fc1 commit f6a1233
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
'importlib_resources',
'jsonschema',
'numpy',
'pydantic~=1.10,>=1.10.8',
'pydantic~=2.0',
'packaging',
'qe-tools~=2.0',
'xmlschema~=2.0'
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_quantumespresso/calculations/pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,15 @@ def validate_inputs(cls, value, port_namespace):
'Please set one of the following in the input parameters:\n'
" parameters['CONTROL']['restart_mode'] = 'restart'\n"
" parameters['ELECTRONS']['startingpot'] = 'file'\n"
" parameters['ELECTRONS']['startingwfc'] = 'file'\n"
" parameters['ELECTRONS']['startingwfc'] = 'file'\n", UserWarning
)

if calculation_type in ('nscf', 'bands'):
if 'parent_folder' not in value:
warnings.warn(
f'`parent_folder` not provided for `{calculation_type}` calculation. For work chains wrapping this '
'calculation, you can disable this warning by excluding the `parent_folder` when exposing the '
'inputs of the `PwCalculation`.'
'inputs of the `PwCalculation`.', UserWarning
)

@classproperty
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_quantumespresso/common/hubbard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pylint: disable=no-name-in-module, invalid-name
from typing import List, Literal, Tuple

from pydantic import BaseModel, conint, constr, validator
from pydantic import BaseModel, conint, constr, field_validator

__all__ = ('HubbardParameters', 'Hubbard')

Expand Down Expand Up @@ -39,7 +39,7 @@ class HubbardParameters(BaseModel):
hubbard_type: Literal['Ueff', 'U', 'V', 'J', 'B', 'E2', 'E3']
"""Type of the Hubbard parameters used (`Ueff`, `U`, `V`, `J`, `B`, `E2`, `E3`)."""

@validator('atom_manifold', 'neighbour_manifold') # cls is mandatory to use
@field_validator('atom_manifold', 'neighbour_manifold') # cls is mandatory to use
def check_manifolds(cls, value): # pylint: disable=no-self-argument, no-self-use
"""Check the validity of the manifold input.
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_quantumespresso/data/hubbard_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def hubbard(self) -> Hubbard:
:returns: a :class:`~aiida_quantumespresso.common.hubbard.Hubbard` instance.
"""
with self.base.repository.open(self._hubbard_filename, mode='rb') as handle:
return Hubbard.parse_raw(json.load(handle))
return Hubbard.model_validate_json(json.load(handle))

@hubbard.setter
def hubbard(self, hubbard: Hubbard):
"""Set the full Hubbard information."""
if not isinstance(hubbard, Hubbard):
raise ValueError('the input is not of type `Hubbard`')

serialized = json.dumps(hubbard.json())
serialized = json.dumps(hubbard.model_dump_json())
self.base.repository.put_object_from_bytes(serialized.encode('utf-8'), self._hubbard_filename)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_quantumespresso/utils/hubbard.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def reorder_atoms(self):
reordered = structure.clone() # to be set at the end
reordered.clear_kinds()

hubbard = structure.hubbard.copy()
hubbard = structure.hubbard.model_copy()
parameters = hubbard.to_list()

sites = structure.sites
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_quantumespresso/workflows/xspectra/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_builder_from_protocol(
if isinstance(code, str):
code = orm.load_code(code)

type_check(code, orm.Code)
type_check(code, orm.AbstractCode)

inputs = cls.get_protocol_inputs(protocol, overrides)

Expand Down
6 changes: 3 additions & 3 deletions tests/calculations/test_pw.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,15 @@ def test_pw_validate_inputs_restart_base(

# Add `parent_folder` but no restart tags -> warning
inputs['parent_folder'] = remote_data
with pytest.warns(Warning, match='`parent_folder` input was provided for the'):
with pytest.warns(UserWarning, match='`parent_folder` input was provided for the'):
generate_calc_job(fixture_sandbox, entry_point_name, inputs)

# Set `restart_mode` to `'restart'` -> no warning
parameters['CONTROL']['restart_mode'] = 'restart'
inputs['parameters'] = orm.Dict(parameters)
with pytest.warns(None) as warnings:
generate_calc_job(fixture_sandbox, entry_point_name, inputs)
assert len(warnings.list) == 0
assert len([w for w in warnings.list if w.category is UserWarning]) == 0, [w.message for w in warnings.list]
parameters['CONTROL'].pop('restart_mode')

# Set `startingwfc` or `startingpot` to `'file'` -> no warning
Expand All @@ -305,7 +305,7 @@ def test_pw_validate_inputs_restart_base(
inputs['parameters'] = orm.Dict(parameters)
with pytest.warns(None) as warnings:
generate_calc_job(fixture_sandbox, entry_point_name, inputs)
assert len(warnings.list) == 0
assert len([w for w in warnings.list if w.category is UserWarning]) == 0
parameters['ELECTRONS'].pop(restart_setting)


Expand Down
22 changes: 11 additions & 11 deletions tests/common/test_hubbard.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _get_hubbard():

def test_safe_hubbard_parameters(get_hubbard_parameters):
"""Test valid inputs are stored correctly for py:meth:`HubbardParameters`."""
params = get_hubbard_parameters().dict()
params = get_hubbard_parameters().model_dump()
assert params == VALID_PARAMETERS


Expand All @@ -60,7 +60,7 @@ def test_from_to_list_parameters(get_hubbard_parameters):
hp_tuple = (0, '3d', 1, '2p', 5.0, (0, 0, 0), 'U')
assert param.to_tuple() == hp_tuple
param = HubbardParameters.from_tuple(hp_tuple)
assert param.dict() == VALID_PARAMETERS
assert param.model_dump() == VALID_PARAMETERS


@pytest.mark.parametrize(
Expand All @@ -76,7 +76,7 @@ def test_from_to_list_parameters(get_hubbard_parameters):
)
def test_valid_hubbard_parameters(get_hubbard_parameters, overrides):
"""Test valid inputs for py:meth:`HubbardParameters`."""
hp_dict = get_hubbard_parameters(overrides=overrides).dict()
hp_dict = get_hubbard_parameters(overrides=overrides).model_dump()
new_dict = deepcopy(VALID_PARAMETERS)
new_dict.update(overrides)
assert hp_dict == new_dict
Expand All @@ -85,12 +85,12 @@ def test_valid_hubbard_parameters(get_hubbard_parameters, overrides):
@pytest.mark.parametrize(('overrides', 'match'), (
({
'atom_index': -1
}, r'ensure this value is greater than or equal to 0'),
}, r'Input should be greater than or equal to 0'),
(
{
'atom_index': 0.5
},
r'value is not a valid integer',
r'Input should be a valid integer',
),
(
{
Expand All @@ -108,31 +108,31 @@ def test_valid_hubbard_parameters(get_hubbard_parameters, overrides):
{
'atom_manifold': '3d-3p-2s'
},
r'ensure this value has at most 5 characters',
r'String should have at most 5 characters',
),
(
{
'translation': (0, 0)
},
r'wrong tuple length 2, expected 3',
r'translation\.2\n\s+Field required',
),
(
{
'translation': (0, 0, 0, 0)
},
r'wrong tuple length 4, expected 3',
r'Tuple should have at most 3 items after validation, not 4',
),
(
{
'translation': (0, 0, -1.5)
},
r'value is not a valid integer',
r'Input should be a valid integer',
),
(
{
'hubbard_type': 'L'
},
r"permitted: 'Ueff', 'U', 'V', 'J', 'B', 'E2', 'E3'",
r"Input should be 'Ueff', 'U', 'V', 'J', 'B', 'E2' or 'E3'",
),
))
def test_invalid_hubbard_parameters(get_hubbard_parameters, overrides, match):
Expand All @@ -150,7 +150,7 @@ def test_from_to_list_hubbard(get_hubbard):
assert hubbard.to_list() == hubbard_list

hubbard = Hubbard.from_list(hubbard_list)
assert hubbard.dict() == {
assert hubbard.model_dump() == {
'parameters': [VALID_PARAMETERS, VALID_PARAMETERS],
'projectors': 'ortho-atomic',
'formulation': 'dudarev',
Expand Down

0 comments on commit f6a1233

Please sign in to comment.