Skip to content

Commit

Permalink
Better coverage for float32 tests (#6780)
Browse files Browse the repository at this point in the history
* create a failing test

* fix the bug

* simplify

* add float32 test to transforms
  • Loading branch information
ferrine authored Jun 22, 2023
1 parent f91dd1c commit 14e673f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ jobs:
floatx: [float32]
python-version: ["3.11"]
test-subset:
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py
- tests/sampling/test_mcmc.py tests/ode/test_ode.py tests/ode/test_utils.py tests/distributions/test_transform.py
fail-fast: false
runs-on: ${{ matrix.os }}
env:
Expand Down
6 changes: 5 additions & 1 deletion pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,8 +958,10 @@ class SimplexTransform(RVTransform):
name = "simplex"

def forward(self, value, *inputs):
value = pt.as_tensor(value)
log_value = pt.log(value)
shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1]
N = value.shape[-1].astype(value.dtype)
shift = pt.sum(log_value, -1, keepdims=True) / N
return log_value[..., :-1] - shift

def backward(self, value, *inputs):
Expand All @@ -968,7 +970,9 @@ def backward(self, value, *inputs):
return exp_value_max / pt.sum(exp_value_max, -1, keepdims=True)

def log_jac_det(self, value, *inputs):
value = pt.as_tensor(value)
N = value.shape[-1] + 1
N = N.astype(value.dtype)
sum_value = pt.sum(value, -1, keepdims=True)
value_sum_expanded = value + sum_value
value_sum_expanded = pt.concatenate([value_sum_expanded, pt.zeros(sum_value.shape)], -1)
Expand Down
97 changes: 61 additions & 36 deletions tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@

# some transforms (stick breaking) require addition of small slack in order to be numerically
# stable. The minimal addable slack for float32 is higher thus we need to be less strict
tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-6
tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-5


def check_transform(transform, domain, constructor=pt.dscalar, test=0, rv_var=None):
def check_transform(transform, domain, constructor=pt.scalar, test=0, rv_var=None):
x = constructor("x")
x.tag.test_value = test
if rv_var is None:
Expand All @@ -57,18 +57,20 @@ def check_transform(transform, domain, constructor=pt.dscalar, test=0, rv_var=No
# FIXME: What's being tested here? That the transformed graph can compile?
forward_f = pytensor.function([x], transform.forward(x, *rv_inputs))
# test transform identity
identity_f = pytensor.function(
[x], transform.backward(transform.forward(x, *rv_inputs), *rv_inputs)
)
z = transform.backward(transform.forward(x, *rv_inputs))
assert z.type == x.type
identity_f = pytensor.function([x], z, *rv_inputs)
for val in domain.vals:
close_to(val, identity_f(val), tol)


def check_vector_transform(transform, domain, rv_var=None):
return check_transform(transform, domain, pt.dvector, test=np.array([0, 0]), rv_var=rv_var)
return check_transform(
transform, domain, pt.vector, test=floatX(np.array([0, 0])), rv_var=rv_var
)


def get_values(transform, domain=R, constructor=pt.dscalar, test=0, rv_var=None):
def get_values(transform, domain=R, constructor=pt.scalar, test=0, rv_var=None):
x = constructor("x")
x.tag.test_value = test
if rv_var is None:
Expand All @@ -81,7 +83,7 @@ def get_values(transform, domain=R, constructor=pt.dscalar, test=0, rv_var=None)
def check_jacobian_det(
transform,
domain,
constructor=pt.dscalar,
constructor=pt.scalar,
test=0,
make_comparable=None,
elemwise=False,
Expand Down Expand Up @@ -119,22 +121,26 @@ def test_simplex():
check_vector_transform(tr.simplex, Simplex(2))
check_vector_transform(tr.simplex, Simplex(4))

check_transform(tr.simplex, MultiSimplex(3, 2), constructor=pt.dmatrix, test=np.zeros((2, 2)))
check_transform(
tr.simplex, MultiSimplex(3, 2), constructor=pt.matrix, test=floatX(np.zeros((2, 2)))
)


def test_simplex_bounds():
vals = get_values(tr.simplex, Vector(R, 2), pt.dvector, np.array([0, 0]))
vals = get_values(tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])))

close_to(vals.sum(axis=1), 1, tol)
close_to_logical(vals > 0, True, tol)
close_to_logical(vals < 1, True, tol)

check_jacobian_det(tr.simplex, Vector(R, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1])
check_jacobian_det(
tr.simplex, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), lambda x: x[:-1]
)


def test_simplex_accuracy():
val = np.array([-30])
x = pt.dvector("x")
val = floatX(np.array([-30]))
x = pt.vector("x")
x.tag.test_value = val
identity_f = pytensor.function([x], tr.simplex.forward(x, tr.simplex.backward(x, x)))
close_to(val, identity_f(val), tol)
Expand All @@ -148,28 +154,39 @@ def test_sum_to_1():
tr.SumTo1(2)

