diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py
index 2aa3c889e3..b8380e0608 100644
--- a/pymc/distributions/multivariate.py
+++ b/pymc/distributions/multivariate.py
@@ -849,9 +849,9 @@ def dist(cls, *args, **kwargs):
 def posdef(AA):
     try:
         linalg.cholesky(AA)
-        return 1
+        return True
     except linalg.LinAlgError:
-        return 0
+        return False
 
 
 class PosDefMatrix(Op):
@@ -868,7 +868,7 @@ class PosDefMatrix(Op):
     def make_node(self, x):
         x = pt.as_tensor_variable(x)
         assert x.ndim == 2
-        o = TensorType(dtype="int8", shape=[])()
+        o = TensorType(dtype="bool", shape=[])()
         return Apply(self, [x], [o])
 
     # Python implementation:
@@ -876,7 +876,7 @@ def perform(self, node, inputs, outputs):
         (x,) = inputs
         (z,) = outputs
         try:
-            z[0] = np.array(posdef(x), dtype="int8")
+            z[0] = np.array(posdef(x), dtype="bool")
         except Exception:
             pm._log.exception("Failed to check if %s positive definite", x)
             raise
diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py
index 6de6c9e388..5016b0897d 100644
--- a/pymc/sampling/jax.py
+++ b/pymc/sampling/jax.py
@@ -21,6 +21,7 @@
 
 import arviz as az
 import jax
+import jax.numpy as jnp
 import numpy as np
 import pytensor.tensor as pt
 
@@ -34,10 +35,10 @@
 from pytensor.raise_op import Assert
 from pytensor.tensor import TensorVariable
 from pytensor.tensor.random.type import RandomType
-from pytensor.tensor.shape import SpecifyShape
 
 from pymc import Model, modelcontext
 from pymc.backends.arviz import find_constants, find_observations
+from pymc.distributions.multivariate import PosDefMatrix
 from pymc.initial_point import StartDict
 from pymc.logprob.utils import CheckParameterValue
 from pymc.sampling.mcmc import _init_jitter
@@ -62,7 +63,6 @@
 
 @jax_funcify.register(Assert)
 @jax_funcify.register(CheckParameterValue)
-@jax_funcify.register(SpecifyShape)
 def jax_funcify_Assert(op, **kwargs):
     # Jax does not allow assert whose values aren't known during JIT compilation
     # within it's JIT-ed code. Hence we need to make a simple pass through
@@ -74,6 +74,15 @@ def assert_fn(value, *inps):
     return assert_fn
 
 
