diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4cf8986f38..d72431194e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -413,7 +413,7 @@ jobs: floatx: [float32] python-version: ["3.11"] test-subset: - - tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py + - tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index cbd613abdc..0b8d494276 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -958,8 +958,10 @@ class SimplexTransform(RVTransform): name = "simplex" def forward(self, value, *inputs): + value = pt.as_tensor(value) log_value = pt.log(value) - shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1] + N = value.shape[-1].astype(value.dtype) + shift = pt.sum(log_value, -1, keepdims=True) / N return log_value[..., :-1] - shift def backward(self, value, *inputs): @@ -968,7 +970,9 @@ def backward(self, value, *inputs): return exp_value_max / pt.sum(exp_value_max, -1, keepdims=True) def log_jac_det(self, value, *inputs): + value = pt.as_tensor(value) N = value.shape[-1] + 1 + N = N.astype(value.dtype) sum_value = pt.sum(value, -1, keepdims=True) value_sum_expanded = value + sum_value value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1) diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index a29bb84ddf..f0979938e3 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -44,10 +44,10 @@ # some transforms (stick breaking) require addition of small slack in order to be numerically # stable. The minimal addable slack for float32 is higher thus we need to be less strict -tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-6 +tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-5 -def check_transform(transform, domain, constructor=pt.dscalar, test=0, rv_var=None): +def check_transform(transform, domain, constructor=pt.scalar, test=0, rv_var=None): x = constructor("x") x.tag.test_value = test if rv_var is None: @@ -57,18 +57,20 @@ def check_transform(transform, domain, constructor=pt.dscalar, test=0, rv_var=No # FIXME: What's being tested here? That the transformed graph can compile? forward_f = pytensor.function([x], transform.forward(x, *rv_inputs)) # test transform identity - identity_f = pytensor.function( - [x], transform.backward(transform.forward(x, *rv_inputs), *rv_inputs) - ) + z = transform.backward(transform.forward(x, *rv_inputs)) + assert z.type == x.type + identity_f = pytensor.function([x], z, *rv_inputs) for val in domain.vals: close_to(val, identity_f(val), tol) def check_vector_transform(transform, domain, rv_var=None): - return check_transform(transform, domain, pt.dvector, test=np.array([0, 0]), rv_var=rv_var) + return check_transform( + transform, domain, pt.vector, test=floatX(np.array([0, 0])), rv_var=rv_var + ) -def get_values(transform, domain=R, constructor=pt.dscalar, test=0, rv_var=None): +def get_values(transform, domain=R, constructor=pt.scalar, test=0, rv_var=None): x = constructor("x") x.tag.test_value = test if rv_var is None: @@ -81,7 +83,7 @@ def get_values(transform, domain=R, constructor=pt.dscalar, test=0, rv_var=None) def check_jacobian_det( transform, domain, - constructor=pt.dscalar, + constructor=pt.scalar, test=0, make_comparable=None, elemwise=False, @@ -119,22 +121,26 @@ def test_simplex(): check_vector_transform(tr.simplex, Simplex(2)) check_vector_transform(tr.simplex, Simplex(4)) - check_transform(tr.simplex, MultiSimplex(3, 2), constructor=pt.dmatrix, test=np.zeros((2, 2))) + check_transform( + tr.simplex, MultiSimplex(3, 2), constructor=pt.matrix, test=floatX(np.zeros((2, 2))) + ) def test_simplex_bounds(): - vals = get_values(tr.simplex, Vector(R, 2), pt.dvector, np.array([0, 0])) + vals = get_values(tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0]))) close_to(vals.sum(axis=1), 1, tol) close_to_logical(vals > 0, True, tol) close_to_logical(vals < 1, True, tol) - check_jacobian_det(tr.simplex, Vector(R, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1]) + check_jacobian_det( + tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), lambda x: x[:-1] + ) def test_simplex_accuracy(): - val = np.array([-30]) - x = pt.dvector("x") + val = floatX(np.array([-30])) + x = pt.vector("x") x.tag.test_value = val identity_f = pytensor.function([x], tr.simplex.forward(x, tr.simplex.backward(x, x))) close_to(val, identity_f(val), tol) @@ -148,10 +154,18 @@ def test_sum_to_1(): tr.SumTo1(2) check_jacobian_det( - tr.univariate_sum_to_1, Vector(Unit, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1] + tr.univariate_sum_to_1, + Vector(Unit, 2), + pt.vector, + floatX(np.array([0, 0])), + lambda x: x[:-1], ) check_jacobian_det( - tr.multivariate_sum_to_1, Vector(Unit, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1] + tr.multivariate_sum_to_1, + Vector(Unit, 2), + pt.vector, + floatX(np.array([0, 0])), + lambda x: x[:-1], ) @@ -159,17 +173,20 @@ def test_log(): check_transform(tr.log, Rplusbig) check_jacobian_det(tr.log, Rplusbig, elemwise=True) - check_jacobian_det(tr.log, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True) + check_jacobian_det(tr.log, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True) vals = get_values(tr.log) close_to_logical(vals > 0, True, tol) +@pytest.mark.skipif( + pytensor.config.floatX == "float32", reason="Test is designed for 64bit precision" +) def test_log_exp_m1(): check_transform(tr.log_exp_m1, Rplusbig) check_jacobian_det(tr.log_exp_m1, Rplusbig, elemwise=True) - check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True) + check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True) vals = get_values(tr.log_exp_m1) close_to_logical(vals > 0, True, tol) @@ -179,7 +196,7 @@ def test_logodds(): check_transform(tr.logodds, Unit) check_jacobian_det(tr.logodds, Unit, elemwise=True) - check_jacobian_det(tr.logodds, Vector(Unit, 2), pt.dvector, [0.5, 0.5], elemwise=True) + check_jacobian_det(tr.logodds, Vector(Unit, 2), pt.vector, [0.5, 0.5], elemwise=True) vals = get_values(tr.logodds) close_to_logical(vals > 0, True, tol) @@ -191,7 +208,7 @@ def test_lowerbound(): check_transform(trans, Rplusbig) check_jacobian_det(trans, Rplusbig, elemwise=True) - check_jacobian_det(trans, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True) + check_jacobian_det(trans, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True) vals = get_values(trans) close_to_logical(vals > 0, True, tol) @@ -202,7 +219,7 @@ def test_upperbound(): check_transform(trans, Rminusbig) check_jacobian_det(trans, Rminusbig, elemwise=True) - check_jacobian_det(trans, Vector(Rminusbig, 2), pt.dvector, [-1, -1], elemwise=True) + check_jacobian_det(trans, Vector(Rminusbig, 2), pt.vector, [-1, -1], elemwise=True) vals = get_values(trans) close_to_logical(vals < 0, True, tol) @@ -234,7 +251,7 @@ def test_interval_near_boundary(): pm.Uniform("x", initval=x0, lower=lb, upper=ub) log_prob = model.point_logps() - np.testing.assert_allclose(list(log_prob.values()), np.array([-52.68])) + np.testing.assert_allclose(list(log_prob.values()), floatX(np.array([-52.68]))) def test_circular(): @@ -257,19 +274,19 @@ def test_ordered(): tr.Ordered(2) check_jacobian_det( - tr.univariate_ordered, Vector(R, 2), pt.dvector, np.array([0, 0]), elemwise=False + tr.univariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False ) check_jacobian_det( - tr.multivariate_ordered, Vector(R, 2), pt.dvector, np.array([0, 0]), elemwise=False + tr.multivariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False ) - vals = get_values(tr.univariate_ordered, Vector(R, 3), pt.dvector, np.zeros(3)) + vals = get_values(tr.univariate_ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3))) close_to_logical(np.diff(vals) >= 0, True, tol) def test_chain_values(): chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) - vals = get_values(chain_tranf, Vector(R, 5), pt.dvector, np.zeros(5)) + vals = get_values(chain_tranf, Vector(R, 5), pt.vector, floatX(np.zeros(5))) close_to_logical(np.diff(vals) >= 0, True, tol) @@ -281,7 +298,7 @@ def test_chain_vector_transform(): @pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.") def test_chain_jacob_det(): chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered]) - check_jacobian_det(chain_tranf, Vector(R, 4), pt.dvector, np.zeros(4), elemwise=False) + check_jacobian_det(chain_tranf, Vector(R, 4), pt.vector, floatX(np.zeros(4)), elemwise=False) class TestElementWiseLogp(SeededTest): @@ -432,7 +449,7 @@ def transform_params(*inputs): [ (0.0, 1.0, 2.0, 2), (-10, 0, 200, (2, 3)), - (np.zeros(3), np.ones(3), np.ones(3), (4, 3)), + (floatX(np.zeros(3)), floatX(np.ones(3)), floatX(np.ones(3)), (4, 3)), ], ) def test_triangular(self, lower, c, upper, size): @@ -449,7 +466,8 @@ def transform_params(*inputs): self.check_transform_elementwise_logp(model) @pytest.mark.parametrize( - "mu,kappa,size", [(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (np.zeros(3), np.ones(3), (4, 3))] + "mu,kappa,size", + [(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (floatX(np.zeros(3)), floatX(np.ones(3)), (4, 3))], ) def test_vonmises(self, mu, kappa, size): model = self.build_model( @@ -549,7 +567,9 @@ def transform_params(*inputs): ) self.check_vectortransform_elementwise_logp(model) - @pytest.mark.parametrize("mu,kappa,size", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))]) + @pytest.mark.parametrize( + "mu,kappa,size", [(0.0, 1.0, (2,)), (floatX(np.zeros(3)), floatX(np.ones(3)), (4, 3))] + ) def test_vonmises_ordered(self, mu, kappa, size): initval = np.sort(np.abs(np.random.rand(*size))) model = self.build_model( @@ -566,7 +586,12 @@ def test_vonmises_ordered(self, mu, kappa, size): [ (0.0, 1.0, (2,), tr.simplex), (0.5, 5.5, (2, 3), tr.simplex), - (np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])), + ( + floatX(np.zeros(3)), + floatX(np.ones(3)), + (4, 3), + tr.Chain([tr.univariate_sum_to_1, tr.logodds]), + ), ], ) def test_uniform_other(self, lower, upper, size, transform): @@ -583,8 +608,8 @@ def test_uniform_other(self, lower, upper, size, transform): @pytest.mark.parametrize( "mu,cov,size,shape", [ - (np.zeros(2), np.diag(np.ones(2)), None, (2,)), - (np.zeros(3), np.diag(np.ones(3)), (4,), (4, 3)), + (floatX(np.zeros(2)), floatX(np.diag(np.ones(2))), None, (2,)), + (floatX(np.zeros(3)), floatX(np.diag(np.ones(3))), (4,), (4, 3)), ], ) def test_mvnormal_ordered(self, mu, cov, size, shape): @@ -643,7 +668,7 @@ def test_2d_univariate_ordered(): ) log_p = model.compile_logp(sum=False)( - {"x_1d_ordered__": np.zeros((4,)), "x_2d_ordered__": np.zeros((10, 4))} + {"x_1d_ordered__": floatX(np.zeros((4,))), "x_2d_ordered__": floatX(np.zeros((10, 4)))} ) np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1]) @@ -667,7 +692,7 @@ def test_2d_multivariate_ordered(): ) log_p = model.compile_logp(sum=False)( - {"x_1d_ordered__": np.zeros((2,)), "x_2d_ordered__": np.zeros((2, 2))} + {"x_1d_ordered__": floatX(np.zeros((2,))), "x_2d_ordered__": floatX(np.zeros((2, 2)))} ) np.testing.assert_allclose(log_p[0], log_p[1]) @@ -690,7 +715,7 @@ def test_2d_univariate_sum_to_1(): ) log_p = model.compile_logp(sum=False)( - {"x_1d_sumto1__": np.zeros(3), "x_2d_sumto1__": np.zeros((10, 3))} + {"x_1d_sumto1__": floatX(np.zeros(3)), "x_2d_sumto1__": floatX(np.zeros((10, 3)))} ) np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1]) @@ -712,6 +737,6 @@ def test_2d_multivariate_sum_to_1(): ) log_p = model.compile_logp(sum=False)( - {"x_1d_sumto1__": np.zeros(1), "x_2d_sumto1__": np.zeros((2, 1))} + {"x_1d_sumto1__": floatX(np.zeros(1)), "x_2d_sumto1__": floatX(np.zeros((2, 1)))} ) np.testing.assert_allclose(log_p[0], log_p[1])