Skip to content

Commit

Permalink
fix returned object when no vars to sample and extend=True
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril authored and ricardoV94 committed Sep 18, 2024
1 parent 4300be1 commit b9fbfed
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def sample_posterior_predictive(
if return_inferencedata and not extend_inferencedata:
return InferenceData()
elif return_inferencedata and extend_inferencedata:
return trace
return trace if idata is None else idata
return {}

vars_in_trace = get_vars_in_point_list(_trace, model)
Expand Down
5 changes: 5 additions & 0 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ def test_normal_scalar(self):
ppc = pm.sample_posterior_predictive(trace, var_names=[], return_inferencedata=False)
assert len(ppc) == 0

# test empty ppc with extend_inferencedata
assert isinstance(trace, InferenceData)
ppc = pm.sample_posterior_predictive(trace, var_names=[], extend_inferencedata=True)
assert ppc is trace

# test keep_size parameter
ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False)
assert ppc["a"].shape == (nchains, ndraws)
Expand Down

0 comments on commit b9fbfed

Please sign in to comment.