Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ICARRV and ICAR #6831

Merged
merged 12 commits into from
Aug 18, 2023
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
)
from pymc.distributions.multivariate import (
CAR,
ICAR,
Dirichlet,
DirichletMultinomial,
KroneckerNormal,
Expand Down Expand Up @@ -198,6 +199,7 @@
"Truncated",
"Censored",
"CAR",
"ICAR",
"PolyaGamma",
"HurdleGamma",
"HurdleLogNormal",
Expand Down
164 changes: 164 additions & 0 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"MatrixNormal",
"KroneckerNormal",
"CAR",
"ICAR",
"StickBreakingWeights",
]

Expand Down Expand Up @@ -2218,6 +2219,169 @@
)


class ICARRV(RandomVariable):
name = "icar"
ndim_supp = 1
ndims_params = [2, 1, 1, 0, 0, 0]
dtype = "floatX"
_print_name = ("ICAR", "\\operatorname{ICAR}")

def __call__(self, W, node1, node2, N, sigma, zero_sum_strength, size=None, **kwargs):
return super().__call__(W, node1, node2, N, sigma, zero_sum_strength, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, size, W, node1, node2, N, sigma, zero_sum_strength):
raise NotImplementedError("Cannot sample from ICAR prior")


icar = ICARRV()

Check warning on line 2238 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2238

Added line #L2238 was not covered by tests

class ICAR(Continuous):
r"""
The intrinsic conditional autoregressive prior. It is primarily used to model

Check warning on line 2242 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2242

Added line #L2242 was not covered by tests
covariance between neighboring areas on large datasets. It is a special case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on large datasets

Or small datasets!

of the :class:`~pymc.CAR` distribution where alpha is set to 1.

The log probability density function is

.. math::
f(\\phi| W,\\sigma) =
-\frac{1}{2\\sigma^{2}} \\sum_{i\\sim j} (\\phi_{i} - \\phi_{j})^2 -
\frac{1}{2}*\frac{\\sum_{i}{\\phi_{i}}}{0.001N}^{2} - \\ln{\\sqrt{2\\pi}} -
\\ln{0.001N}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are some backslashes twice and others once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, I wrote them with single slashes. Maybe one of the pre-commit checks added them by mistake?


The first term represents the spatial covariance component. Each $\\phi_{i}$ is penalized
based on the square distance from each of its neighbors. The notation $i\\sim j$
indicates a sum over all the neighbors of $\\phi_{i}$. The last three terms are the
Normal log density function where the mean is zero and the standard deviation is
$N * 0.001$ (where N is the length of the vector $\\phi$). This component imposed the zero-sum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
$N * 0.001$ (where N is the length of the vector $\\phi$). This component imposed the zero-sum
$N * 0.001$ (where N is the length of the vector $\\phi$). This component imposes a zero-sum

constraint by finding the sum of the vector $\\phi$ and penalizing based on its
distance from zero.

Parameters
----------
W : ndarray of int
Symmetric adjacency matrix of 1s and 0s indicating adjacency between elements.
Must pass either W or both node1 and node2.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Must pass either W or both node1 and node2.


sigma : scalar, default 1
Standard deviation of the vector of phi's. Putting a prior on sigma
will result in a centered parameterization. In most cases, it is
preferable to use a non-centered parameterization by using the default
value and multiplying the resulting phi's by sigma. See the example below.

zero_sum_strength : scalar, default 0.001
Controls how strongly to enforce the zero-sum constraint. It sets the
standard deviation of a normal density function with mean zero.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add a bit of detail here, like, "It puts an additional normal prior on the sum of the phi, such that the sum is normally distributed with mean zero and a small standard deviation, whose value is zero_sum_strength. Maybe zero_sum_stdev is a clearer name then?



Examples
--------
This example illustrates how to switch between centered and non-centered
parameterizations.

.. code-block:: python

import numpy as np
import pymc as pm

# 4x4 adjacency matrix
# arranged in a square lattice

W = np.array([[0,1,0,1],
[1,0,1,0],
[0,1,0,1],
[1,0,1,0]])

# centered parameterization

with pm.Model():
sigma = pm.Exponential('sigma',1)
phi = pm.ICAR('phi',W=W,sigma=sigma)

mu = phi

# non-centered parameterization

with pm.Model():
sigma = pm.Exponential('sigma',1)
phi = pm.ICAR('phi',W=W)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
W = np.array([[0,1,0,1],
[1,0,1,0],
[0,1,0,1],
[1,0,1,0]])
# centered parameterization
with pm.Model():
sigma = pm.Exponential('sigma',1)
phi = pm.ICAR('phi',W=W,sigma=sigma)
mu = phi
# non-centered parameterization
with pm.Model():
sigma = pm.Exponential('sigma',1)
phi = pm.ICAR('phi',W=W)
W = np.array([
[0, 1, 0, 1],
[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 1, 0]
])
# centered parameterization
with pm.Model():
sigma = pm.Exponential('sigma', 1)
phi = pm.ICAR('phi', W=W, sigma=sigma)
mu = phi
# non-centered parameterization
with pm.Model():
sigma = pm.Exponential('sigma', 1)
phi = pm.ICAR('phi', W=W)
mu = sigma * phi

adjusted some formatting here, mostly putting a space after comma


mu = sigma * phi

References
----------
.. Mitzi, M., Wheeler-Martin, K., Simpson, D., Mooney, J. S.,
Gelman, A., Dimaggio, C.
"Bayesian hierarchical spatial models: Implementing the Besag York
Mollié model in stan"
Spatial and Spatio-temporal Epidemiology, Vol. 31, (Aug., 2019),
pp 1-18
.. Banerjee, S., Carlin, B., Gelfand, A. Hierarchical Modeling
and Analysis for Spatial Data. Second edition. CRC press. (2015)

"""

