Skip to content

Commit

Permalink
Convert step methods to v4
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 26, 2021
1 parent 36589db commit c9fb000
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 56 deletions.
63 changes: 39 additions & 24 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
from itertools import chain

import aesara.scalar as aes
import aesara.tensor as at
import numpy as np
import pymc3 as pm
from aesara.compile import optdb
from aesara.graph.basic import Variable, graph_inputs
from aesara.graph.basic import Variable, graph_inputs, vars_between
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value as test_value
from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer
from aesara.graph.optdb import Query
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.subtensor import AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant
from pymc3.aesaraf import change_rv_size
from pymc3.distributions.logp import logpt
from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence
from pymc3.util import get_untransformed_name

from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess
from pymc3_hmm.distributions import DiscreteMarkovChainFactory, SwitchingProcessFactory
from pymc3_hmm.utils import compute_trans_freqs

big: float = 1e20
small: float = 1.0 / big


def conform_rv_shape(rv_var, shape):
ndim_supp = rv_var.owner.op.ndim_supp
if ndim_supp > 0:
new_size = shape[:-ndim_supp]
else:
new_size = shape

rv_var = change_rv_size(rv_var, new_size)
return rv_var


def ffbs_step(
gamma_0: np.ndarray,
Gammas: np.ndarray,
Expand Down Expand Up @@ -133,9 +144,9 @@ def __init__(self, vars, values=None, model=None):
if len(vars) > 1:
raise ValueError("This sampler only takes one variable.")

(var,) = pm.inputvars(vars)
(var,) = vars

if not isinstance(var.distribution, DiscreteMarkovChain):
if not var.owner or not isinstance(var.owner.op, DiscreteMarkovChainFactory):
raise TypeError("This sampler only samples `DiscreteMarkovChain`s.")

model = pm.modelcontext(model)
Expand All @@ -145,18 +156,26 @@ def __init__(self, vars, values=None, model=None):
self.dependent_rvs = [
v
for v in model.basic_RVs
if v is not var and var in graph_inputs([v.logpt])
if v is not var and var in vars_between(list(graph_inputs([v])), [v])
]

if not self.dependent_rvs:
raise ValueError(f"Could not find variables that depend on {var}")

dep_comps_logp_stacked = []
for i, dependent_rv in enumerate(self.dependent_rvs):
if isinstance(dependent_rv.distribution, SwitchingProcess):
if dependent_rv.owner and isinstance(
dependent_rv.owner.op, SwitchingProcessFactory
):
comp_logps = []

# Get the log-likelihoood sequences for each state in this
# `SwitchingProcess` observations distribution
for comp_dist in dependent_rv.distribution.comp_dists:
comp_logps.append(comp_dist.logp(dependent_rv))
for comp_dist in dependent_rv.owner.inputs[
4 : -len(dependent_rv.owner.op.shared_inputs)
]:
new_comp_dist = conform_rv_shape(comp_dist, dependent_rv.shape)
comp_logps.append(logpt(new_comp_dist, dependent_rv))

comp_logp_stacked = at.stack(comp_logps)
else:
Expand All @@ -168,14 +187,15 @@ def __init__(self, vars, values=None, model=None):

comp_logp_stacked = at.sum(dep_comps_logp_stacked, axis=0)

# XXX: This isn't correct.
M = var.owner.inputs[2].eval(model.test_point)
N = model.test_point[var.name].shape[-1]
Gammas_var = var.owner.inputs[2]
gamma_0_var = var.owner.inputs[3]
M = model.fn(Gammas_var.shape)(model.test_point).item()
N = model.test_point[var.name].shape[0]
self.alphas = np.empty((M, N), dtype=float)

self.log_lik_states = model.fn(comp_logp_stacked)
self.gamma_0_fn = model.fn(var.distribution.gamma_0)
self.Gammas_fn = model.fn(var.distribution.Gammas)
self.gamma_0_fn = model.fn(gamma_0_var)
self.Gammas_fn = model.fn(Gammas_var)

def step(self, point):
gamma_0 = self.gamma_0_fn(point)
Expand All @@ -190,9 +210,8 @@ def step(self, point):

@staticmethod
def competence(var):
distribution = getattr(var.distribution, "parent_dist", var.distribution)

if isinstance(distribution, DiscreteMarkovChain):
if var.owner and isinstance(var.owner.op, DiscreteMarkovChainFactory):
return Competence.IDEAL
# elif isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types):
# return Competence.COMPATIBLE
Expand Down Expand Up @@ -242,7 +261,7 @@ def __init__(self, model_vars, values=None, model=None, rng=None):
if isinstance(model_vars, Variable):
model_vars = [model_vars]

model_vars = list(chain.from_iterable([pm.inputvars(v) for v in model_vars]))
model_vars = list(model_vars)

# TODO: Are the rows in this matrix our `dir_priors`?
dir_priors = []
Expand All @@ -256,7 +275,7 @@ def __init__(self, model_vars, values=None, model=None, rng=None):
state_seqs = [
v
for v in model.vars + model.observed_RVs
if isinstance(v.distribution, DiscreteMarkovChain)
if (v.owner.op and isinstance(v.owner.op, DiscreteMarkovChainFactory))
and all(d in graph_inputs([v.distribution.Gammas]) for d in dir_priors)
]

Expand Down Expand Up @@ -429,11 +448,7 @@ def astep(self, point, inputs):
@staticmethod
def competence(var):

# TODO: Check that the dependent term is a conjugate type.

distribution = getattr(var.distribution, "parent_dist", var.distribution)

if isinstance(distribution, pm.Dirichlet):
if var.owner and isinstance(var.owner.op, pm.Dirichlet):
return Competence.COMPATIBLE

return Competence.INCOMPATIBLE
75 changes: 43 additions & 32 deletions tests/test_step_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,24 @@ def test_ffbs_step():
def test_FFBSStep():

with pm.Model(), pytest.raises(ValueError):
P_rv = np.eye(2)[None, ...]
S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0], shape=10)
S_2_rv = DiscreteMarkovChain("S_2_t", P_rv, np.r_[0.0, 1.0], shape=10)
P_rv = np.broadcast_to(np.eye(2), (10, 2, 2))
S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0])
S_2_rv = DiscreteMarkovChain("S_2_t", P_rv, np.r_[0.0, 1.0])
PoissonZeroProcess(
"Y_t", 9.0, S_rv + S_2_rv, observed=np.random.poisson(9.0, size=10)
)
# Only one variable can be sampled by this step method
ffbs = FFBSStep([S_rv, S_2_rv])

