Skip to content

Commit

Permalink
removed asserts, added tests on sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-saunders-phil committed Jul 26, 2023
1 parent 6f0579a commit e6a7acd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
10 changes: 0 additions & 10 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,12 +2259,6 @@ class ICAR(Continuous):
constraint by finding the sum of the vector $\\phi$ and penalizing based on its
distance from zero.
======== ==========================================
Support :math:`x \\in \\mathbb{R}^k`
Mean :math:`0`
Variance :math:`T^{-1}` ?
======== ==========================================
Parameters
----------
W : ndarray of int
Expand Down Expand Up @@ -2365,14 +2359,10 @@ def dist(cls, W, sigma=1, zero_sum_strength=0.001, **kwargs):
# check on sigma

Check warning on line 2360 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2359-L2360

Added lines #L2359 - L2360 were not covered by tests
sigma = pt.as_tensor_variable(floatX(sigma))
sigma = Assert("sigma > 0")(sigma, pt.gt(sigma, 0))

Check warning on line 2362 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2362

Added line #L2362 was not covered by tests
# check on centering_strength

zero_sum_strength = pt.as_tensor_variable(floatX(zero_sum_strength))

Check warning on line 2365 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2364-L2365

Added lines #L2364 - L2365 were not covered by tests
zero_sum_strength = Assert("centering_strength > 0")(
zero_sum_strength, pt.gt(zero_sum_strength, 0)
)

return super().dist([W, node1, node2, N, sigma, zero_sum_strength], **kwargs)

Expand Down
10 changes: 9 additions & 1 deletion tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,7 +2093,15 @@ class TestICAR(BaseTestDistributionRandom):
"sigma": 2,
"zero_sum_strength": 0.001,
}
checks_to_run = ["check_pymc_params_match_rv_op"]
checks_to_run = ["check_pymc_params_match_rv_op", "check_rv_inferred_size"]

def check_rv_inferred_size(self):
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
sizes_expected = [(3,), (3,), (1, 3), (1, 3), (5, 3), (4, 5, 3), (2, 4, 2, 3)]
for size, expected in zip(sizes_to_check, sizes_expected):
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
expected_symbolic = tuple(pymc_rv.shape.eval())
assert expected_symbolic == expected

def test_icar_logp(self):
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
Expand Down

0 comments on commit e6a7acd

Please sign in to comment.