Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ICARRV and ICAR #6831

Merged
merged 12 commits into from
Aug 18, 2023
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
)
from pymc.distributions.multivariate import (
CAR,
ICAR,
Dirichlet,
DirichletMultinomial,
KroneckerNormal,
Expand Down Expand Up @@ -198,6 +199,7 @@
"Truncated",
"Censored",
"CAR",
"ICAR",
"PolyaGamma",
"HurdleGamma",
"HurdleLogNormal",
Expand Down
167 changes: 167 additions & 0 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"MatrixNormal",
"KroneckerNormal",
"CAR",
"ICAR",
"StickBreakingWeights",
]

Expand Down Expand Up @@ -2256,6 +2257,172 @@
)


class ICARRV(RandomVariable):
name = "icar"
ndim_supp = 1
ndims_params = [2, 1, 1, 0, 0, 0]
dtype = "floatX"
_print_name = ("ICAR", "\\operatorname{ICAR}")

def __call__(self, W, node1, node2, N, sigma, zero_sum_stdev, size=None, **kwargs):
return super().__call__(W, node1, node2, N, sigma, zero_sum_stdev, size=size, **kwargs)

Check warning on line 2268 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2268

Added line #L2268 was not covered by tests

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(

Check warning on line 2271 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2271

Added line #L2271 was not covered by tests
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)

@classmethod
def rng_fn(cls, rng, size, W, node1, node2, N, sigma, zero_sum_stdev):
raise NotImplementedError("Cannot sample from ICAR prior")

Check warning on line 2280 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2280

Added line #L2280 was not covered by tests


icar = ICARRV()


class ICAR(Continuous):
r"""
The intrinsic conditional autoregressive prior. It is primarily used to model
covariance between neighboring areas. It is a special case
of the :class:`~pymc.CAR` distribution where alpha is set to 1.

The log probability density function is

.. math::
f(\phi| W,\sigma) =
-\frac{1}{2\sigma^{2}} \sum_{i\sim j} (\phi_{i} - \phi_{j})^2 -
\frac{1}{2}*\frac{\sum_{i}{\phi_{i}}}{0.001N}^{2} - \ln{\sqrt{2\\pi}} -
\ln{0.001N}

The first term represents the spatial covariance component. Each $\\phi_{i}$ is penalized
based on the square distance from each of its neighbors. The notation $i\\sim j$
indicates a sum over all the neighbors of $\\phi_{i}$. The last three terms are the
Normal log density function where the mean is zero and the standard deviation is
$N * 0.001$ (where N is the length of the vector $\\phi$). This component imposes
a zero-sum constraint by finding the sum of the vector $\\phi$ and penalizing based
on its distance from zero.

Parameters
----------
W : ndarray of int
Symmetric adjacency matrix of 1s and 0s indicating adjacency between elements.

sigma : scalar, default 1
Standard deviation of the vector of phi's. Putting a prior on sigma
will result in a centered parameterization. In most cases, it is
preferable to use a non-centered parameterization by using the default
value and multiplying the resulting phi's by sigma. See the example below.

zero_sum_stdev : scalar, default 0.001
Controls how strongly to enforce the zero-sum constraint. The sum of
phi is normally distributed with a mean of zero and small standard deviation.
This parameter sets the standard deviation of a normal density function with
mean zero.


Examples
--------
This example illustrates how to switch between centered and non-centered
parameterizations.

.. code-block:: python

import numpy as np
import pymc as pm

# 4x4 adjacency matrix
# arranged in a square lattice

W = np.array([
[0,1,0,1],
[1,0,1,0],
[0,1,0,1],
[1,0,1,0]
])

# centered parameterization
with pm.Model():
sigma = pm.Exponential('sigma', 1)
phi = pm.ICAR('phi', W=W, sigma=sigma)
mu = phi

# non-centered parameterization
with pm.Model():
sigma = pm.Exponential('sigma', 1)
phi = pm.ICAR('phi', W=W)
mu = sigma * phi

References
----------
.. Mitzi, M., Wheeler-Martin, K., Simpson, D., Mooney, J. S.,
Gelman, A., Dimaggio, C.
"Bayesian hierarchical spatial models: Implementing the Besag York
Mollié model in stan"
Spatial and Spatio-temporal Epidemiology, Vol. 31, (Aug., 2019),
pp 1-18
.. Banerjee, S., Carlin, B., Gelfand, A. Hierarchical Modeling
and Analysis for Spatial Data. Second edition. CRC press. (2015)

"""

rv_op = icar

@classmethod
def dist(cls, W, sigma=1, zero_sum_stdev=0.001, **kwargs):
if not W.ndim == 2:
raise ValueError("W must be matrix with ndim=2")

Check warning on line 2376 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2375-L2376

Added lines #L2375 - L2376 were not covered by tests

if not W.shape[0] == W.shape[1]:
raise ValueError("W must be a square matrix")

Check warning on line 2379 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2378-L2379

Added lines #L2378 - L2379 were not covered by tests

if not np.allclose(W.T, W):
raise ValueError("W must be a symmetric matrix")

Check warning on line 2382 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2381-L2382

Added lines #L2381 - L2382 were not covered by tests

if np.any((W != 0) & (W != 1)):
raise ValueError("W must be composed of only 1s and 0s")

Check warning on line 2385 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2384-L2385

Added lines #L2384 - L2385 were not covered by tests

# convert adjacency matrix to edgelist representation
bwengals marked this conversation as resolved.
Show resolved Hide resolved
# An edgelist is a pair of lists.
# If node i and node j are connected then one list
# will contain i and the other will contain j at the same
# index value.
# We only use the lower triangle here because adjacency
# is a undirected connection.