check_jacobian_det(
tr.univariate_sum_to_1, Vector(Unit, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1]
tr.univariate_sum_to_1,
Vector(Unit, 2),
pt.vector,
floatX(np.array([0, 0])),
lambda x: x[:-1],
)
check_jacobian_det(
tr.multivariate_sum_to_1, Vector(Unit, 2), pt.dvector, np.array([0, 0]), lambda x: x[:-1]
tr.multivariate_sum_to_1,
Vector(Unit, 2),
pt.vector,
floatX(np.array([0, 0])),
lambda x: x[:-1],
)


def test_log():
check_transform(tr.log, Rplusbig)

check_jacobian_det(tr.log, Rplusbig, elemwise=True)
check_jacobian_det(tr.log, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True)
check_jacobian_det(tr.log, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)

vals = get_values(tr.log)
close_to_logical(vals > 0, True, tol)


@pytest.mark.skipif(
pytensor.config.floatX == "float32", reason="Test is designed for 64bit precision"
)
def test_log_exp_m1():
check_transform(tr.log_exp_m1, Rplusbig)

check_jacobian_det(tr.log_exp_m1, Rplusbig, elemwise=True)
check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True)
check_jacobian_det(tr.log_exp_m1, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)

vals = get_values(tr.log_exp_m1)
close_to_logical(vals > 0, True, tol)
Expand All @@ -179,7 +196,7 @@ def test_logodds():
check_transform(tr.logodds, Unit)

check_jacobian_det(tr.logodds, Unit, elemwise=True)
check_jacobian_det(tr.logodds, Vector(Unit, 2), pt.dvector, [0.5, 0.5], elemwise=True)
check_jacobian_det(tr.logodds, Vector(Unit, 2), pt.vector, [0.5, 0.5], elemwise=True)

vals = get_values(tr.logodds)
close_to_logical(vals > 0, True, tol)
Expand All @@ -191,7 +208,7 @@ def test_lowerbound():
check_transform(trans, Rplusbig)

check_jacobian_det(trans, Rplusbig, elemwise=True)
check_jacobian_det(trans, Vector(Rplusbig, 2), pt.dvector, [0, 0], elemwise=True)
check_jacobian_det(trans, Vector(Rplusbig, 2), pt.vector, [0, 0], elemwise=True)

vals = get_values(trans)
close_to_logical(vals > 0, True, tol)
Expand All @@ -202,7 +219,7 @@ def test_upperbound():
check_transform(trans, Rminusbig)

check_jacobian_det(trans, Rminusbig, elemwise=True)
check_jacobian_det(trans, Vector(Rminusbig, 2), pt.dvector, [-1, -1], elemwise=True)
check_jacobian_det(trans, Vector(Rminusbig, 2), pt.vector, [-1, -1], elemwise=True)

vals = get_values(trans)
close_to_logical(vals < 0, True, tol)
Expand Down Expand Up @@ -234,7 +251,7 @@ def test_interval_near_boundary():
pm.Uniform("x", initval=x0, lower=lb, upper=ub)

log_prob = model.point_logps()
np.testing.assert_allclose(list(log_prob.values()), np.array([-52.68]))
np.testing.assert_allclose(list(log_prob.values()), floatX(np.array([-52.68])))


def test_circular():
Expand All @@ -257,19 +274,19 @@ def test_ordered():
tr.Ordered(2)

check_jacobian_det(
tr.univariate_ordered, Vector(R, 2), pt.dvector, np.array([0, 0]), elemwise=False
tr.univariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False
)
check_jacobian_det(
tr.multivariate_ordered, Vector(R, 2), pt.dvector, np.array([0, 0]), elemwise=False
tr.multivariate_ordered, Vector(R, 2), pt.vector, floatX(np.array([0, 0])), elemwise=False
)

vals = get_values(tr.univariate_ordered, Vector(R, 3), pt.dvector, np.zeros(3))
vals = get_values(tr.univariate_ordered, Vector(R, 3), pt.vector, floatX(np.zeros(3)))
close_to_logical(np.diff(vals) >= 0, True, tol)


def test_chain_values():
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
vals = get_values(chain_tranf, Vector(R, 5), pt.dvector, np.zeros(5))
vals = get_values(chain_tranf, Vector(R, 5), pt.vector, floatX(np.zeros(5)))
close_to_logical(np.diff(vals) >= 0, True, tol)


Expand All @@ -281,7 +298,7 @@ def test_chain_vector_transform():
@pytest.mark.xfail(reason="Fails due to precision issue. Values just close to expected.")
def test_chain_jacob_det():
chain_tranf = tr.Chain([tr.logodds, tr.univariate_ordered])
check_jacobian_det(chain_tranf, Vector(R, 4), pt.dvector, np.zeros(4), elemwise=False)
check_jacobian_det(chain_tranf, Vector(R, 4), pt.vector, floatX(np.zeros(4)), elemwise=False)


