Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DiscreteMarkovChain distribution #100

Merged
merged 44 commits into from
Apr 20, 2023

Conversation

jessegrabowski
Copy link
Member

Add a distribution for discrete-state Markov chains. Support for deterministic and random initial conditions. Motivated by two recent threads on the discourse asking about this type of model: here and here. This work is derivative of the work presented by @junpenglao and @ricardoV94 in those threads.

Still a lot of work to do so I'm marking this as a draft. In no particular order:

  1. I use .eval() methods in several places to validate inputs, this strikes me as bad but I didn't know what else to do.
  2. Using dims currently breaks the model
  3. The time series dimension is currently specified by the steps argument, but it seems more natural if it were instead set by the size or dims argument (this is related to why dims breaks the model i think)
  4. pm.sample automatically assigns the markov chain RV to pm.Metropolis, don't know how to direct it to BinaryMetropolis or CategoricalMetropolis, depending on the size of the state space.
  5. There's some janky stuff with the steps argument. I internally subtract 1 to account for the fact that x0 will be appended to the scan -- otherwise the resulting chain is length steps + 1, which I don't think will match the user expectation. It gives the right answer but it doesn't feel very clean.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks amazing 🤩 Left some comments below that hopefully you find helpful.

pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

For the sampler, I think there's a way to make it so your distribution goes with Categorical sampler by default.

@ricardoV94 ricardoV94 added the enhancements New feature or request label Dec 17, 2022
validate `P` in `logp` via `check_parameters`
Create `test_discrete_markov_chain.py`
remove shape checks on x0
add `init_dist` argument
add support for `dtype` kwarg
`dims` argument now works

Add tests associated with `dims`
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great progress. Some comments/requests below.

Would be so cool to implement logp marginalization of these variables in a future PR! Should be as simple as putting the right DiscreteMarginalizedRV inside Scan? Should come out the same as the forward algorithm.

https://github.com/pymc-devs/pymc-experimental/blob/main/pymc_experimental/marginal_model.py

Another future improvement would be multiple lag dependencies?

Anyway I am getting carried away, this is already very exciting!

pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Show resolved Hide resolved
@ricardoV94
Copy link
Member

ricardoV94 commented Dec 18, 2022

For the sampler, I think there's a way to make it so your distribution goes with Categorical sampler by default.

Speaking of bad APIs...

Here is a solution? Like the test added here you can override the default list of step methods: pymc-devs/pymc@2dd4c8c

You could add a subclass of the categorical metropolis sampler that assigns ideal competence to this variable.

We should definitely refactor that approach, perhaps with dispatching, but for now that is the step assignment API for external libraries...

@jessegrabowski
Copy link
Member Author

For the sampler, I think there's a way to make it so your distribution goes with Categorical sampler by default.

Speaking of bad APIs...

Here is a solution? Like the test added here you can override the default list of step methods: pymc-devs/pymc@2dd4c8c

You could add a subclass of the categorical metropolis sampler that assigns ideal competence to this variable.

We should definitely refactor that approach, perhaps with dispatching, but for now that is the step assignment API for external libraries...

I was thinking if I could subclass from Categorical then it would be automatically assigned correctly, but I got some errors doing this.

@jessegrabowski
Copy link
Member Author

re: multiple lag dependencies, I think that would be nice as well and should be easy enough to add.
re: Marginalization, I was really hoping this would integrate with that PR as well. The HMM model in particular is quite difficult to sample (see the example notebook as well as discussion here), and I wonder if automatic marginalization could help the sampler along.

@ricardoV94
Copy link
Member

Marginalizing HMMs usually sample much better in my experience. We would need to increase the marginalization to support it, though. Right now it only works for vanilla discrete variables, not scans or this specific wrapped scan op either.

Remove unnecessary type handling on distributions.
Add `initval='prior'` as a default argument to `__new__`
@jessegrabowski
Copy link
Member Author

jessegrabowski commented Dec 18, 2022

