diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 2aa3c889e3..b8380e0608 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -849,9 +849,9 @@ def dist(cls, *args, **kwargs): def posdef(AA): try: linalg.cholesky(AA) - return 1 + return True except linalg.LinAlgError: - return 0 + return False class PosDefMatrix(Op): @@ -868,7 +868,7 @@ class PosDefMatrix(Op): def make_node(self, x): x = pt.as_tensor_variable(x) assert x.ndim == 2 - o = TensorType(dtype="int8", shape=[])() + o = TensorType(dtype="bool", shape=[])() return Apply(self, [x], [o]) # Python implementation: @@ -876,7 +876,7 @@ def perform(self, node, inputs, outputs): (x,) = inputs (z,) = outputs try: - z[0] = np.array(posdef(x), dtype="int8") + z[0] = np.array(posdef(x), dtype="bool") except Exception: pm._log.exception("Failed to check if %s positive definite", x) raise diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 6de6c9e388..5016b0897d 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -21,6 +21,7 @@ import arviz as az import jax +import jax.numpy as jnp import numpy as np import pytensor.tensor as pt @@ -34,10 +35,10 @@ from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable from pytensor.tensor.random.type import RandomType -from pytensor.tensor.shape import SpecifyShape from pymc import Model, modelcontext from pymc.backends.arviz import find_constants, find_observations +from pymc.distributions.multivariate import PosDefMatrix from pymc.initial_point import StartDict from pymc.logprob.utils import CheckParameterValue from pymc.sampling.mcmc import _init_jitter @@ -62,7 +63,6 @@ @jax_funcify.register(Assert) @jax_funcify.register(CheckParameterValue) -@jax_funcify.register(SpecifyShape) def jax_funcify_Assert(op, **kwargs): # Jax does not allow assert whose values aren't known during JIT compilation # within it's JIT-ed code. Hence we need to make a simple pass through @@ -74,6 +74,15 @@ def assert_fn(value, *inps): return assert_fn +@jax_funcify.register(PosDefMatrix) +def jax_funcify_PosDefMatrix(op, **kwargs): + def posdefmatrix_fn(value, *inps): + no_pos_def = jnp.any(jnp.isnan(jnp.linalg.cholesky(value))) + return jnp.invert(no_pos_def) + + return posdefmatrix_fn + + def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable]: """Replace shared variables in graph by their constant values diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 752f250291..174b4a7c79 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1907,26 +1907,22 @@ def check_draws(self): def ref_rand(mu, rowcov, colcov): return st.matrix_normal.rvs(mean=mu, rowcov=rowcov, colcov=colcov) - with pm.Model(): - matrixnormal = pm.MatrixNormal( - "matnormal", - mu=np.random.random((3, 3)), - rowcov=np.eye(3), - colcov=np.eye(3), - ) - check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=1) - - ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3)) + matrixnormal = pm.MatrixNormal.dist( + mu=np.random.random((3, 3)), + rowcov=np.eye(3), + colcov=np.eye(3), + ) p, f = delta, n_fails while p <= delta and f > 0: - matrixnormal_smp = check["matnormal"] + matrixnormal_smp = pm.draw(matrixnormal) + ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3)) p = np.min( [ st.ks_2samp( - np.atleast_1d(matrixnormal_smp).flatten(), - np.atleast_1d(ref_smp).flatten(), + matrixnormal_smp.flatten(), + ref_smp.flatten(), ) ] ) @@ -2134,10 +2130,10 @@ def test_car_rng_fn(sparse): @pytest.mark.parametrize( "matrix, result", [ - ([[1.0, 0], [0, 1]], 1), - ([[1.0, 2], [2, 1]], 0), - ([[1.0, 1], [1, 1]], 0), - ([[1, 0.99, 1], [0.99, 1, 0.999], [1, 0.999, 1]], 0), + ([[1.0, 0], [0, 1]], True), + ([[1.0, 2], [2, 1]], False), + ([[1.0, 1], [1, 1]], False), + ([[1, 0.99, 1], [0.99, 1, 0.999], [1, 0.999, 1]], False), ], ) def test_posdef_symmetric(matrix, result): diff --git a/tests/gp/test_hsgp_approx.py b/tests/gp/test_hsgp_approx.py index b6f03a4acc..96465e9437 100644 --- a/tests/gp/test_hsgp_approx.py +++ b/tests/gp/test_hsgp_approx.py @@ -166,7 +166,7 @@ def test_parametrization_drop_first(self, model, cov_func, X1, drop_first): assert n_coeffs == n_basis, "one was dropped when it shouldn't have been" @pytest.mark.parametrize("parameterization", ["centered", "noncentered"]) - def test_prior(self, model, cov_func, X1, parameterization): + def test_prior(self, model, cov_func, X1, parameterization, rng): """Compare HSGP prior to unapproximated GP prior, pm.gp.Latent. Draw samples from the prior and compare them using MMD two sample test. Tests both centered and non-centered parameterizations. @@ -178,7 +178,7 @@ def test_prior(self, model, cov_func, X1, parameterization): gp = pm.gp.Latent(cov_func=cov_func) f2 = gp.prior("f2", X=X1) - idata = pm.sample_prior_predictive(samples=1000) + idata = pm.sample_prior_predictive(samples=1000, random_seed=rng) samples1 = az.extract(idata.prior["f1"])["f1"].values.T samples2 = az.extract(idata.prior["f2"])["f2"].values.T diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 0a244877c2..bba5fb2925 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -18,6 +18,7 @@ import arviz as az import jax +import jax.numpy as jnp import numpy as np import pytensor import pytensor.tensor as pt @@ -29,6 +30,7 @@ import pymc as pm +from pymc.distributions.multivariate import PosDefMatrix from pymc.sampling.jax import ( _get_batched_jittered_initial_points, _get_log_likelihood, @@ -49,6 +51,25 @@ def test_old_import_route(): assert set(new_sj.__all__) <= set(dir(old_sj)) +def test_jax_PosDefMatrix(): + x = pt.tensor(name="x", shape=(2, 2), dtype="float32") + matrix_pos_def = PosDefMatrix() + x_is_pos_def = matrix_pos_def(x) + f = pytensor.function(inputs=[x], outputs=[x_is_pos_def], mode="JAX") + + test_cases = [ + (jnp.eye(2), True), + (jnp.zeros(shape=(2, 2)), False), + (jnp.array([[1, -1.5], [0, 1.2]], dtype="float32"), True), + (-1 * jnp.array([[1, -1.5], [0, 1.2]], dtype="float32"), False), + (jnp.array([[1, -1.5], [0, -1.2]], dtype="float32"), False), + ] + + for input, expected in test_cases: + actual = f(input)[0] + assert jnp.array_equal(a1=actual, a2=expected) + + @pytest.mark.parametrize( "sampler", [