class TestElementWiseLogp(SeededTest):
Expand Down Expand Up @@ -432,7 +449,7 @@ def transform_params(*inputs):
[
(0.0, 1.0, 2.0, 2),
(-10, 0, 200, (2, 3)),
(np.zeros(3), np.ones(3), np.ones(3), (4, 3)),
(floatX(np.zeros(3)), floatX(np.ones(3)), floatX(np.ones(3)), (4, 3)),
],
)
def test_triangular(self, lower, c, upper, size):
Expand All @@ -449,7 +466,8 @@ def transform_params(*inputs):
self.check_transform_elementwise_logp(model)

@pytest.mark.parametrize(
"mu,kappa,size", [(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (np.zeros(3), np.ones(3), (4, 3))]
"mu,kappa,size",
[(0.0, 1.0, 2), (-0.5, 5.5, (2, 3)), (floatX(np.zeros(3)), floatX(np.ones(3)), (4, 3))],
)
def test_vonmises(self, mu, kappa, size):
model = self.build_model(
Expand Down Expand Up @@ -549,7 +567,9 @@ def transform_params(*inputs):
)
self.check_vectortransform_elementwise_logp(model)

@pytest.mark.parametrize("mu,kappa,size", [(0.0, 1.0, (2,)), (np.zeros(3), np.ones(3), (4, 3))])
@pytest.mark.parametrize(
"mu,kappa,size", [(0.0, 1.0, (2,)), (floatX(np.zeros(3)), floatX(np.ones(3)), (4, 3))]
)
def test_vonmises_ordered(self, mu, kappa, size):
initval = np.sort(np.abs(np.random.rand(*size)))
model = self.build_model(
Expand All @@ -566,7 +586,12 @@ def test_vonmises_ordered(self, mu, kappa, size):
[
(0.0, 1.0, (2,), tr.simplex),
(0.5, 5.5, (2, 3), tr.simplex),
(np.zeros(3), np.ones(3), (4, 3), tr.Chain([tr.univariate_sum_to_1, tr.logodds])),
(
floatX(np.zeros(3)),
floatX(np.ones(3)),
(4, 3),
tr.Chain([tr.univariate_sum_to_1, tr.logodds]),
),
],
)
def test_uniform_other(self, lower, upper, size, transform):
Expand All @@ -583,8 +608,8 @@ def test_uniform_other(self, lower, upper, size, transform):
@pytest.mark.parametrize(
"mu,cov,size,shape",
[
(np.zeros(2), np.diag(np.ones(2)), None, (2,)),
(np.zeros(3), np.diag(np.ones(3)), (4,), (4, 3)),
(floatX(np.zeros(2)), floatX(np.diag(np.ones(2))), None, (2,)),
(floatX(np.zeros(3)), floatX(np.diag(np.ones(3))), (4,), (4, 3)),
],
)
def test_mvnormal_ordered(self, mu, cov, size, shape):
Expand Down Expand Up @@ -643,7 +668,7 @@ def test_2d_univariate_ordered():
)

log_p = model.compile_logp(sum=False)(
{"x_1d_ordered__": np.zeros((4,)), "x_2d_ordered__": np.zeros((10, 4))}
{"x_1d_ordered__": floatX(np.zeros((4,))), "x_2d_ordered__": floatX(np.zeros((10, 4)))}
)
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])

Expand All @@ -667,7 +692,7 @@ def test_2d_multivariate_ordered():
)

log_p = model.compile_logp(sum=False)(
{"x_1d_ordered__": np.zeros((2,)), "x_2d_ordered__": np.zeros((2, 2))}
{"x_1d_ordered__": floatX(np.zeros((2,))), "x_2d_ordered__": floatX(np.zeros((2, 2)))}
)
np.testing.assert_allclose(log_p[0], log_p[1])

Expand All @@ -690,7 +715,7 @@ def test_2d_univariate_sum_to_1():
)

log_p = model.compile_logp(sum=False)(
{"x_1d_sumto1__": np.zeros(3), "x_2d_sumto1__": np.zeros((10, 3))}
{"x_1d_sumto1__": floatX(np.zeros(3)), "x_2d_sumto1__": floatX(np.zeros((10, 3)))}
)
np.testing.assert_allclose(np.tile(log_p[0], (10, 1)), log_p[1])

Expand All @@ -712,6 +737,6 @@ def test_2d_multivariate_sum_to_1():
)

log_p = model.compile_logp(sum=False)(
{"x_1d_sumto1__": np.zeros(1), "x_2d_sumto1__": np.zeros((2, 1))}
{"x_1d_sumto1__": floatX(np.zeros(1)), "x_2d_sumto1__": floatX(np.zeros((2, 1)))}
)
np.testing.assert_allclose(log_p[0], log_p[1])

0 comments on commit 14e673f

Please sign in to comment.