Skip to content

Commit

Permalink
Merge branch 'main' into issue_XyData_documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushjariyal authored Feb 14, 2025
2 parents d11f3d7 + d2fbf21 commit cffc576
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 134 deletions.
2 changes: 1 addition & 1 deletion .docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
docker~=7.0.0
pytest~=8.2.0
requests~=2.32.0
pytest-docker~=3.1.0
pytest-docker~=3.2.0
4 changes: 2 additions & 2 deletions .docker/tests/test_aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def test_correct_python_version_installed(aiida_exec, python_version):
info = json.loads(aiida_exec('mamba list --json --full-name python').decode())[0]
info = json.loads(aiida_exec('mamba list --json --full-name python', ignore_stderr=True).decode())[0]
assert info['name'] == 'python'
assert parse(info['version']) == parse(python_version)

Expand All @@ -15,7 +15,7 @@ def test_correct_pgsql_version_installed(aiida_exec, pgsql_version, variant):
if variant == 'aiida-core-base':
pytest.skip('PostgreSQL is not installed in the base image')

info = json.loads(aiida_exec('mamba list --json --full-name postgresql').decode())[0]
info = json.loads(aiida_exec('mamba list --json --full-name postgresql', ignore_stderr=True).decode())[0]
assert info['name'] == 'postgresql'
assert parse(info['version']).major == parse(pgsql_version).major

Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/ci-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ jobs:
with:
python-version: '3.12'
from-lock: 'true'
# NOTE: The `verdi devel check-undesired-imports` fails if
# the 'tui' extra is installed.
extras: ''

- name: Run verdi tests
Expand Down
12 changes: 8 additions & 4 deletions src/aiida/cmdline/commands/cmd_devel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ def devel_check_load_time():
def devel_check_undesired_imports():
"""Check that verdi does not import python modules it shouldn't.
Note: The blacklist was taken from the list of packages in the 'atomic_tools' extra but can be extended.
This is to keep the verdi CLI snappy, especially for tab-completion.
"""
loaded_modules = 0

for modulename in [
'asyncio',
unwanted_modules = [
'requests',
'plumpy',
'disk_objectstore',
Expand All @@ -78,7 +77,12 @@ def devel_check_undesired_imports():
'spglib',
'pymysql',
'yaml',
]:
]
# trogon powers the optional TUI and uses asyncio.
# Check for asyncio only when the optional tui extras are not installed.
if 'trogon' not in sys.modules:
unwanted_modules += 'asyncio'
for modulename in unwanted_modules:
if modulename in sys.modules:
echo.echo_warning(f'Detected loaded module "{modulename}"')
loaded_modules += 1
Expand Down
74 changes: 16 additions & 58 deletions src/aiida/storage/psql_dos/orm/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,10 @@ def add_nodes(self, nodes, **kwargs):
:note: all the nodes *and* the group itself have to be stored.
:param nodes: a list of `BackendNode` instance to be added to this group
:param kwargs:
skip_orm: When the flag is on, the SQLA ORM is skipped and SQLA is used
to create a direct SQL INSERT statement to the group-node relationship
table (to improve speed).
"""
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import IntegrityError

super().add_nodes(nodes)
skip_orm = kwargs.get('skip_orm', False)

def check_node(given_node):
"""Check if given node is of correct type and stored"""
Expand All @@ -188,31 +181,16 @@ def check_node(given_node):
raise ValueError('At least one of the provided nodes is unstored, stopping...')

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes

for node in nodes:
check_node(node)

# Use pattern as suggested here:
# http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint
try:
with session.begin_nested():
dbnodes.append(node.bare_model)
session.flush()
except IntegrityError:
# Duplicate entry, skip
pass
else:
ins_dict = []
for node in nodes:
check_node(node)
ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id})

table = self.GROUP_NODE_CLASS.__table__
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))
ins_dict = []
for node in nodes:
check_node(node)
ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id})
if len(ins_dict) == 0:
return

table = self.GROUP_NODE_CLASS.__table__
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))

# Commit everything as up till now we've just flushed
if not session.in_nested_transaction():
Expand All @@ -224,45 +202,25 @@ def remove_nodes(self, nodes, **kwargs):
:note: all the nodes *and* the group itself have to be stored.
:param nodes: a list of `BackendNode` instance to be added to this group
:param kwargs:
skip_orm: When the flag is set to `True`, the SQLA ORM is skipped and SQLA is used to create a direct SQL
DELETE statement to the group-node relationship table in order to improve speed.
"""
from sqlalchemy import and_

super().remove_nodes(nodes)

# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes
skip_orm = kwargs.get('skip_orm', False)

def check_node(node):
if not isinstance(node, self.NODE_CLASS):
raise TypeError(f'invalid type {type(node)}, has to be {self.NODE_CLASS}')

if node.id is None:
raise ValueError('At least one of the provided nodes is unstored, stopping...')

list_nodes = []

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
for node in nodes:
check_node(node)

# Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error
if node.bare_model in dbnodes:
list_nodes.append(node.bare_model)

for node in list_nodes:
dbnodes.remove(node)
else:
table = self.GROUP_NODE_CLASS.__table__
for node in nodes:
check_node(node)
clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id)
statement = table.delete().where(clause)
session.execute(statement)
table = self.GROUP_NODE_CLASS.__table__
for node in nodes:
check_node(node)
clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id)
statement = table.delete().where(clause)
session.execute(statement)

