Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Armavica committed Jun 10, 2024
1 parent d6ef06c commit d805ae4
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: mixed-line-ending
- id: name-tests-test
args: [--pytest-test-first]
- repo: https://github.com/crate-ci/typos
rev: typos-dict-v0.11.20
hooks:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ ignore = [
]

[tool.ruff.lint.per-file-ignores]
"tests/test_*.py" = ["D", "S101", "PLR2004", "ANN201"]
"examples/*" = ["INP001"]
"examples/sir.py" = ["T201"]
"python/rebop/__init__.py" = ["D104"]
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for the python bindings."""
36 changes: 36 additions & 0 deletions tests/test_rebop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
import numpy.testing as npt
import pytest
import rebop
import xarray as xr


def sir_model(transmission: float = 1e-4, recovery: float = 0.01) -> rebop.Gillespie:
sir = rebop.Gillespie()
sir.add_reaction(transmission, ["S", "I"], ["I", "I"])
sir.add_reaction(recovery, ["I"], ["R"])
return sir


@pytest.mark.parametrize("seed", [None, *range(10)])
def test_sir(seed: int):
sir = sir_model()
ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, seed=seed)
assert isinstance(ds, xr.Dataset)
npt.assert_array_equal(ds.time, np.arange(251))
assert all(ds.S >= 0)
assert all(ds.I >= 0)
assert all(ds.R >= 0)
assert all(ds.S <= 1000)
assert all(ds.I <= 1000)
assert all(ds.R <= 1000)
npt.assert_array_equal(ds.S + ds.I + ds.R, [1000] * 251)


def test_fixed_seed():
sir = sir_model()
ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, seed=42)

assert ds.S[-1] == 0
assert ds.I[-1] == 166
assert ds.R[-1] == 834

0 comments on commit d805ae4

Please sign in to comment.