Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX implementation fol MatrixIsPositiveDefinite Op #6853

Merged
merged 5 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,9 +849,9 @@
def posdef(AA):
try:
linalg.cholesky(AA)
return 1
return True

Check warning on line 852 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L852

Added line #L852 was not covered by tests
except linalg.LinAlgError:
return 0
return False

Check warning on line 854 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L854

Added line #L854 was not covered by tests


class PosDefMatrix(Op):
Expand All @@ -868,15 +868,15 @@
def make_node(self, x):
x = pt.as_tensor_variable(x)
assert x.ndim == 2
o = TensorType(dtype="int8", shape=[])()
o = TensorType(dtype="bool", shape=[])()
return Apply(self, [x], [o])

# Python implementation:
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
try:
z[0] = np.array(posdef(x), dtype="int8")
z[0] = np.array(posdef(x), dtype="bool")

Check warning on line 879 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L879

Added line #L879 was not covered by tests
except Exception:
pm._log.exception("Failed to check if %s positive definite", x)
raise
Expand Down
13 changes: 11 additions & 2 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import pytensor.tensor as pt

Expand All @@ -34,10 +35,10 @@
from pytensor.raise_op import Assert
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.shape import SpecifyShape

from pymc import Model, modelcontext
from pymc.backends.arviz import find_constants, find_observations
from pymc.distributions.multivariate import PosDefMatrix
from pymc.initial_point import StartDict
from pymc.logprob.utils import CheckParameterValue
from pymc.sampling.mcmc import _init_jitter
Expand All @@ -62,7 +63,6 @@

@jax_funcify.register(Assert)
@jax_funcify.register(CheckParameterValue)
@jax_funcify.register(SpecifyShape)
def jax_funcify_Assert(op, **kwargs):
# Jax does not allow assert whose values aren't known during JIT compilation
# within it's JIT-ed code. Hence we need to make a simple pass through
Expand All @@ -74,6 +74,15 @@ def assert_fn(value, *inps):
return assert_fn


@jax_funcify.register(PosDefMatrix)
def jax_funcify_PosDefMatrix(op, **kwargs):
def posdefmatrix_fn(value, *inps):
no_pos_def = jnp.any(jnp.isnan(jnp.linalg.cholesky(value)))
return jnp.invert(no_pos_def)

return posdefmatrix_fn


def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable]:
"""Replace shared variables in graph by their constant values

Expand Down
30 changes: 13 additions & 17 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1907,26 +1907,22 @@ def check_draws(self):
def ref_rand(mu, rowcov, colcov):
return st.matrix_normal.rvs(mean=mu, rowcov=rowcov, colcov=colcov)

with pm.Model():
matrixnormal = pm.MatrixNormal(
"matnormal",
mu=np.random.random((3, 3)),
rowcov=np.eye(3),
colcov=np.eye(3),
)
check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=1)

ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3))
matrixnormal = pm.MatrixNormal.dist(
mu=np.random.random((3, 3)),
rowcov=np.eye(3),
colcov=np.eye(3),
)

p, f = delta, n_fails
while p <= delta and f > 0:
matrixnormal_smp = check["matnormal"]
matrixnormal_smp = pm.draw(matrixnormal)
ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3))

p = np.min(
[
st.ks_2samp(
np.atleast_1d(matrixnormal_smp).flatten(),
np.atleast_1d(ref_smp).flatten(),
matrixnormal_smp.flatten(),
ref_smp.flatten(),
)
]
)
Expand Down Expand Up @@ -2134,10 +2130,10 @@ def test_car_rng_fn(sparse):
@pytest.mark.parametrize(
"matrix, result",
[
([[1.0, 0], [0, 1]], 1),
([[1.0, 2], [2, 1]], 0),
([[1.0, 1], [1, 1]], 0),
([[1, 0.99, 1], [0.99, 1, 0.999], [1, 0.999, 1]], 0),
([[1.0, 0], [0, 1]], True),
([[1.0, 2], [2, 1]], False),
([[1.0, 1], [1, 1]], False),
([[1, 0.99, 1], [0.99, 1, 0.999], [1, 0.999, 1]], False),
],
)
def test_posdef_symmetric(matrix, result):
Expand Down
4 changes: 2 additions & 2 deletions tests/gp/test_hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_parametrization_drop_first(self, model, cov_func, X1, drop_first):
assert n_coeffs == n_basis, "one was dropped when it shouldn't have been"

@pytest.mark.parametrize("parameterization", ["centered", "noncentered"])
def test_prior(self, model, cov_func, X1, parameterization):
def test_prior(self, model, cov_func, X1, parameterization, rng):
"""Compare HSGP prior to unapproximated GP prior, pm.gp.Latent. Draw samples from the
prior and compare them using MMD two sample test. Tests both centered and non-centered
parameterizations.
Expand All @@ -178,7 +178,7 @@ def test_prior(self, model, cov_func, X1, parameterization):
gp = pm.gp.Latent(cov_func=cov_func)
f2 = gp.prior("f2", X=X1)

idata = pm.sample_prior_predictive(samples=1000)
idata = pm.sample_prior_predictive(samples=1000, random_seed=rng)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One day we will get this 5 line PR merged. Just you wait :D


samples1 = az.extract(idata.prior["f1"])["f1"].values.T
samples2 = az.extract(idata.prior["f2"])["f2"].values.T
Expand Down
21 changes: 21 additions & 0 deletions tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import pytensor
import pytensor.tensor as pt
Expand All @@ -29,6 +30,7 @@

import pymc as pm

from pymc.distributions.multivariate import PosDefMatrix
from pymc.sampling.jax import (
_get_batched_jittered_initial_points,
_get_log_likelihood,
Expand All @@ -49,6 +51,25 @@ def test_old_import_route():
assert set(new_sj.__all__) <= set(dir(old_sj))


def test_jax_PosDefMatrix():
x = pt.tensor(name="x", shape=(2, 2), dtype="float32")
matrix_pos_def = PosDefMatrix()
x_is_pos_def = matrix_pos_def(x)
f = pytensor.function(inputs=[x], outputs=[x_is_pos_def], mode="JAX")

test_cases = [
(jnp.eye(2), True),
(jnp.zeros(shape=(2, 2)), False),
(jnp.array([[1, -1.5], [0, 1.2]], dtype="float32"), True),
(-1 * jnp.array([[1, -1.5], [0, 1.2]], dtype="float32"), False),
(jnp.array([[1, -1.5], [0, -1.2]], dtype="float32"), False),
]

for input, expected in test_cases:
actual = f(input)[0]
assert jnp.array_equal(a1=actual, a2=expected)


@pytest.mark.parametrize(
"sampler",
[
Expand Down
Loading