if not session.in_nested_transaction():
session.commit()
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/transports/plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def rename(self, oldpath: TransportPath, newpath: TransportPath):
:param str oldpath: existing name of the file or folder
:param str newpath: new name for the file or folder
:raises OSError: if src/dst is not found
:raises OSError: if oldpath is not found or newpath already exists
:raises ValueError: if src/dst is not a valid string
"""
oldpath = str(oldpath)
Expand All @@ -877,8 +877,8 @@ def rename(self, oldpath: TransportPath, newpath: TransportPath):
raise ValueError(f'Destination {newpath} is not a valid string')
if not os.path.exists(oldpath):
raise OSError(f'Source {oldpath} does not exist')
if not os.path.exists(newpath):
raise OSError(f'Destination {newpath} does not exist')
if os.path.exists(newpath):
raise OSError(f'Destination {newpath} already exists.')

shutil.move(oldpath, newpath)

Expand Down
63 changes: 0 additions & 63 deletions tests/orm/implementation/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,66 +25,3 @@ def test_creation_from_dbgroup(backend):

assert group.pk == gcopy.pk
assert group.uuid == gcopy.uuid


def test_add_nodes_skip_orm():
"""Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag."""
group = orm.Group(label='test_adding_nodes').store().backend_entity

node_01 = orm.Data().store().backend_entity
node_02 = orm.Data().store().backend_entity
node_03 = orm.Data().store().backend_entity
node_04 = orm.Data().store().backend_entity
node_05 = orm.Data().store().backend_entity
nodes = [node_01, node_02, node_03, node_04, node_05]

group.add_nodes([node_01], skip_orm=True)
group.add_nodes([node_02, node_03], skip_orm=True)
group.add_nodes((node_04, node_05), skip_orm=True)

assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to add a node that is already present: there should be no problem
group.add_nodes([node_01], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)


def test_add_nodes_skip_orm_batch():
"""Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag and batches."""
nodes = [orm.Data().store().backend_entity for _ in range(100)]

# Add nodes to groups using different batch size. Check in the end the correct addition.
batch_sizes = (1, 3, 10, 1000)
for batch_size in batch_sizes:
group = orm.Group(label=f'test_batches_{batch_size!s}').store()
group.backend_entity.add_nodes(nodes, skip_orm=True, batch_size=batch_size)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)


def test_remove_nodes_bulk():
"""Test node removal with `skip_orm=True`."""
group = orm.Group(label='test_removing_nodes').store().backend_entity

node_01 = orm.Data().store().backend_entity
node_02 = orm.Data().store().backend_entity
node_03 = orm.Data().store().backend_entity
node_04 = orm.Data().store().backend_entity
nodes = [node_01, node_02, node_03]

group.add_nodes(nodes)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove a node that is not in the group: nothing should happen
group.remove_nodes([node_04], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove one Node
nodes.remove(node_03)
group.remove_nodes([node_03], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove a list of Nodes and check
nodes.remove(node_01)
nodes.remove(node_02)
group.remove_nodes([node_01, node_02], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)
25 changes: 24 additions & 1 deletion tests/orm/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,27 @@ def test_add_nodes(self):
group.add_nodes(node_01)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to add nothing: there should be no problem
group.add_nodes([])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

nodes = [orm.Data().store().backend_entity for _ in range(100)]

# Add nodes to groups using different batch size. Check in the end the correct addition.
batch_sizes = (1, 3, 10, 1000)
for batch_size in batch_sizes:
group = orm.Group(label=f'test_batches_{batch_size!s}').store()
group.backend_entity.add_nodes(nodes, batch_size=batch_size)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

def test_remove_nodes(self):
"""Test node removal."""
node_01 = orm.Data().store()
node_02 = orm.Data().store()
node_03 = orm.Data().store()
node_04 = orm.Data().store()
nodes = [node_01, node_02, node_03]
node_05 = orm.Data().store()
nodes = [node_01, node_02, node_03, node_05]
group = orm.Group(label=uuid.uuid4().hex).store()

# Add initial nodes
Expand All @@ -177,6 +191,15 @@ def test_remove_nodes(self):
group.remove_nodes([node_01, node_02])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove to empty
nodes.remove(node_05)
group.remove_nodes([node_05])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to remove nothing: there should be no problem
group.remove_nodes([])
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

def test_clear(self):
"""Test the `clear` method to remove all nodes."""
node_01 = orm.Data().store()
Expand Down
25 changes: 25 additions & 0 deletions tests/transports/test_all_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,3 +1232,28 @@ def test_asynchronous_execution(custom_transport, tmp_path):
except ProcessLookupError:
# If the process is already dead (or has never run), I just ignore the error
pass


def test_rename(custom_transport, tmp_path_remote):
"""Test the rename function of the transport plugin."""
with custom_transport as transport:
old_file = tmp_path_remote / 'oldfile.txt'
new_file = tmp_path_remote / 'newfile.txt'
another_file = tmp_path_remote / 'anotherfile.txt'

# Create a test file to rename
old_file.touch()
another_file.touch()
assert old_file.exists()
assert another_file.exists()

# Perform rename operation
transport.rename(old_file, new_file)

# Verify rename was successful
assert not old_file.exists()
assert new_file.exists()

# Perform rename operation if new file already exists
with pytest.raises(OSError, match='already exist|destination exists'):
transport.rename(new_file, another_file)

0 comments on commit cffc576

Please sign in to comment.