Skip to content

Commit

Permalink
Add test for scan logp with multiple valued output types
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 11, 2024
1 parent db0b218 commit b06d6c3
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/logprob/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import itertools

import numpy as np
import pytensor
Expand Down Expand Up @@ -502,3 +503,51 @@ def ref_logp(values, rho, sigma):
logp_expr.eval({ma2_vv: ma2_test, rho: rho_test, sigma: sigma_test}),
ref_logp(ma2_test, rho_test, sigma_test),
)


@pytest.mark.xfail(reason="Not implemented yet")
def test_scan_multiple_output_types():
"""Test we can derive the logp for a scan that contains recurring and non-recurring measurable outputs."""
[xs, ys, zs], _ = pytensor.scan(
fn=lambda x_mu, y_tm1, z_tm2, z_tm1: (
pt.random.normal(x_mu),
pt.random.normal(y_tm1),
pt.random.normal(z_tm1) + z_tm2,
),
sequences=[pt.arange(10)],
outputs_info=[
None,
pt.zeros(()),
dict(initial=pt.ones(2), taps=[-2, -1]),
],
)

xs.name = "xs"
xs_value = xs.clone()
ys.name = "ys"
ys_value = ys.clone()
zs.name = "zs"
zs_value = zs.clone()

logp_dict = conditional_logp({xs: xs_value, ys: ys_value, zs: zs_value})
xs_logp = logp_dict[xs_value]
ys_logp = logp_dict[ys_value]
zs_logp = logp_dict[zs_value]

assert_no_rvs([xs_logp, ys_logp, zs_logp])
fn = pytensor.function(
[xs_value, ys_value, zs_value],
[xs_logp, ys_logp, zs_logp],
)

rng = np.random.default_rng(577)
test_value = rng.uniform(size=(10,))
(xs_logp_eval, ys_logp_eval, zs_logp_eval) = fn(test_value, test_value, test_value)
np.testing.assert_allclose(xs_logp_eval, stats.norm.logpdf(test_value, np.arange(10)))
np.testing.assert_allclose(ys_logp_eval, stats.norm.logpdf(test_value, [0, *test_value[:-1]]))
np.testing.assert_allclose(
zs_logp_eval,
stats.norm.logpdf(
test_value, [a + b for a, b in itertools.pairwise([1, 1, *test_value[:-1]])]
),
)

0 comments on commit b06d6c3

Please sign in to comment.