node1, node2 = np.where(np.tril(W) == 1)

Check warning on line 2395 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2395

Added line #L2395 was not covered by tests

node1 = pt.as_tensor_variable(node1, dtype=int)
node2 = pt.as_tensor_variable(node2, dtype=int)

Check warning on line 2398 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2397-L2398

Added lines #L2397 - L2398 were not covered by tests

W = pt.as_tensor_variable(W, dtype=int)

Check warning on line 2400 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2400

Added line #L2400 was not covered by tests

N = pt.shape(W)[0]
N = pt.as_tensor_variable(N)

Check warning on line 2403 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2402-L2403

Added lines #L2402 - L2403 were not covered by tests

sigma = pt.as_tensor_variable(floatX(sigma))
zero_sum_stdev = pt.as_tensor_variable(floatX(zero_sum_stdev))

Check warning on line 2406 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2405-L2406

Added lines #L2405 - L2406 were not covered by tests

return super().dist([W, node1, node2, N, sigma, zero_sum_stdev], **kwargs)

Check warning on line 2408 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2408

Added line #L2408 was not covered by tests

def moment(rv, size, W, node1, node2, N, sigma, zero_sum_stdev):
return pt.zeros(N)

Check warning on line 2411 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2411

Added line #L2411 was not covered by tests

def logp(value, W, node1, node2, N, sigma, zero_sum_stdev):
pairwise_difference = (-1 / (2 * sigma**2)) * pt.sum(

Check warning on line 2414 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2414

Added line #L2414 was not covered by tests
pt.square(value[node1] - value[node2])
)
zero_sum = (

Check warning on line 2417 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2417

Added line #L2417 was not covered by tests
-0.5 * pt.pow(pt.sum(value) / (zero_sum_stdev * N), 2)
- pt.log(pt.sqrt(2.0 * np.pi))
- pt.log(zero_sum_stdev * N)
)

return check_parameters(pairwise_difference + zero_sum, sigma > 0, msg="sigma > 0")

Check warning on line 2423 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2423

Added line #L2423 was not covered by tests


class StickBreakingWeightsRV(RandomVariable):
name = "stick_breaking_weights"
ndim_supp = 1
Expand Down
72 changes: 72 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,18 @@ def test_car_moment(self, mu, size, expected):
pm.CAR("x", mu=mu, W=W, alpha=alpha, tau=tau, size=size)
assert_moment_is_expected(model, expected)

@pytest.mark.parametrize(
"W, expected",
[
(np.array([[0, 1, 0], [1, 0, 1], [0, 1, 0]]), np.array([0, 0, 0])),
(np.array([[0, 1], [1, 0]]), np.array([0, 0])),
],
)
def test_icar_moment(self, W, expected):
with pm.Model() as model:
RV = pm.ICAR("x", W=W)
assert_moment_is_expected(model, expected)

@pytest.mark.parametrize(
"nu, mu, cov, size, expected",
[
Expand Down Expand Up @@ -2091,6 +2103,66 @@ def check_draws_match_expected(self):
assert np.all(np.abs(draw(x, random_seed=rng) - np.array([0.5, 0, 2.0])) < 0.01)


class TestICAR(BaseTestDistributionRandom):
pymc_dist = pm.ICAR
pymc_dist_params = {"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]), "sigma": 2}
expected_rv_op_params = {
"W": np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]),
"node1": np.array([1, 2, 2]),
"node2": np.array([0, 0, 1]),
"N": 3,
"sigma": 2,
"zero_sum_strength": 0.001,
}
checks_to_run = ["check_pymc_params_match_rv_op", "check_rv_inferred_size"]

def check_rv_inferred_size(self):
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
sizes_expected = [(3,), (3,), (1, 3), (1, 3), (5, 3), (4, 5, 3), (2, 4, 2, 3)]
for size, expected in zip(sizes_to_check, sizes_expected):
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
expected_symbolic = tuple(pymc_rv.shape.eval())
assert expected_symbolic == expected

def test_icar_logp(self):
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])

with pm.Model() as m:
RV = pm.ICAR("phi", W=W)

assert pt.isclose(
pm.logp(RV, np.array([0.01, -0.03, 0.02, 0.00])).eval(), np.array(4.60022238)
).eval(), "logp inaccuracy"

def test_icar_rng_fn(self):
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])

RV = pm.ICAR.dist(W=W)

with pytest.raises(NotImplementedError, match="Cannot sample from ICAR prior"):
pm.draw(RV)

@pytest.mark.parametrize(
"W,msg",
[
(np.array([0, 1, 0, 0]), "W must be matrix with ndim=2"),
(np.array([[0, 1, 0, 0], [1, 0, 0, 1], [1, 0, 0, 1]]), "W must be a square matrix"),
(
np.array([[0, 1, 0, 0], [1, 0, 0, 1], [1, 0, 0, 1], [0, 1, 1, 0]]),
"W must be a symmetric matrix",
),
(
np.array([[0, 1, 1, 0], [1, 0, 0, 0.5], [1, 0, 0, 1], [0, 0.5, 1, 0]]),
"W must be composed of only 1s and 0s",
),
],
)
def test_icar_matrix_checks(self, W, msg):
with pytest.raises(ValueError, match=msg):
with pm.Model():
pm.ICAR("phi", W=W)


@pytest.mark.parametrize("sparse", [True, False])
def test_car_rng_fn(sparse):
delta = 0.05 # limit for KS p-value
Expand Down
Loading