with pm.Model(), pytest.raises(TypeError):
S_rv = pm.Categorical("S_t", np.r_[1.0, 0.0], shape=10)
S_rv = pm.Categorical("S_t", np.r_[1.0, 0.0], size=10)
PoissonZeroProcess("Y_t", 9.0, S_rv, observed=np.random.poisson(9.0, size=10))
# Only `DiscreteMarkovChains` can be sampled with this step method
ffbs = FFBSStep([S_rv])

with pm.Model(), pytest.raises(TypeError):
P_rv = np.eye(2)[None, ...]
S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0], shape=10)
P_rv = np.broadcast_to(np.eye(2), (10, 2, 2))
S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0])
pm.Poisson("Y_t", S_rv, observed=np.random.poisson(9.0, size=10))
# Only `SwitchingProcess`es can used as dependent variables
ffbs = FFBSStep([S_rv])
Expand All @@ -124,15 +124,18 @@ def test_FFBSStep():
y_test = poiszero_sim["Y_t"]

with pm.Model() as test_model:
p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2)
p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2)
p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1])

P_tt = at.stack([p_0_rv, p_1_rv])
P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt))

pi_0_tt = compute_steady_state(P_rv)
P_rv = pm.Deterministic(
"P_tt", at.broadcast_to(P_tt, (y_test.shape[0],) + tuple(P_tt.shape))
)

pi_0_tt = compute_steady_state(P_tt)

S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0])
S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt)

PoissonZeroProcess("Y_t", 9.0, S_rv, observed=y_test)

Expand Down Expand Up @@ -162,11 +165,13 @@ def test_FFBSStep_extreme():
p_1_rv = poiszero_sim["p_1"]

