Skip to content

Commit

Permalink
add new tests for rvs
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash committed Jan 18, 2024
1 parent deeea41 commit f1613d1
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 69 deletions.
4 changes: 3 additions & 1 deletion jaxampler/_src/rvs/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def __init__(self, p: Numeric | Any, name: Optional[str] = None) -> None:
super().__init__(name=name, shape=shape)

def check_params(self) -> None:
assert jnp.all(self._p >= 0.0), "All p must be greater than or equals to 0"
assert jnp.all(0.0 <= self._p) & jnp.all(
self._p <= 1.0
), "All p must be greater than or equals to 0 and less than or equals to 1"

@partial(jit, static_argnums=(0,))
def logpmf_x(self, x: Numeric) -> Numeric:
Expand Down
41 changes: 27 additions & 14 deletions tests/geometric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import jax
import jax.numpy as jnp
import pytest
from jax.scipy.stats import geom as jax_geom


Expand All @@ -25,30 +26,42 @@


class TestGeometric:
def test_shapes(self):
assert jnp.allclose(Geometric(p=0.5, name="test_logpmf_p0.5").logpmf_x(5), jax_geom.logpmf(5, 0.5))
def test_positive_p(self):
assert jnp.allclose(
Geometric(p=0.5, name="test_logpmf_p0.5").logpmf_x(5),
jax_geom.logpmf(5, 0.5),
)

# for different shapes
assert Geometric(p=[0.5, 0.1], name="test_logpmf_p2").logpmf_x(5).shape == (2,)
assert Geometric(p=[0.5, 0.1, 0.3], name="test_logpmf_p3").logpmf_x(5).shape == (3,)
def test_p_greater_than_1(self):
with pytest.raises(AssertionError):
Geometric(p=1.5, name="test_logpmf_p1.5")

# when probability is very small
assert jnp.allclose(Geometric(p=0.0001, name="test_logpmf_p0.0001").logpmf_x(5), jax_geom.logpmf(5, 0.0001))
def test_small_p(self):
assert jnp.allclose(
Geometric(p=0.0001, name="test_pmf_p0.0001").pmf_x(5),
jax_geom.pmf(5, 0.0001),
)
assert jnp.allclose(
Geometric(p=0.0001, name="test_logpmf_p0.0001").logpmf_x(5),
jax_geom.logpmf(5, 0.0001),
)

# when n is very large
def test_big_n(self):
assert jnp.allclose(
Geometric(p=0.1, name="test_pmf_p0.1").pmf_x(50),
jax_geom.pmf(50, 0.1),
)

def test_shapes(self):
assert Geometric(p=[0.5, 0.1], name="test_logpmf_p2").logpmf_x(5).shape == (2,)
assert Geometric(p=[0.5, 0.1, 0.3], name="test_logpmf_p3").logpmf_x(5).shape == (3,)
assert jnp.allclose(Geometric(p=0.1).logpmf_x(50), jax_geom.logpmf(50, 0.1))
assert jnp.allclose(Geometric(p=0.5, name="test_pmf_p0.5").pmf_x(5), jax_geom.pmf(5, 0.5))
assert jnp.allclose(
Geometric(p=[0.5, 0.1], name="test_pmf_p2").pmf_x(5), jax_geom.pmf(5, jnp.asarray([0.5, 0.1]))
)
assert Geometric(p=[[0.5, 0.1], [0.4, 0.1]], name="test_pmf_p2x3").pmf_x(5).shape == (2, 2)

# when probability is very small
assert jnp.allclose(Geometric(p=0.0001, name="test_pmf_p0.0001").pmf_x(5), jax_geom.pmf(5, 0.0001))

# when n is very large
assert jnp.allclose(Geometric(p=0.1, name="test_pmf_p0.1").pmf_x(50), jax_geom.pmf(50, 0.1))

def test_cdf_x(self):
geom_cdf = Geometric(p=0.2, name="test_cdf")
assert geom_cdf.cdf_x(-1) == 0.0
Expand Down
23 changes: 12 additions & 11 deletions tests/rayleigh_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,34 @@
import jax
import jax.numpy as jnp
import pytest
from jax.scipy.stats import uniform as jax_uniform

sys.path.append("../jaxampler")
from jaxampler.rvs import Rayleigh


class TestRayleigh:

def test_pdf(self):
def test_pdf_x(self):
assert Rayleigh(sigma=0.5).pdf_x(1) == 0.5413411

