diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index 133a3ca965..a0efaec075 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -33,6 +33,7 @@ BlockedStep, CompoundStep, StatsBijection, + check_step_emits_tune, flat_statname, flatten_steps, ) @@ -207,11 +208,10 @@ def make_runmeta_and_point_fn( ) -> Tuple[mcb.RunMeta, PointFunc]: variables, point_fn = get_variables_and_point_fn(model, initial_point) - sample_stats = [ - mcb.Variable("tune", "bool"), - ] + check_step_emits_tune(step) # In PyMC the sampler stats are grouped by the sampler. + sample_stats = [] steps = flatten_steps(step) for s, sm in enumerate(steps): for statname, (dtype, shape) in sm.stats_dtypes_shapes.items(): diff --git a/tests/backends/test_mcbackend.py b/tests/backends/test_mcbackend.py index 2e3693c785..2989030fa7 100644 --- a/tests/backends/test_mcbackend.py +++ b/tests/backends/test_mcbackend.py @@ -119,7 +119,7 @@ def test_make_runmeta_and_point_fn(simple_model): assert not vars["vector"].is_deterministic assert not vars["vector_interval__"].is_deterministic assert vars["matrix"].is_deterministic - assert len(rmeta.sample_stats) == 1 + len(step.stats_dtypes[0]) + assert len(rmeta.sample_stats) == len(step.stats_dtypes[0]) pass