Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas3156 committed Jul 19, 2023
1 parent 41cd3ab commit 74f42c3
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions tests/logprob/test_censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,36 @@ def test_rounding(rounding_op):
logprob.eval({xr_vv: test_value}),
expected_logp,
)


def test_switch_encoding_both_branches():
x_rv = pt.random.normal(0.5, 1)
y_rv = pt.switch(x_rv < 0.3, 1, 2)

y_vv = y_rv.clone()
ref_scipy = st.norm(0.5, 1)

logprob = logp(y_rv, y_vv)
logp_fn = pytensor.function([y_vv], logprob)

assert logp_fn(3) == -np.inf

assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.3))
assert np.isclose(logp_fn(2), ref_scipy.logsf(0.3))


@pytest.mark.skip(reason="Logprob calculation for measurable branches not added")
def test_switch_encoding_second_branch_measurable():
x_rv = pt.random.normal(0.5, 1)
y_rv = pt.switch(x_rv < 0.3, 1, x_rv)

y_vv = y_rv.clone()
ref_scipy = st.norm(0.5, 1)

logprob = logp(y_rv, y_vv)
logp_fn = pytensor.function([y_vv], logprob)

assert logp_fn(3) == -np.inf

assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.3))
assert np.isclose(logp_fn(0.2), -np.inf)

0 comments on commit 74f42c3

Please sign in to comment.