def test_shapes(self):
assert jnp.allclose(Rayleigh(sigma=[0.5, 0.3]).pdf_x(1), jnp.array([0.5413411, 0.04295468]))
assert jnp.allclose(Rayleigh(sigma=[0.5, 0.5, 0.5]).pdf_x(1), jnp.array([0.5413411, 0.5413411, 0.5413411]))
assert jnp.allclose(
Rayleigh(sigma=[0.5, 0.3]).pdf_x(1),
jnp.array([0.5413411, 0.04295468]),
)
assert jnp.allclose(
Rayleigh(sigma=[0.5, 0.5, 0.5]).pdf_x(1),
jnp.array([0.5413411, 0.5413411, 0.5413411]),
)
assert jnp.allclose(
Rayleigh(sigma=[[0.3, 0.5], [0.5, 0.4]]).pdf_x(1),
jnp.array([[0.04295468, 0.5413411], [0.5413411, 0.2746058]]))
jnp.array([[0.04295468, 0.5413411], [0.5413411, 0.2746058]]),
)

def test_imcompatible_shapes(self):
def test_incompatible_shapes(self):
with pytest.raises(ValueError):
Rayleigh(sigma=[[0.3, 0.5], [0.5]])

def test_out_of_bound(self):
# when sigma is negative
with pytest.raises(AssertionError):
assert Rayleigh(sigma=-1)

Expand All @@ -64,6 +68,3 @@ def test_rvs(self):
# without key
result = tpl_rvs.rvs(shape)
assert result.shape, shape + tpl_rvs._shape


print(Rayleigh(sigma=500).cdf_x(30))
137 changes: 113 additions & 24 deletions tests/triangular_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,117 @@


class TestTriangular:
def test_shape(self):
# when x is less than low
assert jnp.allclose(Triangular(low=0, mode=5, high=10, name="triangular_0_to_10").pdf_x(-1), 0)
# when x is less than mid
def test_x_is_less_than_low(self):
assert jnp.allclose(
Triangular(low=0, mode=5, high=10, name="triangular_0_to_10").pdf_x(2),
Triangular(
low=0,
mode=5,
high=10,
name="triangular_0_to_10",
).pdf_x(-1),
0,
)

def test_x_is_less_than_mid(self):
assert jnp.allclose(
Triangular(
low=0,
mode=5,
high=10,
name="triangular_0_to_10",
).pdf_x(2),
jnp.exp(jnp.log(2) + jnp.log(2) - jnp.log(10) - jnp.log(5)),
)
# when x is equal to mid
assert jnp.allclose(Triangular(low=0, mode=5, high=10, name="triangular_0_to_10").pdf_x(5), 0.2)
# when x is greater than mid
assert jnp.allclose(Triangular(low=0.0, mode=5.0, high=10.0, name="triangular_0_to_10").pdf_x(7), 0.12)
# when x is greater than high
assert jnp.allclose(Triangular(low=0, mode=5, high=10, name="triangular_0_to_10").pdf_x(11), 0)

# when low is negative
assert jnp.allclose(Triangular(low=-10, mode=5, high=10, name="triangular_n10_to_10").pdf_x(2), 0.08)
# when both low and high are negative
assert jnp.allclose(Triangular(low=-10, mode=-5, high=-1, name="triangular_n10_to_n1").pdf_x(-9), 2 / 45)
def test_x_is_equal_to_mid(self):
assert jnp.allclose(
Triangular(
low=0,
mode=5,
high=10,
name="triangular_0_to_10",
).pdf_x(5),
0.2,
)

def test_x_is_greater_than_mid(self):
assert jnp.allclose(
Triangular(
low=0.0,
mode=5.0,
high=10.0,
name="triangular_0_to_10",
).pdf_x(7),
0.12,
)

def test_x_is_greater_than_high(self):
assert jnp.allclose(
Triangular(
low=0,
mode=5,
high=10,
name="triangular_0_to_10",
).pdf_x(11),
0,
)

# when low is greater than high
def test_low_is_negative(self):
assert jnp.allclose(
Triangular(
low=-10,
mode=5,
high=10,
name="triangular_n10_to_10",
).pdf_x(2),
0.08,
)

def test_low_and_high_are_negative(self):
assert jnp.allclose(
Triangular(
low=-10,
mode=-5,
high=-1,
name="triangular_n10_to_n1",
).pdf_x(-9),
2 / 45,
)

