From d805ae49ee218bf3323a1c188e11f8160148d2fa Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Sun, 28 Jan 2024 03:11:25 -0500 Subject: [PATCH] Add tests --- .pre-commit-config.yaml | 2 ++ pyproject.toml | 1 + tests/__init__.py | 1 + tests/test_rebop.py | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 40 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_rebop.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 12f9306..c1ccb75 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,8 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - id: mixed-line-ending + - id: name-tests-test + args: [--pytest-test-first] - repo: https://github.com/crate-ci/typos rev: typos-dict-v0.11.20 hooks: diff --git a/pyproject.toml b/pyproject.toml index 6c6c1c0..4985297 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] +"tests/test_*.py" = ["D", "S101", "PLR2004", "ANN201"] "examples/*" = ["INP001"] "examples/sir.py" = ["T201"] "python/rebop/__init__.py" = ["D104"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..5bec482 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the python bindings.""" diff --git a/tests/test_rebop.py b/tests/test_rebop.py new file mode 100644 index 0000000..dae3915 --- /dev/null +++ b/tests/test_rebop.py @@ -0,0 +1,36 @@ +import numpy as np +import numpy.testing as npt +import pytest +import rebop +import xarray as xr + + +def sir_model(transmission: float = 1e-4, recovery: float = 0.01) -> rebop.Gillespie: + sir = rebop.Gillespie() + sir.add_reaction(transmission, ["S", "I"], ["I", "I"]) + sir.add_reaction(recovery, ["I"], ["R"]) + return sir + + +@pytest.mark.parametrize("seed", [None, *range(10)]) +def test_sir(seed: int): + sir = sir_model() + ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, seed=seed) + assert isinstance(ds, xr.Dataset) + npt.assert_array_equal(ds.time, np.arange(251)) + assert all(ds.S >= 0) + assert all(ds.I >= 0) + assert all(ds.R >= 0) + assert all(ds.S <= 1000) + assert all(ds.I <= 1000) + assert all(ds.R <= 1000) + npt.assert_array_equal(ds.S + ds.I + ds.R, [1000] * 251) + + +def test_fixed_seed(): + sir = sir_model() + ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, seed=42) + + assert ds.S[-1] == 0 + assert ds.I[-1] == 166 + assert ds.R[-1] == 834