Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 12, 2024
1 parent 8cc8b25 commit c241a97
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,3 @@ cython_debug/
.vscode/
*.h5
output.json

32 changes: 19 additions & 13 deletions tests/comparison.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import uuid
import warnings

import ase.io
import chemfiles
import mdtraj
import pandas as pd
import numpy as np
import chemfiles
import pandas as pd
import pytest
import ase.io
import uuid
import znh5md
import warnings

import znh5md

# n_steps, n_atoms

Expand All @@ -16,6 +17,7 @@
WRITE = [(1000, 1000)]
READ = [(1000, 1000)]


def collect_file_sizes(tmp_path) -> str:
# compute the mean and standard deviation of the file sizes in the given directory
sizes = []
Expand All @@ -24,6 +26,7 @@ def collect_file_sizes(tmp_path) -> str:
# print mean and standard deviation in megabytes
return f"Mean: {np.mean(sizes) / 1e6:.2f} MB, Std: {np.std(sizes) / 1e6:.2f} MB"


def create_topology(n_atoms):
"""Create an MDTraj topology for a given number of atoms."""
data = pd.DataFrame(
Expand All @@ -38,6 +41,7 @@ def create_topology(n_atoms):
)
return mdtraj.Topology().from_dataframe(data)


def generate_frames(n_steps, n_atoms) -> list[ase.Atoms]:
"""Generate a list of Chemfiles frames with random atomic positions."""
frames = []
Expand All @@ -46,19 +50,22 @@ def generate_frames(n_steps, n_atoms) -> list[ase.Atoms]:
frames.append(atoms)
return frames


def convert_atoms_to_chemfiles(atoms: ase.Atoms) -> chemfiles.Frame:
"""Convert an ASE atoms object to a Chemfiles frame."""
frame = chemfiles.Frame()
frame.resize(len(atoms))
frame.positions[:] = atoms.positions
return frame


@pytest.fixture
def frames(request):
"""Fixture to create frames based on n_steps and n_atoms from the request."""
n_steps, n_atoms = request.param
return list(generate_frames(n_steps, n_atoms))


@pytest.mark.benchmark(group="write")
@pytest.mark.parametrize("frames", WRITE, indirect=True)
def test_write_chemfiles_pdb(tmp_path, frames, benchmark):
Expand Down Expand Up @@ -105,6 +112,7 @@ def write_znh5md():
benchmark(write_znh5md)
warnings.warn(collect_file_sizes(tmp_path))


@pytest.mark.benchmark(group="write")
@pytest.mark.parametrize("frames", WRITE, indirect=True)
def test_write_xtc(tmp_path, frames, benchmark):
Expand All @@ -114,14 +122,13 @@ def test_write_xtc(tmp_path, frames, benchmark):
def write_xtc():
"""Inner function for benchmarking."""
filename = tmp_path / f"{uuid.uuid4()}.xtc"
traj = mdtraj.Trajectory(
positions, topology
)
traj = mdtraj.Trajectory(positions, topology)
traj.save_xtc(filename.as_posix())

benchmark(write_xtc)
warnings.warn(collect_file_sizes(tmp_path))


@pytest.mark.benchmark(group="write")
@pytest.mark.parametrize("frames", WRITE, indirect=True)
def test_write_ase_traj(tmp_path, frames, benchmark):
Expand All @@ -145,6 +152,7 @@ def write_ase_xyz():
benchmark(write_ase_xyz)
warnings.warn(collect_file_sizes(tmp_path))


@pytest.mark.benchmark(group="read")
@pytest.mark.parametrize("frames", READ, indirect=True)
def test_read_ase_traj(tmp_path, frames, benchmark):
Expand All @@ -170,6 +178,7 @@ def read_ase_traj():

benchmark(read_ase_traj)


@pytest.mark.parametrize("compression", ["lzf", "gzip", None])
@pytest.mark.benchmark(group="read")
@pytest.mark.parametrize("frames", READ, indirect=True)
Expand All @@ -190,9 +199,7 @@ def test_read_xtc(tmp_path, frames, benchmark):
topology = create_topology(len(frames[0]))
positions = np.array([frame.positions for frame in frames])
filename = tmp_path / f"{uuid.uuid4()}.xtc"
traj = mdtraj.Trajectory(
positions, topology
)
traj = mdtraj.Trajectory(positions, topology)
traj.save_xtc(filename.as_posix())

def read_xtc():
Expand Down Expand Up @@ -223,7 +230,6 @@ def read_chemfiles_pdb():
benchmark(read_chemfiles_pdb)



@pytest.mark.benchmark(group="read")
@pytest.mark.parametrize("frames", READ, indirect=True)
def test_read_chemfiles_xyz(tmp_path, frames, benchmark):
Expand Down

0 comments on commit c241a97

Please sign in to comment.