Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DiscreteMarkovChain distribution #100

Merged
merged 44 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
c42c94a
add DiscreteMarkovChainRV
jessegrabowski Dec 16, 2022
fd472bd
remove `validate_transition_matrix`
jessegrabowski Dec 17, 2022
37ef8e0
remove `x0` argument
jessegrabowski Dec 17, 2022
b4e15db
Add reshape logic to `rv_op` based on size and `init_dist`
jessegrabowski Dec 17, 2022
a3408ba
Remove moot TODO comments
jessegrabowski Dec 17, 2022
22bfe17
Update and re-run example notebook
jessegrabowski Dec 18, 2022
3aef66d
Update `pytensor` alias to `pt`
jessegrabowski Dec 18, 2022
c2d5fc6
Remove moment method
jessegrabowski Dec 18, 2022
d77337c
Wrap tests into test class
jessegrabowski Dec 18, 2022
82978f0
Add test for default initial distribution warning
jessegrabowski Dec 18, 2022
b797af3
Replace `.dimshuffle` with `pt.moveaxis` in `rv_op`
jessegrabowski Dec 18, 2022
d827586
Fix scan error
jessegrabowski Dec 18, 2022
0267801
Add test for `change_dist_size`
jessegrabowski Dec 18, 2022
b7794ed
Add code example to `DiscreteMarkovChain` docstring
jessegrabowski Dec 18, 2022
ab00e6c
Update pymc_experimental/distributions/timeseries.py
jessegrabowski Dec 18, 2022
bfb6b43
Remove shape argument from default `init_dist`
jessegrabowski Dec 18, 2022
ad3a878
Use shape parameter in example code
jessegrabowski Dec 18, 2022
7912d02
Remove steps adjustment from `__new__`
jessegrabowski Dec 18, 2022
85233f6
Remove `.squeeze()` from scan output
jessegrabowski Dec 18, 2022
b42184d
Remove dimension check on `init_dist`
jessegrabowski Dec 18, 2022
4773aa8
Fix batch size detection
jessegrabowski Dec 19, 2022
a2b79f2
Add support for n_lags > 1
jessegrabowski Dec 20, 2022
7fb1e56
Fix shape of markov_chain when P has a batch_size and n_lags == 1.
jessegrabowski Dec 20, 2022
847a2f9
Add test to recover P when n_lags > 1
jessegrabowski Dec 20, 2022
c52191e
Fix `logp` shape checking wrong dimension of `init_dist`
jessegrabowski Dec 20, 2022
ed2983f
Updates imports following pymc-devs/pymc#6441
jessegrabowski Apr 15, 2023
bbaa81e
Add a moment function to `DiscreteMarkovRV`
jessegrabowski Apr 15, 2023
9ac136b
Raise `NotImplementedError` if `init_dist` is not `pm.Categorical`
jessegrabowski Apr 15, 2023
2673b12
Update example notebook with some new plots
jessegrabowski Apr 15, 2023
b2df5a2
Fix a bug that broke `n_lags` > 1
jessegrabowski Apr 15, 2023
daffdc8
Rename test function to correctly match test
jessegrabowski Apr 16, 2023
c35bc31
rebase from main
jessegrabowski Apr 17, 2023
170e20b
Add `timeseries.DiscreteMarkovChain` to `api_reference.rst`
jessegrabowski Apr 17, 2023
eb23686
Remove check on `init_dist`
jessegrabowski Apr 17, 2023
c24a86d
Add `DiscreteMarkovChain` to `distribtuions.__all__`
jessegrabowski Apr 17, 2023
ee141ef
Change example notebook title, add subtitles, add plots comparing res…
jessegrabowski Apr 17, 2023
f138cac
Pass `init_dist` to all tests to avoid `UserWarning`
jessegrabowski Apr 17, 2023
66a3198
Fix flakey `test_moment_function` test
jessegrabowski Apr 17, 2023
10e2817
Fix latex in docstring
jessegrabowski Apr 17, 2023
7ef7ae6
Apply suggestions from code review
jessegrabowski Apr 17, 2023
0624d2d
Fix latex in docstring
jessegrabowski Apr 17, 2023
f07cdba
Merge branch 'discrete-markov' of https://github.com/jessegrabowski/p…
jessegrabowski Apr 17, 2023
994e3fb
Fix latex in docstring
jessegrabowski Apr 17, 2023
3c96dc8
Fix warning in docstring
jessegrabowski Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,096 changes: 1,096 additions & 0 deletions notebooks/discrete_markov_chain.ipynb

Large diffs are not rendered by default.

198 changes: 198 additions & 0 deletions pymc_experimental/distributions/timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import pymc as pm
import numpy as np
import pytensor.tensor as at
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
import pytensor

from pymc.distributions.distribution import (
Distribution,
Discrete,
SymbolicRandomVariable,
_moment,
)

from pymc.logprob.abstract import _logprob
from pymc.pytensorf import intX
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from pytensor.graph.basic import Node

from pymc.distributions.shape_utils import (
_change_dist_size,
get_support_shape_1d,
)


def validate_transition_matrix(P):
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
"""
Checks that P is a valid transition matrix
"""

# TODO: Can this eval be avoided?
n, k = P.shape.eval()
if n != k:
raise ValueError(f'P must be square, found shape ({n}, {k})')

