Skip to content

Commit

Permalink
Remove stale xfail scan test
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 11, 2024
1 parent 585962d commit db0b218
Showing 1 changed file with 0 additions and 52 deletions.
52 changes: 0 additions & 52 deletions tests/logprob/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,58 +388,6 @@ def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t):
assert np.allclose(y_logp_val, y_logp_ref_val)


@pytest.mark.xfail(reason="see #148")
@pytensor.config.change_flags(compute_test_value="raise")
@pytest.mark.xfail(reason="see #148")
def test_initial_values():
srng = pt.random.RandomStream(seed=2320)

p_S_0 = np.array([0.9, 0.1])
S_0_rv = srng.categorical(p_S_0, name="S_0")
S_0_rv.tag.test_value = 0

Gamma_at = pt.matrix("Gamma")
Gamma_at.tag.test_value = np.array([[0, 1], [1, 0]])

s_0_vv = S_0_rv.clone()
s_0_vv.name = "s_0"

def step_fn(S_tm1, Gamma):
S_t = srng.categorical(Gamma[S_tm1], name="S_t")
return S_t

S_1T_rv, _ = pytensor.scan(
fn=step_fn,
outputs_info=[{"initial": S_0_rv, "taps": [-1]}],
non_sequences=[Gamma_at],
strict=True,
n_steps=10,
name="S_0T",
)

S_1T_rv.name = "S_1T"
s_1T_vv = S_1T_rv.clone()
s_1T_vv.name = "s_1T"

logp_parts = conditional_logp({S_1T_rv: s_1T_vv, S_0_rv: s_0_vv})

s_0_val = 0
s_1T_val = np.array([1, 0, 1, 0, 1, 1, 0, 1, 0, 1])
Gamma_val = np.array([[0.1, 0.9], [0.9, 0.1]])

exp_res = np.log(p_S_0[s_0_val])
s_prev = s_0_val
for s in s_1T_val:
exp_res += np.log(Gamma_val[s_prev, s])
s_prev = s

S_0T_logp = sum(v.sum() for v in logp_parts.values())
S_0T_logp_fn = pytensor.function([s_0_vv, s_1T_vv, Gamma_at], S_0T_logp)
res = S_0T_logp_fn(s_0_val, s_1T_val, Gamma_val)

assert res == pytest.approx(exp_res)


@pytest.mark.parametrize("remove_asserts", (True, False))
def test_mode_is_kept(remove_asserts):
mode = Mode().including("local_remove_all_assert") if remove_asserts else None
Expand Down

0 comments on commit db0b218

Please sign in to comment.