diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4ea4da55ca..c776e3c7fb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,6 +27,8 @@ make format # runs isort make test # linting and unit tests ``` +`make test` needs `pip install funsor` and `brew install graphviz`. + If you've modified core pyro code, examples, or tutorials, you can run more comprehensive tests locally (after first adding any new files to the appropriate `tests/` script) ```sh make test-examples # test examples/ diff --git a/pyro/infer/discrete.py b/pyro/infer/discrete.py index a89294abd2..f95ae972ae 100644 --- a/pyro/infer/discrete.py +++ b/pyro/infer/discrete.py @@ -32,7 +32,7 @@ class SamplePosteriorMessenger(ReplayMessenger): # This acts like ReplayMessenger but additionally replays cond_indep_stack. def _pyro_sample(self, msg): - if msg["infer"].get("enumerate") == "parallel": + if msg["infer"].get("enumerate") in ["parallel", "sequential"]: super()._pyro_sample(msg) if msg["name"] in self.trace: msg["cond_indep_stack"] = self.trace.nodes[msg["name"]]["cond_indep_stack"] diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 1046e30910..384b28c109 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -123,7 +123,7 @@ def enumerate_site(msg): class EnumMessenger(Messenger): """ Enumerates in parallel over discrete sample sites marked - ``infer={"enumerate": "parallel"}``. + ``infer={"enumerate": "parallel"}`` or ``infer={"enumerate": "sequential"}``. :param int first_available_dim: The first tensor dimension (counting from the right) that is available for parallel enumeration. This @@ -166,7 +166,10 @@ def _pyro_sample(self, msg): param_dims.update(self._value_dims[name]) self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"] self._param_dims[msg["name"]] = param_dims - if msg["is_observed"] or msg["infer"].get("enumerate") != "parallel": + if msg["is_observed"] or msg["infer"].get("enumerate") not in [ + "parallel", + "sequential", + ]: return # Compute an enumerated value (at an arbitrary dim). diff --git a/tests/infer/test_discrete.py b/tests/infer/test_discrete.py index 416ecba73e..e6cffd3edb 100644 --- a/tests/infer/test_discrete.py +++ b/tests/infer/test_discrete.py @@ -317,6 +317,67 @@ def model(num_particles=1, z=None): assert_equal(actual_z_mean, expected_z_mean, prec=1e-2) +@pytest.mark.parametrize( + "infer,temperature,enum", + [ + (infer_discrete, 0, ("sequential", "parallel")), + (infer_discrete, 1, ("sequential", "parallel")), + pytest.param( + infer_discrete, + 0, + ("parallel", "other"), + marks=pytest.mark.xfail(reason="expected failed case without this fix"), + ), + pytest.param( + infer_discrete, + 0, + ("sequential", "other"), + marks=pytest.mark.xfail(reason="expected failed case without this fix"), + ), + pytest.param( + infer_discrete, + 1, + ("parallel", "other"), + marks=pytest.mark.xfail(reason="expected failed case without this fix"), + ), + pytest.param( + infer_discrete, + 1, + ("sequential", "other"), + marks=pytest.mark.xfail(reason="expected failed case without this fix"), + ), + ], +) +def test_enum(infer, temperature, enum): + assert len(enum) == 2 + first_available_dim = -1 + p = torch.tensor(0.5) + + @config_enumerate + def model(x_ch_obs=None, enum_option=None): + y = pyro.sample( + "y_pre", + dist.Binomial(probs=p, total_count=1), + infer={"enumerate": enum_option}, + ) + d_ch = dist.Normal(y, 1.0) + pyro.sample("x_ch_pre", d_ch, obs=x_ch_obs) + return y + + y_posts = [] + for enum_option in enum: + data_obs = {"x_ch_obs": torch.tensor(1.0), "enum_option": enum_option} + model_discrete = infer_discrete( + model, first_available_dim=first_available_dim, temperature=temperature + ) + y_post = [] + for ii in range(10**4): + y_post.append(model_discrete(**data_obs)) + smpl = torch.stack(y_post) + y_posts.append(smpl.mean()) + assert_equal(y_posts[0], y_posts[1], prec=1e-2) + + @pytest.mark.parametrize("length", [1, 2, 10, 100]) @pytest.mark.parametrize( "infer,temperature", diff --git a/tests/infer/test_enum.py b/tests/infer/test_enum.py index be176c9ca9..f9edd17b81 100644 --- a/tests/infer/test_enum.py +++ b/tests/infer/test_enum.py @@ -118,15 +118,18 @@ def guide(): def test_enumerate_sequential_model(): + values = [] + def model(): - pyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "sequential"}) + x = pyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "sequential"}) + values.append(x) def guide(): pass - with pytest.raises(NotImplementedError): - elbo = TraceEnum_ELBO(max_plate_nesting=0) - elbo.loss(model, guide) + elbo = TraceEnum_ELBO(max_plate_nesting=0) + elbo.loss(model, guide) + assert len(values) == 1, values # The usual dist.Bernoulli avoids NANs by clamping log prob. This unsafe version diff --git a/tests/infer/test_valid_models.py b/tests/infer/test_valid_models.py index 3c6cfa8cdc..57ab412d98 100644 --- a/tests/infer/test_valid_models.py +++ b/tests/infer/test_valid_models.py @@ -1570,12 +1570,7 @@ def model(): def guide(): pass - assert_error( - model, - guide, - TraceEnum_ELBO(max_plate_nesting=0), - match="At site .*, model-side sequential enumeration is not implemented", - ) + assert_ok(model, guide, TraceEnum_ELBO(max_plate_nesting=0)) def test_enum_in_model_plate_reuse_ok():