Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add var_names arg to PyMC compiled model #100

Open
fonnesbeck opened this issue Mar 25, 2024 · 3 comments
Open

Add var_names arg to PyMC compiled model #100

fonnesbeck opened this issue Mar 25, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@fonnesbeck
Copy link
Member

To accommodate the customization of variables to store in the trace, it would be helpful to have a similar argument for CompiledPyMCModel so that unwanted variables can be ignored by the trace.

@fonnesbeck fonnesbeck added the enhancement New feature or request label Mar 25, 2024
@btschroer
Copy link

btschroer commented Oct 22, 2024

👍 Yes, that would be really nice to have. I have a hierarchical model with lots of groups.

@aseyboldt , @fonnesbeck: Is it at all possible that my nutpie sampling is slowed down, because my trace has to store lots of group-specific random variables? I only need samples of the population parameters.

Can you suggest how to circumvent this problem until var_names is available as argument?

@aseyboldt
Copy link
Member

I agree that this would be nice to have (and it wouldn't be that hard to implement, it only needs some changes in the compile_pymc_model function. If someone wants to give it a go, I'd be glad to help).

If the model you are looking at is somewhat like the one you posted in the other thread, I'd be surprised if storing the trace is an issue though.

The simplest thing to get it faster is probably to switch to float32 (set the env variable PYTENSOR_FLAGS=floatX=float32).
You could also give running it on the gpu a go, if the dataset is large that might help a lot.

And then, I'd double check the parametrization, and make sure your predictors are not too correlated. An easy thing to check to see if that can help is to have a look at the "gradients/draw". If that is large (say > 30 or 15 or so), that means that there is probably quite some room for improvement. This number is pretty much proportional to the runtime if all other things are equal. So if you can get it from 100 to 10, that's a 10x speedup.

@btschroer
Copy link

Thanks @aseyboldt for the suggestion, that's really helpful. I also looked into compile_pymc_model but it was essentially just calling functions "compile to numba" or "compile to jax", I am assuming the change would need to be done there, right? One point that is not entirely clear to me is to why we need to exclude variables at "compilation time" I thought we'd still want to sample e.g. group specific random variables, but then exclude them when writing the trace, i.e. exclusion would happen at "runtime"?

I checked the sampling stats, cf. output below, but I wasn't able to find the "number of gradient evaluations", is it correct to assume that this corresponds to 'n_steps'?

sample_stats
<xarray.Dataset> Size: 500kB
Dimensions:               (chain: 6, draw: 1000)
Coordinates:
  * chain                 (chain) int64 48B 0 1 2 3 4 5
  * draw                  (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
    depth                 (chain, draw) uint64 48kB 4 4 4 4 4 4 ... 4 4 4 4 4 4
    maxdepth_reached      (chain, draw) bool 6kB False False ... False False
    index_in_trajectory   (chain, draw) int64 48kB -8 8 12 -3 7 ... -6 4 5 3 11
    logp                  (chain, draw) float64 48kB 1.633e+04 ... 1.631e+04
    energy                (chain, draw) float64 48kB -1.617e+04 ... -1.614e+04
    diverging             (chain, draw) bool 6kB False False ... False False
    energy_error          (chain, draw) float64 48kB -0.4984 0.3992 ... -0.6857
    step_size             (chain, draw) float64 48kB 0.3677 0.3677 ... 0.354
    step_size_bar         (chain, draw) float64 48kB 0.3677 0.3677 ... 0.354
    mean_tree_accept      (chain, draw) float64 48kB 0.989 0.7761 ... 0.7476
    mean_tree_accept_sym  (chain, draw) float64 48kB 0.7821 0.8038 ... 0.8219
    n_steps               (chain, draw) uint64 48kB 15 15 15 15 ... 15 15 15 15
Attributes:
    created_at:     2024-10-22T11:45:06.969615+00:00
    arviz_version:  0.20.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants