Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a method to compile an aesara forward sampling function #5759

Merged
merged 1 commit into from
May 17, 2022

Conversation

lucianopaz
Copy link
Contributor

@lucianopaz lucianopaz commented May 13, 2022

Closes #5302

This PR adds a low level method to compile the aesara function responsible for generating forward samples. It introduces the notion of volatile variables in a graph, which allows it to determine which values it should take from an inferred posterior during forward sampling. Volatile variables are variables whose values could change between a run of pm.sample and future calls to pm.sample_posterior_predictive. These are:

  • Variables in the outputs list (the ones that should actually be returned by the compiled function)
  • SharedVariable instances
  • RandomVariable instances that have volatile parameters
  • RandomVariables that are not in the vars_in_trace list (this means that they don't have values in the inferred posterior, so they should be drawn from their a priori assumed distribution)
  • Variables in the givens_dict, because setting values through givens is considered disruptive with respect to the values the variable could have taken during pm.sample.
  • Any other type of variable that depends on volatile variables.

Tasks that remain before this PR is ready to merge:

  • Use compile_forward_sampling_function in sample_prior_predictive
    - [ ] Use compile_forward_sampling_function in sample_posterior_predictive_w

@ricardoV94
Copy link
Member

sample_posterior_predictive_w is not refactored yet so we don't need to incorporate it yet

@codecov
Copy link

codecov bot commented May 13, 2022

Codecov Report

Merging #5759 (483dbfb) into main (ab593b1) will increase coverage by 0.00%.
The diff coverage is 97.87%.

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #5759   +/-   ##
=======================================
  Coverage   89.27%   89.27%           
=======================================
  Files          74       74           
  Lines       13803    13813   +10     
=======================================
+ Hits        12322    12332   +10     
  Misses       1481     1481           
Impacted Files Coverage Δ
pymc/sampling.py 88.40% <97.50%> (+0.23%) ⬆️
pymc/util.py 77.24% <100.00%> (+0.99%) ⬆️
pymc/model.py 86.37% <0.00%> (-0.19%) ⬇️
pymc/sampling_jax.py 96.96% <0.00%> (ø)
pymc/aesaraf.py 91.62% <0.00%> (+0.15%) ⬆️

pymc/aesaraf.py Outdated Show resolved Hide resolved
pymc/aesaraf.py Outdated Show resolved Hide resolved
pymc/aesaraf.py Outdated Show resolved Hide resolved
pymc/aesaraf.py Outdated Show resolved Hide resolved
pymc/util.py Show resolved Hide resolved
pymc/aesaraf.py Outdated Show resolved Hide resolved
pymc/tests/test_aesaraf.py Outdated Show resolved Hide resolved
pymc/aesaraf.py Outdated Show resolved Hide resolved
@lucianopaz lucianopaz force-pushed the compile_ppc_sampler branch from 67c2a13 to 75a3b4a Compare May 13, 2022 09:26
pymc/aesaraf.py Outdated Show resolved Hide resolved
@lucianopaz lucianopaz force-pushed the compile_ppc_sampler branch 2 times, most recently from 7ce1a0f to 10e06bc Compare May 16, 2022 08:26
@lucianopaz lucianopaz marked this pull request as ready for review May 16, 2022 08:26
@lucianopaz
Copy link
Contributor Author

@ricardoV94, this is ready for review now. We'll address the sample_posterior_predictive_w function in another PR

pymc/tests/test_sampling.py Outdated Show resolved Hide resolved
pymc/tests/test_sampling.py Outdated Show resolved Hide resolved
pymc/tests/test_sampling.py Outdated Show resolved Hide resolved
pymc/tests/test_sampling.py Outdated Show resolved Hide resolved
pymc/tests/test_sampling.py Outdated Show resolved Hide resolved
pymc/tests/test_sampling.py Outdated Show resolved Hide resolved
pymc/tests/test_sampling.py Outdated Show resolved Hide resolved
pymc/util.py Outdated Show resolved Hide resolved
pymc/tests/test_aesaraf.py Outdated Show resolved Hide resolved
pymc/distributions/distribution.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

Looks good, just trying to push to simplify code as much as possible. Left some comments above @lucianopaz

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested some simplifications locally, do you spot any issues?

pymc/sampling.py Outdated Show resolved Hide resolved
pymc/sampling.py Outdated Show resolved Hide resolved
pymc/sampling.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 added this to the v4.0.0 milestone May 16, 2022
@lucianopaz lucianopaz force-pushed the compile_ppc_sampler branch from 10e06bc to c346391 Compare May 17, 2022 06:56
pymc/sampling.py Outdated Show resolved Hide resolved
@lucianopaz lucianopaz force-pushed the compile_ppc_sampler branch from c346391 to b27a40c Compare May 17, 2022 15:09
@lucianopaz lucianopaz force-pushed the compile_ppc_sampler branch from b27a40c to 483dbfb Compare May 17, 2022 16:05
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome @lucianopaz! This was actually one of the trickiest blockers for V4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Deterministics not resampled in posterior predictive
2 participants