P_tt = at.stack([p_0_rv, p_1_rv])
P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt))
P_rv = pm.Deterministic(
"P_tt", at.broadcast_to(P_tt, (y_test.shape[0],) + tuple(P_tt.shape))
)

pi_0_tt = poiszero_sim["pi_0"]

S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0])
S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt)
S_rv.tag.test_value = (y_test > 0).astype(int)

# This prior is very far from the true value...
Expand Down Expand Up @@ -213,7 +218,7 @@ def test_FFBSStep_extreme():
def test_TransMatConjugateStep():

with pm.Model() as test_model, pytest.raises(ValueError):
p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2)
p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
transmat = TransMatConjugateStep(p_0_rv)

np.random.seed(2032)
Expand All @@ -222,15 +227,17 @@ def test_TransMatConjugateStep():
y_test = poiszero_sim["Y_t"]

with pm.Model() as test_model:
p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2)
p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2)
p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1])

P_tt = at.stack([p_0_rv, p_1_rv])
P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt))
P_rv = pm.Deterministic(
"P_tt", at.broadcast_to(P_tt, (y_test.shape[0],) + tuple(P_tt.shape))
)

pi_0_tt = compute_steady_state(P_rv)
pi_0_tt = compute_steady_state(P_tt)

S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0])
S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt)

PoissonZeroProcess("Y_t", 9.0, S_rv, observed=y_test)

Expand Down Expand Up @@ -265,8 +272,8 @@ def test_TransMatConjugateStep_subtensors():
# Confirm that Dirichlet/non-Dirichlet mixed rows can be
# parsed
with pm.Model():
d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2)
d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2)
d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1])

p_0_rv = at.as_tensor([0, 0, 1])
p_1_rv = at.zeros(3)
Expand All @@ -275,8 +282,10 @@ def test_TransMatConjugateStep_subtensors():
p_2_rv = at.set_subtensor(p_1_rv[[1, 2]], d_1_rv)

P_tt = at.stack([p_0_rv, p_1_rv, p_2_rv])
P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt))
DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0], shape=(10,))
P_rv = pm.Deterministic(
"P_tt", at.broadcast_to(P_tt, (10,) + tuple(P_tt.shape))
)
DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0])

transmat = TransMatConjugateStep(P_rv)

Expand All @@ -289,8 +298,8 @@ def test_TransMatConjugateStep_subtensors():

# Same thing, just with some manipulations of the transition matrix
with pm.Model():
d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2)
d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2)
d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1])

p_0_rv = at.as_tensor([0, 0, 1])
p_1_rv = at.zeros(3)
Expand All @@ -301,8 +310,10 @@ def test_TransMatConjugateStep_subtensors():
P_tt = at.horizontal_stack(
p_0_rv[..., None], p_1_rv[..., None], p_2_rv[..., None]
)
P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt.T))
DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0], shape=(10,))
P_rv = pm.Deterministic(
"P_tt", at.broadcast_to(P_tt.T, (10,) + tuple(P_tt.T.shape))
)
DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0])

transmat = TransMatConjugateStep(P_rv)

Expand All @@ -315,8 +326,8 @@ def test_TransMatConjugateStep_subtensors():

# Use an observed `DiscreteMarkovChain` and check the conjugate results
with pm.Model():
d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2)
d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2)
d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1])

p_0_rv = at.as_tensor([0, 0, 1])
p_1_rv = at.zeros(3)
Expand All @@ -327,9 +338,9 @@ def test_TransMatConjugateStep_subtensors():
P_tt = at.horizontal_stack(
p_0_rv[..., None], p_1_rv[..., None], p_2_rv[..., None]
)
P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt.T))
DiscreteMarkovChain(
"S_t", P_rv, np.r_[1, 0, 0], shape=(4,), observed=np.r_[0, 1, 0, 2]
P_rv = pm.Deterministic(
"P_tt", at.broadcast_to(P_tt.T, (4,) + tuple(P_tt.T.shape))
)
DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0], observed=np.r_[0, 1, 0, 2])

transmat = TransMatConjugateStep(P_rv)

0 comments on commit c9fb000

Please sign in to comment.