rv_op = icar

@classmethod
def dist(cls, W, sigma=1, zero_sum_strength=0.001, **kwargs):
# check that adjacency matrix is two dimensional,
# square,
# symmetrical
# and composed of 1s or 0s.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block of comments doesn't add too much to the code doing the checks below, dont think its necessary


if not W.ndim == 2:
raise ValueError("W must be matrix with ndim=2")

if not W.shape[0] == W.shape[1]:
raise ValueError("W must be a square matrix")

if not np.allclose(W.T, W):
raise ValueError("W must be a symmetric matrix")

if np.any((W != 0) & (W != 1)):

Check warning on line 2344 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2343-L2344

Added lines #L2343 - L2344 were not covered by tests
raise ValueError("W must be composed of only 1s and 0s")

# convert adjacency matrix to edgelist representation

Check warning on line 2347 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2346-L2347

Added lines #L2346 - L2347 were not covered by tests
bwengals marked this conversation as resolved.
Show resolved Hide resolved

node1, node2 = np.where(np.tril(W) == 1)

Check warning on line 2350 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2349-L2350

Added lines #L2349 - L2350 were not covered by tests
node1 = pt.as_tensor_variable(node1, dtype=int)
node2 = pt.as_tensor_variable(node2, dtype=int)

Check warning on line 2353 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2352-L2353

Added lines #L2352 - L2353 were not covered by tests
W = pt.as_tensor_variable(W, dtype=int)

N = pt.shape(W)[0]
N = pt.as_tensor_variable(N)

Check warning on line 2357 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2357

Added line #L2357 was not covered by tests

# check on sigma
Copy link
Contributor

@bwengals bwengals Jul 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment and the comment check on centering strength are out of date now


Check warning on line 2360 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2359-L2360

Added lines #L2359 - L2360 were not covered by tests
sigma = pt.as_tensor_variable(floatX(sigma))

Check warning on line 2362 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2362

Added line #L2362 was not covered by tests
# check on centering_strength

zero_sum_strength = pt.as_tensor_variable(floatX(zero_sum_strength))

Check warning on line 2365 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2364-L2365

Added lines #L2364 - L2365 were not covered by tests

return super().dist([W, node1, node2, N, sigma, zero_sum_strength], **kwargs)

def moment(rv, size, W, node1, node2, N, sigma, zero_sum_strength):

Check warning on line 2369 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2369

Added line #L2369 was not covered by tests
return pt.zeros(N)

def logp(value, W, node1, node2, N, sigma, zero_sum_strength):
pairwise_difference = (-1 / (2 * sigma**2)) * pt.sum(

Check warning on line 2373 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2373

Added line #L2373 was not covered by tests
pt.square(value[node1] - value[node2])
)