row_sums_all_one = at.allclose(P.sum(axis=1), 1.0)
if not row_sums_all_one.eval():
raise ValueError('All rows of P must sum to 1.')


class DiscreteMarkovChainRV(SymbolicRandomVariable):
default_output = 1
_print_name = ('DiscreteMC', '\\operatorname{DiscreteMC}')

def update(self, node: Node):
# TODO: Do I need this?
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
return {node.inputs[-1]: node.outputs[0]}


class DiscreteMarkovChain(Distribution):
r"""
A Discrete Markov Chain is a sequence of random variables
.. math::
\{x_t\}_{t=0}^T
Where transition probability P(x_t | x_{t-1}) depends only on the state of the system at x_{t-1}.

Parameters
----------
P: tensor
Matrix of transition probabilities between states. Rows must sum to 1.
One of P or P_logits must be provided.
P_logit: tensor, Optional
Matrix of tranisiton logits. Converted to probabilities via Softmax activation.
One of P or P_logits must be provided.
steps: tensor
Length of the markov chain
x0: tensor or RandomVariable
Intial state of the system. If tensor, treated as deterministic.
"""
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

rv_type = DiscreteMarkovChainRV

def __new__(cls, *args, steps, **kwargs):
# TODO: Allow steps to be None and infer chain length from shape?
# TODO: Dims breaks the RV
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

# Subtract 1 step to account for x0 given, better match user expectation of
# len(markov_chain) = steps
steps -= 1

steps = get_support_shape_1d(
support_shape=steps,
shape=None,
dims=kwargs.get('dims', None),
observed=kwargs.get('observed', None),
support_shape_offset=1
)

return super().__new__(cls, *args, steps=steps, **kwargs)

@classmethod
def dist(cls, P=None, logit_P=None, steps=None, x0=None, **kwargs):
steps = get_support_shape_1d(
support_shape=steps, shape=kwargs.get('shape', None), support_shape_offset=1
)
if steps is None:
raise ValueError("Must specify steps or shape parameter")
if P is None and logit_P is None:
raise ValueError('Must specify P or logit_P parameter')
if P is not None and logit_P is not None:
raise ValueError('Must specify only one of either P or logit_P parameter')

if logit_P is not None:
P = pm.math.softmax(logit_P, axis=1)
P = at.as_tensor_variable(P)
validate_transition_matrix(P)

# TODO: Can this eval be avoided?
n_states = P.shape[0].eval()
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(x0, TensorVariable):
x0 = at.as_tensor_variable(x0).astype(intX)
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

# TODO: Can this eval be avoided?
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
if not at.all(at.lt(x0, n_states - 1)).eval():
raise ValueError('At least one initial state is larger than the number of states in the Markov Chain')

elif not isinstance(x0.owner.op, Discrete):
raise ValueError('x0 must be a discrete distribution')

else:
x0_probs = x0.owner.inputs[-1].eval()
n_cats = 1 if x0_probs.ndim == 0 else len(x0_probs)

if not n_cats <= n_states:
raise ValueError(
'x0 has support over a range of values larger than the number of states in the Markov Chain')

return super().dist([P, logit_P, steps, x0], **kwargs)
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def rv_op(cls, P, logit_P, steps, x0, size=None):
if size is not None:
batch_size = size
else:
batch_size = at.broadcast_shape(x0)
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

x0_ = x0.type()
P_ = P.type()
steps_ = steps.type()

state_rng = pytensor.shared(np.random.default_rng())

def transition(previous_state, transition_probs, old_rng):
p = transition_probs[previous_state]
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
return next_state, {old_rng: next_rng}

markov_chain, state_updates = pytensor.scan(transition,
non_sequences=[P_, state_rng],
outputs_info=[x0_],
n_steps=steps_,
strict=True)

(state_next_rng,) = tuple(state_updates.values())

discrete_mc_ = at.concatenate([x0_[None, ...], markov_chain], axis=0).dimshuffle(
tuple(range(1, markov_chain.ndim)) + (0,)
).squeeze()
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

discrete_mc_op = DiscreteMarkovChainRV(
inputs=[P_, x0_, steps_],
outputs=[state_next_rng, discrete_mc_],
ndim_supp=1,
)

discrete_mc = discrete_mc_op(P, x0, steps)
return discrete_mc


@_change_dist_size.register(DiscreteMarkovChainRV)
def change_mc_size(op, dist, new_size, expand=False):
if expand:
old_size = dist.shape[:-1]
new_size = tuple(new_size) + tuple(old_size)

return DiscreteMarkovChainRV.rv_op(
*dist.owner.inputs[:-1],
size=new_size,
)


@_logprob.register(DiscreteMarkovChainRV)
def discrete_mc_logp(
op, values, P, x0, steps, state_rng, **kwargs
):

(value,) = values
# GARCH11 swaps the time axis to the front, not sure why this is necessary
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved
mc_logprob = at.log(P[value[..., :-1], value[..., 1:]]).sum(axis=-1)
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved

return mc_logprob

@_moment.register(DiscreteMarkovChainRV)
def discrete_markov_chain_moment(
op, rv, P, x0, steps, state_rng
):
# TODO: What is the mean of the chain?
return at.zeros_like(rv)
jessegrabowski marked this conversation as resolved.
Show resolved Hide resolved