From b06d6c38908b00f0600b453fc74fd93014c91132 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 6 Sep 2024 23:16:54 +0200 Subject: [PATCH] Add test for scan logp with multiple valued output types --- tests/logprob/test_scan.py | 49 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index edd6ee51a5..a9e64e459c 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -33,6 +33,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import itertools import numpy as np import pytensor @@ -502,3 +503,51 @@ def ref_logp(values, rho, sigma): logp_expr.eval({ma2_vv: ma2_test, rho: rho_test, sigma: sigma_test}), ref_logp(ma2_test, rho_test, sigma_test), ) + + +@pytest.mark.xfail(reason="Not implemented yet") +def test_scan_multiple_output_types(): + """Test we can derive the logp for a scan that contains recurring and non-recurring measurable outputs.""" + [xs, ys, zs], _ = pytensor.scan( + fn=lambda x_mu, y_tm1, z_tm2, z_tm1: ( + pt.random.normal(x_mu), + pt.random.normal(y_tm1), + pt.random.normal(z_tm1) + z_tm2, + ), + sequences=[pt.arange(10)], + outputs_info=[ + None, + pt.zeros(()), + dict(initial=pt.ones(2), taps=[-2, -1]), + ], + ) + + xs.name = "xs" + xs_value = xs.clone() + ys.name = "ys" + ys_value = ys.clone() + zs.name = "zs" + zs_value = zs.clone() + + logp_dict = conditional_logp({xs: xs_value, ys: ys_value, zs: zs_value}) + xs_logp = logp_dict[xs_value] + ys_logp = logp_dict[ys_value] + zs_logp = logp_dict[zs_value] + + assert_no_rvs([xs_logp, ys_logp, zs_logp]) + fn = pytensor.function( + [xs_value, ys_value, zs_value], + [xs_logp, ys_logp, zs_logp], + ) + + rng = np.random.default_rng(577) + test_value = rng.uniform(size=(10,)) + (xs_logp_eval, ys_logp_eval, zs_logp_eval) = fn(test_value, test_value, test_value) + np.testing.assert_allclose(xs_logp_eval, stats.norm.logpdf(test_value, np.arange(10))) + np.testing.assert_allclose(ys_logp_eval, stats.norm.logpdf(test_value, [0, *test_value[:-1]])) + np.testing.assert_allclose( + zs_logp_eval, + stats.norm.logpdf( + test_value, [a + b for a, b in itertools.pairwise([1, 1, *test_value[:-1]])] + ), + )