One of the scan re-writes is throw an error now, but only if I compile the model outside of a model context (e.g. with pm.draw). Curious if you know what might be causing it?

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: save_mem_new_scan
ERROR (pytensor.graph.rewriting.basic): node: for{cpu,scan_fn}(TensorConstant{(1,) of 9}, IncSubtensor{Set;:int64:}.0, RandomGeneratorSharedVariable(<Generator(PCG64) at 0x151E8EE40>), Softmax{axis=1}.0)
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/Users/jessegrabowski/mambaforge/envs/pymc_dev/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1933, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/Users/jessegrabowski/mambaforge/envs/pymc_dev/lib/python3.9/site-packages/pytensor/graph/rewriting/basic.py", line 1092, in transform
    return self.fn(fgraph, node)
  File "/Users/jessegrabowski/mambaforge/envs/pymc_dev/lib/python3.9/site-packages/pytensor/scan/rewriting.py", line 1628, in save_mem_new_scan
    subtens = Subtensor(nw_slice)
  File "/Users/jessegrabowski/mambaforge/envs/pymc_dev/lib/python3.9/site-packages/pytensor/tensor/subtensor.py", line 692, in __init__
    self.idx_list = tuple(map(index_vars_to_types, idx_list))
  File "/Users/jessegrabowski/mambaforge/envs/pymc_dev/lib/python3.9/site-packages/pytensor/tensor/subtensor.py", line 592, in index_vars_to_types
    slice_a = index_vars_to_types(a, False)
  File "/Users/jessegrabowski/mambaforge/envs/pymc_dev/lib/python3.9/site-packages/pytensor/tensor/subtensor.py", line 613, in index_vars_to_types
    raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
pytensor.tensor.exceptions.AdvancedIndexingError: Invalid index type or slice for Subtensor

EDIT:: This was caused by ndims = 1 in steps = at.as_tensor_variable(intX(steps), ndims=1). I copied that from the AR distribution, not sure why it is causing a problem here.

@jessegrabowski jessegrabowski marked this pull request as ready for review December 18, 2022 20:59
@jessegrabowski
Copy link
Member Author

jessegrabowski commented Dec 18, 2022

I marked it as ready for review, although the problem of sampler assignment still isn't solved. If it was OK'd for the main PyMC code base, I'd just add the distribution to the list of competencies for CategoricalGibbsMetropolis I guess.

Another small issue is that pm.model_to_graphviz doesn't seem to recognize dependency between the hidden state chain and a sequence of state means when I did the HMM example.

Also still need to add support for multiple lags.

pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/timeseries.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

Also still need to add support for multiple lags.

That need not be a blocker unless you want it to

@jessegrabowski
Copy link
Member Author

I think I addressed pretty much everything. I changed an import in marginal_model.py because I did git rebase on my fork's main but not the branch i set up for this PR; I guess that was pretty stupid. I needed to change the imports for logp. I hope it's not too much of a hassle to fix it.

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 17, 2023

@jessegrabowski I am happy code and test-wise. Do you need help fixing the conflicts and rebasing or are you up to it? We can merge aftewards

@twiecki
Copy link
Member

twiecki commented Apr 17, 2023

The title of the NB needs to be updated.

@jessegrabowski
Copy link
Member Author

The NB was a bit of an afterthought, just to show everyone the distribution works as intended. I could doll it up a bit and make a separate PR into pymc-examples?

@twiecki
Copy link
Member

twiecki commented Apr 17, 2023

@jessegrabowski I updated my comment after I saw that this was pymc-experimental, not pymc. So I think a NB here is a good idea, and you can later doll it up for pymc-examples. But the title needs to be fixed.

docs/api_reference.rst Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 merged commit 666ac8c into pymc-devs:main Apr 20, 2023
@ricardoV94 ricardoV94 changed the title add DiscreteMarkovChainRV Add DiscreteMarkovChain distribution Apr 20, 2023
@jessegrabowski jessegrabowski deleted the discrete-markov branch September 17, 2023 16:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants