Skip to content

Commit

Permalink
Fix duplicate "tune" stat in McBackend adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Jul 24, 2023
1 parent 38e87e2 commit 65eb592
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
BlockedStep,
CompoundStep,
StatsBijection,
check_step_emits_tune,
flat_statname,
flatten_steps,
)
Expand Down Expand Up @@ -207,11 +208,10 @@ def make_runmeta_and_point_fn(
) -> Tuple[mcb.RunMeta, PointFunc]:
variables, point_fn = get_variables_and_point_fn(model, initial_point)

sample_stats = [
mcb.Variable("tune", "bool"),
]
check_step_emits_tune(step)

# In PyMC the sampler stats are grouped by the sampler.
sample_stats = []
steps = flatten_steps(step)
for s, sm in enumerate(steps):
for statname, (dtype, shape) in sm.stats_dtypes_shapes.items():
Expand Down
2 changes: 1 addition & 1 deletion tests/backends/test_mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_make_runmeta_and_point_fn(simple_model):
assert not vars["vector"].is_deterministic
assert not vars["vector_interval__"].is_deterministic
assert vars["matrix"].is_deterministic
assert len(rmeta.sample_stats) == 1 + len(step.stats_dtypes[0])
assert len(rmeta.sample_stats) == len(step.stats_dtypes[0])
pass


Expand Down

0 comments on commit 65eb592

Please sign in to comment.