diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index 41620d46d4..fdcfcf0781 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -279,17 +279,38 @@ def test_switch_encoding_both_branches(): assert np.isclose(logp_fn(2), ref_scipy.logsf(0.3)) -def test_switch_encoding_second_branch_measurable(): - x_rv = pt.random.normal(0.5, 1) - y_rv = pt.switch(x_rv < 1, 1, x_rv) +@pytest.mark.parametrize( + "measurable_idx, test_values, exp_logp", + [ + (1, (0.9, 1, 1.5), (-np.inf, st.norm(0.5, 1).logcdf(1), st.norm(0.5, 1).logpdf(1.5))), + (0, (1.5, 1, 0.9), (-np.inf, st.norm(0.5, 1).logsf(1), st.norm(0.5, 1).logpdf(0.9))), + ], +) +def test_switch_encoding_one_branch_measurable(measurable_idx, test_values, exp_logp): + x_rv = pt.random.normal(0.5, 1) # should not be defined again ideally + branches = (1, x_rv) if measurable_idx == 1 else (x_rv, 1) + + y_rv = pt.switch(x_rv < 1, *branches) y_vv = y_rv.clone() - ref_scipy = st.norm(0.5, 1) logprob = logp(y_rv, y_vv) + logp_fn = pytensor.function([y_vv], logprob) - assert logp_fn(0.5) == -np.inf + for i, j in zip(test_values, exp_logp): + assert np.isclose(logp_fn(i), j) - assert np.isclose(logp_fn(1), ref_scipy.logcdf(1)) - assert np.isclose(logp_fn(1.2), ref_scipy.logpdf(1.2)) + +def test_switch_encoding_discrete_fail(): + x_rv = pt.random.poisson(2) + y_rv = pt.switch(x_rv > 3, x_rv, 1) + + y_vv = x_rv.clone() + y_vv_test = 1 + + with pytest.raises( + NotImplementedError, + match="Logprob method not implemented", + ): + logp(y_rv, y_vv).eval({y_vv: y_vv_test})