Check warning on line 2375 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2375

Added line #L2375 was not covered by tests
soft_center = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as before RE wording, "center" vs zero sum. I think the latter is better and is more consistent with zero_sum_strength

-0.5 * pt.pow(pt.sum(value) / (zero_sum_strength * N), 2)
- pt.log(pt.sqrt(2.0 * np.pi))

Check warning on line 2378 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2378

Added line #L2378 was not covered by tests
- pt.log(zero_sum_strength * N)
)

Check warning on line 2381 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2381

Added line #L2381 was not covered by tests
return check_parameters(pairwise_difference + soft_center, sigma > 0, msg="sigma > 0")


Check warning on line 2384 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2384

Added line #L2384 was not covered by tests
class StickBreakingWeightsRV(RandomVariable):
name = "stick_breaking_weights"
ndim_supp = 1
Expand Down
72 changes: 72 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,18 @@ def test_car_moment(self, mu, size, expected):
pm.CAR("x", mu=mu, W=W, alpha=alpha, tau=tau, size=size)
assert_moment_is_expected(model, expected)

@pytest.mark.parametrize(
"W, expected",
[
(np.array([[0, 1, 0], [1, 0, 1], [0, 1, 0]]), np.array([0, 0, 0])),
(np.array([[0, 1], [1, 0]]), np.array([0, 0])),
],
)
def test_icar_moment(self, W, expected):
with pm.Model() as model:
RV = pm.ICAR("x", W=W)
assert_moment_is_expected(model, expected)

@pytest.mark.parametrize(
"nu, mu, cov, size, expected",
[
Expand Down Expand Up @@ -2070,6 +2082,66 @@ def check_draws_match_expected(self):
assert np.all(np.abs(draw(x, random_seed=rng) - np.array([0.5, 0, 2.0])) < 0.01)


class TestICAR(BaseTestDistributionRandom):
pymc_dist = pm.ICAR
pymc_dist_params = {"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]), "sigma": 2}
expected_rv_op_params = {
"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]),
"node1": np.array([1, 2, 2]),
"node2": np.array([0, 0, 1]),
"N": 3,
"sigma": 2,
"zero_sum_strength": 0.001,
}
checks_to_run = ["check_pymc_params_match_rv_op", "check_rv_inferred_size"]

def check_rv_inferred_size(self):
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
sizes_expected = [(3,), (3,), (1, 3), (1, 3), (5, 3), (4, 5, 3), (2, 4, 2, 3)]
for size, expected in zip(sizes_to_check, sizes_expected):
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
expected_symbolic = tuple(pymc_rv.shape.eval())
assert expected_symbolic == expected

def test_icar_logp(self):
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])

with pm.Model() as m:
RV = pm.ICAR("phi", W=W)

assert pt.isclose(
pm.logp(RV, np.array([0.01, -0.03, 0.02, 0.00])).eval(), np.array(4.60022238)
).eval(), "logp inaccuracy"

def test_icar_rng_fn(self):
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])

RV = pm.ICAR.dist(W=W)

with pytest.raises(NotImplementedError, match="Cannot sample from ICAR prior"):
pm.draw(RV)

@pytest.mark.parametrize(
"W,msg",
[
(np.array([0, 1, 0, 0]), "W must be matrix with ndim=2"),
(np.array([[0, 1, 0, 0], [1, 0, 0, 1], [1, 0, 0, 1]]), "W must be a square matrix"),
(
np.array([[0, 1, 0, 0], [1, 0, 0, 1], [1, 0, 0, 1], [0, 1, 1, 0]]),
"W must be a symmetric matrix",
),
(
np.array([[0, 1, 1, 0], [1, 0, 0, 0.5], [1, 0, 0, 1], [0, 0.5, 1, 0]]),
"W must be composed of only 1s and 0s",
),
],
)
def test_icar_matrix_checks(self, W, msg):
with pytest.raises(ValueError, match=msg):
with pm.Model():
pm.ICAR("phi", W=W)


@pytest.mark.parametrize("sparse", [True, False])
def test_car_rng_fn(sparse):
delta = 0.05 # limit for KS p-value
Expand Down
Loading