Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement transforms for discrete variables #6102

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 6, 2022

I am marking this as a draft, as I might very well be shooting myself in the foot... but it seems like it should work?

import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import pymc as pm

with pm.Model() as m:
    pm.DiscreteUniform("w", 0, 10, transform=None)
    pm.DiscreteUniform("wt", 0, 10)
    
    pm.Poisson("x", 5, transform=None)    
    pm.Poisson("xt", 5)
      
    pm.Geometric("y", 0.2, transform=None)
    pm.Geometric("yt", 0.2)
    
    pm.HyperGeometric("z", N=30, k=12, n=20, transform=None)
    pm.HyperGeometric("zt", N=30, k=12, n=20)
            
    trace = pm.sample(draws=5_000, chains=4)
    
az.summary(trace, var_names=sorted(map(str, m.free_RVs)))

One trace

     mean     sd  hdi_3%  hdi_97%  ...  mcse_sd  ess_bulk  ess_tail  r_hat
w   5.015  3.177     0.0     10.0  ...    0.033    4555.0    4930.0    1.0
wt  4.984  3.174     0.0     10.0  ...    0.016   20072.0   20000.0    1.0
x   5.021  2.226     1.0      9.0  ...    0.024    4303.0    5023.0    1.0
xt  4.936  2.197     1.0      9.0  ...    0.021    6805.0    5489.0    1.0
y   4.912  4.323     1.0     13.0  ...    0.094    1553.0    1585.0    1.0
yt  4.947  4.447     1.0     13.0  ...    0.069    1943.0    1816.0    1.0
z   8.030  1.286     6.0     10.0  ...    0.014    4265.0    4769.0    1.0
zt  7.987  1.299     6.0     10.0  ...    0.012    6036.0    5622.0    1.0

Another trace

     mean     sd  hdi_3%  hdi_97%  ...  mcse_sd  ess_bulk  ess_tail  r_hat
w   5.006  3.168     0.0     10.0  ...    0.033    4689.0    4554.0    1.0
wt  4.995  3.162     0.0     10.0  ...    0.016   19678.0   19990.0    1.0
x   5.008  2.235     1.0      9.0  ...    0.025    4397.0    4587.0    1.0
xt  4.970  2.222     1.0      9.0  ...    0.021    6671.0    5428.0    1.0
y   4.906  4.368     1.0     13.0  ...    0.104    1445.0    1307.0    1.0
yt  5.018  4.515     1.0     13.0  ...    0.084    1940.0    2067.0    1.0
z   8.002  1.265     6.0     10.0  ...    0.013    4551.0    4953.0    1.0
zt  8.020  1.286     6.0     10.0  ...    0.012    5338.0    5512.0    1.0

I purposely chose parameters I thought would benefit from transforms. In general it seems to improve slightly. The outlier clearly being the DiscreteUniform which now has a 100% acceptance rate (which will probably drive Metropolis tuning crazy :P)

Checklist

Major / Breaking Changes

  • Add interval transform for discrete variables.
    • As with transformed continuous variables, the respective value variables are now named f"{var}_dinterval" when manually evaluating the model logp/dlogp.

Bugfixes / New features

  • ...

Docs / Maintenance

  • ...

@codecov
Copy link

codecov bot commented Sep 6, 2022

Codecov Report

Merging #6102 (da64bf4) into main (13e7c88) will decrease coverage by 31.84%.
The diff coverage is 64.28%.

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #6102       +/-   ##
===========================================
- Coverage   92.02%   60.18%   -31.84%     
===========================================
  Files          95       95               
  Lines       16262    16328       +66     
===========================================
- Hits        14965     9827     -5138     
- Misses       1297     6501     +5204     
Impacted Files Coverage Δ
pymc/distributions/distribution.py 63.59% <40.00%> (-33.48%) ⬇️
pymc/distributions/discrete.py 44.35% <58.82%> (-54.76%) ⬇️
pymc/distributions/transforms.py 68.98% <74.19%> (-30.39%) ⬇️

... and 66 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the discrete_transforms branch 2 times, most recently from ba811b8 to 6a02057 Compare September 6, 2022 12:03
@ricardoV94
Copy link
Member Author

Another use case showed up in https://discourse.pymc.io/t/variable-lag-and-slicing/12003 where proposing invalid values for discrete variables can cause the model to error out structurally: https://discourse.pymc.io/t/variable-lag-and-slicing/12003

@ricardoV94
Copy link
Member Author

Reviving again. This is useful to avoid gotchas where categorical models have a "structural role" in the model, such as being used in indexing. In which case the samplers will just crash with IndexErrors and stuff like that. See https://discourse.pymc.io/t/abc-function-works-independently-but-fails-in-pm-simulator/12470/13

I don't see a good reason not to have transforms by default the same way we do for Continuous variables

@ricardoV94
Copy link
Member Author

Another similar issue with discrete variables and indexing operations: #7066

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants