Skip to content

Commit

Permalink
Tests for measurable branches and deny discrete
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas3156 committed Aug 7, 2023
1 parent 73d7979 commit d4be8b5
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions tests/logprob/test_censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,38 @@ def test_switch_encoding_both_branches():
assert np.isclose(logp_fn(2), ref_scipy.logsf(0.3))


def test_switch_encoding_second_branch_measurable():
x_rv = pt.random.normal(0.5, 1)
y_rv = pt.switch(x_rv < 1, 1, x_rv)
@pytest.mark.parametrize(
"measurable_idx, test_values, exp_logp",
[
(1, (0.9, 1, 1.5), (-np.inf, st.norm(0.5, 1).logcdf(1), st.norm(0.5, 1).logpdf(1.5))),
(0, (1.5, 1, 0.9), (-np.inf, st.norm(0.5, 1).logsf(1), st.norm(0.5, 1).logpdf(0.9))),
],
)
def test_switch_encoding_one_branch_measurable(measurable_idx, test_values, exp_logp):
x_rv = pt.random.normal(0.5, 1) # should not be defined again ideally
branches = (1, x_rv) if measurable_idx == 1 else (x_rv, 1)

y_rv = pt.switch(x_rv < 1, *branches)

y_vv = y_rv.clone()
ref_scipy = st.norm(0.5, 1)

logprob = logp(y_rv, y_vv)

logp_fn = pytensor.function([y_vv], logprob)

assert logp_fn(0.5) == -np.inf
for i, j in zip(test_values, exp_logp):
assert np.isclose(logp_fn(i), j)

assert np.isclose(logp_fn(1), ref_scipy.logcdf(1))
assert np.isclose(logp_fn(1.2), ref_scipy.logpdf(1.2))

def test_switch_encoding_discrete_fail():
x_rv = pt.random.poisson(2)
y_rv = pt.switch(x_rv > 3, x_rv, 1)

y_vv = x_rv.clone()
y_vv_test = 1

with pytest.raises(
NotImplementedError,
match="Logprob method not implemented",
):
logp(y_rv, y_vv).eval({y_vv: y_vv_test})

0 comments on commit d4be8b5

Please sign in to comment.