def test_low_is_equal_to_high(self):
with pytest.raises(AssertionError):
Triangular(low=10, mode=5, high=0, name="triangular_10_to_0")
# when low is greater than mid
Triangular(
low=10,
mode=5,
high=0,
name="triangular_10_to_0",
)

def test_low_is_greater_than_mid(self):
with pytest.raises(AssertionError):
Triangular(low=1, mode=0, high=10, name="triangular_0_to_10")
# when mid is greater than high
Triangular(
low=1,
mode=0,
high=10,
name="triangular_0_to_10",
)

def test_mid_is_greater_than_high(self):
with pytest.raises(AssertionError):
Triangular(low=10, mode=30, high=20, name="triangular_10_to_20")
Triangular(
low=10,
mode=30,
high=20,
name="triangular_10_to_20",
)

def test_cdf_x(self):
triangular_cdf = Triangular(low=0, mode=5, high=10, name="cdf_0_to_10")
triangular_cdf = Triangular(
low=0,
mode=5,
high=10,
name="cdf_0_to_10",
)
# when x is equal to mode
assert triangular_cdf.cdf_x(5) == 0.5
# when x is greater than mid
Expand All @@ -67,7 +146,12 @@ def test_cdf_x(self):
assert triangular_cdf.cdf_x(-1) == 0

## when low is negative
triangular_cdf = Triangular(low=-5, mode=5, high=10, name="cdf_n5_to_10")
triangular_cdf = Triangular(
low=-5,
mode=5,
high=10,
name="cdf_n5_to_10",
)
# when x is equal to mode
assert triangular_cdf.cdf_x(5) == 0.6666666
# when x is greater than mid
Expand All @@ -78,7 +162,12 @@ def test_cdf_x(self):
assert triangular_cdf.cdf_x(-6) == 0

# when both high and low are negative
triangular_cdf = Triangular(low=-5, mode=-2.5, high=-1, name="cdf_n5_to_n1")
triangular_cdf = Triangular(
low=-5,
mode=-2.5,
high=-1,
name="cdf_n5_to_n1",
)
# when x is equal to mid
assert triangular_cdf.cdf_x(-2.5) == 0.625
# when x is greater than mid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,45 @@
import jax
import jax.numpy as jnp
import pytest
from jax.scipy.stats import uniform as jax_uniform

sys.path.append("../jaxampler")
from jaxampler.rvs import TruncPowerLaw


class TestTruncPowerLaw:

def test_pdf(self):
assert TruncPowerLaw(alpha=0.5, low=0.1, high=10).pdf_x(1) == 0.047481652
# when alpha is negative
assert TruncPowerLaw(alpha=-1, low=0.1, high=10).pdf_x(1) == 0.21714725

def test_shapes(self):
assert jnp.allclose(
TruncPowerLaw(alpha=[0.5, 0.1], low=[0.1, 0.2], high=[10, 10]).pdf_x(1), jnp.array([0.04748165,
0.08857405]))
TruncPowerLaw(alpha=[0.5, 0.1], low=[0.1, 0.2], high=[10, 10]).pdf_x(1), jnp.array([0.04748165, 0.08857405])
)
assert jnp.allclose(
TruncPowerLaw(alpha=[0.5, 0.1, 0.2], low=[0.1, 0.2, 0.2], high=[10, 10, 10]).pdf_x(1),
jnp.array([0.04748165, 0.08857405, 0.07641375]))
jnp.array([0.04748165, 0.08857405, 0.07641375]),
)

def test_imcompatible_shapes(self):
def test_incompatible_shapes(self):
with pytest.raises(ValueError):
TruncPowerLaw(alpha=[0.5, 0.1, 0.9], low=[0.1, 0.2], high=[10, 10])
TruncPowerLaw(
alpha=[0.5, 0.1, 0.9],
low=[0.1, 0.2],
high=[10, 10],
)
with pytest.raises(ValueError):
TruncPowerLaw(alpha=[0.5, 0.1, 0.9], low=[0.1, 0.2, 0.9], high=[10, 10])
TruncPowerLaw(
alpha=[0.5, 0.1, 0.9],
low=[0.1, 0.2, 0.9],
high=[10, 10],
)
with pytest.raises(ValueError):
TruncPowerLaw(alpha=[0.5, 0.1], low=[0.1, 0.2, 0.9], high=[10, 10])
TruncPowerLaw(
alpha=[0.5, 0.1],
low=[0.1, 0.2, 0.9],
high=[10, 10],
)

def test_out_of_bound(self):
# when x is less than low
Expand Down
Loading

0 comments on commit f1613d1

Please sign in to comment.