From 1e3f98632cc68b9d1a2258c29363fc99e574127a Mon Sep 17 00:00:00 2001 From: jung235 <96967431+jung235@users.noreply.github.com> Date: Thu, 19 Oct 2023 02:23:56 +0900 Subject: [PATCH] test: add test code (#5) --- tests/models/test_bm.py | 20 ++++++++++++++++++++ tests/tracer/test_ensemble.py | 18 ++++++++++++++++++ tests/utils/test_jitted.py | 11 ++++++++--- 3 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 tests/models/test_bm.py create mode 100644 tests/tracer/test_ensemble.py diff --git a/tests/models/test_bm.py b/tests/models/test_bm.py new file mode 100644 index 0000000..34b1fac --- /dev/null +++ b/tests/models/test_bm.py @@ -0,0 +1,20 @@ +import pytest + +from pydiffuser.models.bm import BrownianMotion, BrownianMotionConfig + + +def test_bm(): + model = BrownianMotion() + model.generate(dimension=3) + _, _, dim, _ = model.generate_info.values() + assert dim == 3 + + config = BrownianMotionConfig(dimension=1) + assert config.name == model.name + + model_v2 = BrownianMotion.from_config(config) + model_v2.generate(dimension=3) + _, _, dim, _ = model_v2.generate_info.values() + with pytest.raises(AssertionError): + assert dim == 3 + assert dim == 1 diff --git a/tests/tracer/test_ensemble.py b/tests/tracer/test_ensemble.py new file mode 100644 index 0000000..92791c3 --- /dev/null +++ b/tests/tracer/test_ensemble.py @@ -0,0 +1,18 @@ +import jax.numpy as jnp +import pytest + +from pydiffuser.tracer.ensemble import Ensemble +from pydiffuser.tracer.trajectory import Trajectory + + +def test_ensemble(): + ens = Ensemble(dt=1.0) + + tracer = Trajectory(dt=0.1, position_x1=[0, 1, 2], position_x2=[0, -1, -2]) + with pytest.raises(AssertionError): + ens.add(tracer) + + tracer = Trajectory(dt=1.0, position_x1=[0, 1, 2], position_x2=[0, -1, -2]) + ens.add(tracer) + microstate = jnp.array([[[3, -3], [4, -4], [5, -5]]]) + ens.update_microstate(microstate) diff --git a/tests/utils/test_jitted.py b/tests/utils/test_jitted.py index bd96659..c4bd6bc 100644 --- a/tests/utils/test_jitted.py +++ b/tests/utils/test_jitted.py @@ -4,7 +4,12 @@ def test_normalize(): - arr = jnp.array([3, 4]) + arr = jnp.array([[3, 4]]) normed_arr = normalize(arr=arr) - assert normed_arr[0] == 0.6 - assert normed_arr[1] == 0.8 + expected_arr = jnp.array([[0.6, 0.8]]) + assert jnp.all(normed_arr == expected_arr) + + arr = jnp.array([[3], [4]]) + normed_arr = normalize(arr=arr, axis=0) + expected_arr = jnp.array([[0.6], [0.8]]) + assert jnp.all(normed_arr == expected_arr)