+@jax_funcify.register(PosDefMatrix)
+def jax_funcify_PosDefMatrix(op, **kwargs):
+    def posdefmatrix_fn(value, *inps):
+        no_pos_def = jnp.any(jnp.isnan(jnp.linalg.cholesky(value)))
+        return jnp.invert(no_pos_def)
+
+    return posdefmatrix_fn
+
+
 def _replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable]:
     """Replace shared variables in graph by their constant values
 
diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py
index 752f250291..174b4a7c79 100644
--- a/tests/distributions/test_multivariate.py
+++ b/tests/distributions/test_multivariate.py
@@ -1907,26 +1907,22 @@ def check_draws(self):
         def ref_rand(mu, rowcov, colcov):
             return st.matrix_normal.rvs(mean=mu, rowcov=rowcov, colcov=colcov)
 
-        with pm.Model():
-            matrixnormal = pm.MatrixNormal(
-                "matnormal",
-                mu=np.random.random((3, 3)),
-                rowcov=np.eye(3),
-                colcov=np.eye(3),
-            )
-            check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=1)
-
-        ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3))
+        matrixnormal = pm.MatrixNormal.dist(
+            mu=np.random.random((3, 3)),
+            rowcov=np.eye(3),
+            colcov=np.eye(3),
+        )
 
         p, f = delta, n_fails
         while p <= delta and f > 0:
-            matrixnormal_smp = check["matnormal"]
+            matrixnormal_smp = pm.draw(matrixnormal)
+            ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3))
 
             p = np.min(
                 [
                     st.ks_2samp(
-                        np.atleast_1d(matrixnormal_smp).flatten(),
-                        np.atleast_1d(ref_smp).flatten(),
+                        matrixnormal_smp.flatten(),
+                        ref_smp.flatten(),
                     )
                 ]
             )
@@ -2134,10 +2130,10 @@ def test_car_rng_fn(sparse):
 @pytest.mark.parametrize(
     "matrix, result",
     [
-        ([[1.0, 0], [0, 1]], 1),
-        ([[1.0, 2], [2, 1]], 0),
-        ([[1.0, 1], [1, 1]], 0),
-        ([[1, 0.99, 1], [0.99, 1, 0.999], [1, 0.999, 1]], 0),
+        ([[1.0, 0], [0, 1]], True),
+        ([[1.0, 2], [2, 1]], False),
+        ([[1.0, 1], [1, 1]], False),
+        ([[1, 0.99, 1], [0.99, 1, 0.999], [1, 0.999, 1]], False),
     ],
 )
 def test_posdef_symmetric(matrix, result):
diff --git a/tests/gp/test_hsgp_approx.py b/tests/gp/test_hsgp_approx.py
index b6f03a4acc..96465e9437 100644
--- a/tests/gp/test_hsgp_approx.py
+++ b/tests/gp/test_hsgp_approx.py
@@ -166,7 +166,7 @@ def test_parametrization_drop_first(self, model, cov_func, X1, drop_first):
                 assert n_coeffs == n_basis, "one was dropped when it shouldn't have been"
 
     @pytest.mark.parametrize("parameterization", ["centered", "noncentered"])
-    def test_prior(self, model, cov_func, X1, parameterization):
+    def test_prior(self, model, cov_func, X1, parameterization, rng):
         """Compare HSGP prior to unapproximated GP prior, pm.gp.Latent.  Draw samples from the
         prior and compare them using MMD two sample test.  Tests both centered and non-centered
         parameterizations.
@@ -178,7 +178,7 @@ def test_prior(self, model, cov_func, X1, parameterization):
             gp = pm.gp.Latent(cov_func=cov_func)
             f2 = gp.prior("f2", X=X1)
 
-            idata = pm.sample_prior_predictive(samples=1000)
+            idata = pm.sample_prior_predictive(samples=1000, random_seed=rng)
 
         samples1 = az.extract(idata.prior["f1"])["f1"].values.T
         samples2 = az.extract(idata.prior["f2"])["f2"].values.T
diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py
index 0a244877c2..bba5fb2925 100644
--- a/tests/sampling/test_jax.py
+++ b/tests/sampling/test_jax.py
@@ -18,6 +18,7 @@
 
 import arviz as az
 import jax
+import jax.numpy as jnp
 import numpy as np
 import pytensor
 import pytensor.tensor as pt
@@ -29,6 +30,7 @@
 
 import pymc as pm
 
+from pymc.distributions.multivariate import PosDefMatrix
 from pymc.sampling.jax import (
     _get_batched_jittered_initial_points,
     _get_log_likelihood,
@@ -49,6 +51,25 @@ def test_old_import_route():
     assert set(new_sj.__all__) <= set(dir(old_sj))
 
 
+def test_jax_PosDefMatrix():
+    x = pt.tensor(name="x", shape=(2, 2), dtype="float32")
+    matrix_pos_def = PosDefMatrix()
+    x_is_pos_def = matrix_pos_def(x)
+    f = pytensor.function(inputs=[x], outputs=[x_is_pos_def], mode="JAX")
+
+    test_cases = [
+        (jnp.eye(2), True),
+        (jnp.zeros(shape=(2, 2)), False),
+        (jnp.array([[1, -1.5], [0, 1.2]], dtype="float32"), True),
+        (-1 * jnp.array([[1, -1.5], [0, 1.2]], dtype="float32"), False),
+        (jnp.array([[1, -1.5], [0, -1.2]], dtype="float32"), False),
+    ]
+
+    for input, expected in test_cases:
+        actual = f(input)[0]
+        assert jnp.array_equal(a1=actual, a2=expected)
+
+
 @pytest.mark.parametrize(
     "sampler",
     [