From c9fb00012abe9a1491e8f2f86df131ef1b92899c Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 25 May 2021 22:37:38 -0500 Subject: [PATCH] Convert step methods to v4 --- pymc3_hmm/step_methods.py | 63 ++++++++++++++++++++------------ tests/test_step_methods.py | 75 ++++++++++++++++++++++---------------- 2 files changed, 82 insertions(+), 56 deletions(-) diff --git a/pymc3_hmm/step_methods.py b/pymc3_hmm/step_methods.py index f1d2517..3b2d69a 100644 --- a/pymc3_hmm/step_methods.py +++ b/pymc3_hmm/step_methods.py @@ -1,11 +1,9 @@ -from itertools import chain - import aesara.scalar as aes import aesara.tensor as at import numpy as np import pymc3 as pm from aesara.compile import optdb -from aesara.graph.basic import Variable, graph_inputs +from aesara.graph.basic import Variable, graph_inputs, vars_between from aesara.graph.fg import FunctionGraph from aesara.graph.op import get_test_value as test_value from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer @@ -13,16 +11,29 @@ from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.subtensor import AdvancedIncSubtensor1 from aesara.tensor.var import TensorConstant +from pymc3.aesaraf import change_rv_size +from pymc3.distributions.logp import logpt from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence from pymc3.util import get_untransformed_name -from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess +from pymc3_hmm.distributions import DiscreteMarkovChainFactory, SwitchingProcessFactory from pymc3_hmm.utils import compute_trans_freqs big: float = 1e20 small: float = 1.0 / big +def conform_rv_shape(rv_var, shape): + ndim_supp = rv_var.owner.op.ndim_supp + if ndim_supp > 0: + new_size = shape[:-ndim_supp] + else: + new_size = shape + + rv_var = change_rv_size(rv_var, new_size) + return rv_var + + def ffbs_step( gamma_0: np.ndarray, Gammas: np.ndarray, @@ -133,9 +144,9 @@ def __init__(self, vars, values=None, model=None): if len(vars) > 1: raise ValueError("This sampler only takes one variable.") - (var,) = pm.inputvars(vars) + (var,) = vars - if not isinstance(var.distribution, DiscreteMarkovChain): + if not var.owner or not isinstance(var.owner.op, DiscreteMarkovChainFactory): raise TypeError("This sampler only samples `DiscreteMarkovChain`s.") model = pm.modelcontext(model) @@ -145,18 +156,26 @@ def __init__(self, vars, values=None, model=None): self.dependent_rvs = [ v for v in model.basic_RVs - if v is not var and var in graph_inputs([v.logpt]) + if v is not var and var in vars_between(list(graph_inputs([v])), [v]) ] + if not self.dependent_rvs: + raise ValueError(f"Could not find variables that depend on {var}") + dep_comps_logp_stacked = [] for i, dependent_rv in enumerate(self.dependent_rvs): - if isinstance(dependent_rv.distribution, SwitchingProcess): + if dependent_rv.owner and isinstance( + dependent_rv.owner.op, SwitchingProcessFactory + ): comp_logps = [] # Get the log-likelihoood sequences for each state in this # `SwitchingProcess` observations distribution - for comp_dist in dependent_rv.distribution.comp_dists: - comp_logps.append(comp_dist.logp(dependent_rv)) + for comp_dist in dependent_rv.owner.inputs[ + 4 : -len(dependent_rv.owner.op.shared_inputs) + ]: + new_comp_dist = conform_rv_shape(comp_dist, dependent_rv.shape) + comp_logps.append(logpt(new_comp_dist, dependent_rv)) comp_logp_stacked = at.stack(comp_logps) else: @@ -168,14 +187,15 @@ def __init__(self, vars, values=None, model=None): comp_logp_stacked = at.sum(dep_comps_logp_stacked, axis=0) - # XXX: This isn't correct. - M = var.owner.inputs[2].eval(model.test_point) - N = model.test_point[var.name].shape[-1] + Gammas_var = var.owner.inputs[2] + gamma_0_var = var.owner.inputs[3] + M = model.fn(Gammas_var.shape)(model.test_point).item() + N = model.test_point[var.name].shape[0] self.alphas = np.empty((M, N), dtype=float) self.log_lik_states = model.fn(comp_logp_stacked) - self.gamma_0_fn = model.fn(var.distribution.gamma_0) - self.Gammas_fn = model.fn(var.distribution.Gammas) + self.gamma_0_fn = model.fn(gamma_0_var) + self.Gammas_fn = model.fn(Gammas_var) def step(self, point): gamma_0 = self.gamma_0_fn(point) @@ -190,9 +210,8 @@ def step(self, point): @staticmethod def competence(var): - distribution = getattr(var.distribution, "parent_dist", var.distribution) - if isinstance(distribution, DiscreteMarkovChain): + if var.owner and isinstance(var.owner.op, DiscreteMarkovChainFactory): return Competence.IDEAL # elif isinstance(distribution, pm.Bernoulli) or (var.dtype in pm.bool_types): # return Competence.COMPATIBLE @@ -242,7 +261,7 @@ def __init__(self, model_vars, values=None, model=None, rng=None): if isinstance(model_vars, Variable): model_vars = [model_vars] - model_vars = list(chain.from_iterable([pm.inputvars(v) for v in model_vars])) + model_vars = list(model_vars) # TODO: Are the rows in this matrix our `dir_priors`? dir_priors = [] @@ -256,7 +275,7 @@ def __init__(self, model_vars, values=None, model=None, rng=None): state_seqs = [ v for v in model.vars + model.observed_RVs - if isinstance(v.distribution, DiscreteMarkovChain) + if (v.owner.op and isinstance(v.owner.op, DiscreteMarkovChainFactory)) and all(d in graph_inputs([v.distribution.Gammas]) for d in dir_priors) ] @@ -429,11 +448,7 @@ def astep(self, point, inputs): @staticmethod def competence(var): - # TODO: Check that the dependent term is a conjugate type. - - distribution = getattr(var.distribution, "parent_dist", var.distribution) - - if isinstance(distribution, pm.Dirichlet): + if var.owner and isinstance(var.owner.op, pm.Dirichlet): return Competence.COMPATIBLE return Competence.INCOMPATIBLE diff --git a/tests/test_step_methods.py b/tests/test_step_methods.py index cf84af4..83380fe 100644 --- a/tests/test_step_methods.py +++ b/tests/test_step_methods.py @@ -96,9 +96,9 @@ def test_ffbs_step(): def test_FFBSStep(): with pm.Model(), pytest.raises(ValueError): - P_rv = np.eye(2)[None, ...] - S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0], shape=10) - S_2_rv = DiscreteMarkovChain("S_2_t", P_rv, np.r_[0.0, 1.0], shape=10) + P_rv = np.broadcast_to(np.eye(2), (10, 2, 2)) + S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0]) + S_2_rv = DiscreteMarkovChain("S_2_t", P_rv, np.r_[0.0, 1.0]) PoissonZeroProcess( "Y_t", 9.0, S_rv + S_2_rv, observed=np.random.poisson(9.0, size=10) ) @@ -106,14 +106,14 @@ def test_FFBSStep(): ffbs = FFBSStep([S_rv, S_2_rv]) with pm.Model(), pytest.raises(TypeError): - S_rv = pm.Categorical("S_t", np.r_[1.0, 0.0], shape=10) + S_rv = pm.Categorical("S_t", np.r_[1.0, 0.0], size=10) PoissonZeroProcess("Y_t", 9.0, S_rv, observed=np.random.poisson(9.0, size=10)) # Only `DiscreteMarkovChains` can be sampled with this step method ffbs = FFBSStep([S_rv]) with pm.Model(), pytest.raises(TypeError): - P_rv = np.eye(2)[None, ...] - S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0], shape=10) + P_rv = np.broadcast_to(np.eye(2), (10, 2, 2)) + S_rv = DiscreteMarkovChain("S_t", P_rv, np.r_[1.0, 0.0]) pm.Poisson("Y_t", S_rv, observed=np.random.poisson(9.0, size=10)) # Only `SwitchingProcess`es can used as dependent variables ffbs = FFBSStep([S_rv]) @@ -124,15 +124,18 @@ def test_FFBSStep(): y_test = poiszero_sim["Y_t"] with pm.Model() as test_model: - p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) P_tt = at.stack([p_0_rv, p_1_rv]) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt)) - pi_0_tt = compute_steady_state(P_rv) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt, (y_test.shape[0],) + tuple(P_tt.shape)) + ) + + pi_0_tt = compute_steady_state(P_tt) - S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0]) + S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt) PoissonZeroProcess("Y_t", 9.0, S_rv, observed=y_test) @@ -162,11 +165,13 @@ def test_FFBSStep_extreme(): p_1_rv = poiszero_sim["p_1"] P_tt = at.stack([p_0_rv, p_1_rv]) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt)) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt, (y_test.shape[0],) + tuple(P_tt.shape)) + ) pi_0_tt = poiszero_sim["pi_0"] - S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0]) + S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt) S_rv.tag.test_value = (y_test > 0).astype(int) # This prior is very far from the true value... @@ -213,7 +218,7 @@ def test_FFBSStep_extreme(): def test_TransMatConjugateStep(): with pm.Model() as test_model, pytest.raises(ValueError): - p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) + p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) transmat = TransMatConjugateStep(p_0_rv) np.random.seed(2032) @@ -222,15 +227,17 @@ def test_TransMatConjugateStep(): y_test = poiszero_sim["Y_t"] with pm.Model() as test_model: - p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + p_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + p_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) P_tt = at.stack([p_0_rv, p_1_rv]) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt)) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt, (y_test.shape[0],) + tuple(P_tt.shape)) + ) - pi_0_tt = compute_steady_state(P_rv) + pi_0_tt = compute_steady_state(P_tt) - S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt, shape=y_test.shape[0]) + S_rv = DiscreteMarkovChain("S_t", P_rv, pi_0_tt) PoissonZeroProcess("Y_t", 9.0, S_rv, observed=y_test) @@ -265,8 +272,8 @@ def test_TransMatConjugateStep_subtensors(): # Confirm that Dirichlet/non-Dirichlet mixed rows can be # parsed with pm.Model(): - d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) p_0_rv = at.as_tensor([0, 0, 1]) p_1_rv = at.zeros(3) @@ -275,8 +282,10 @@ def test_TransMatConjugateStep_subtensors(): p_2_rv = at.set_subtensor(p_1_rv[[1, 2]], d_1_rv) P_tt = at.stack([p_0_rv, p_1_rv, p_2_rv]) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt)) - DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0], shape=(10,)) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt, (10,) + tuple(P_tt.shape)) + ) + DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0]) transmat = TransMatConjugateStep(P_rv) @@ -289,8 +298,8 @@ def test_TransMatConjugateStep_subtensors(): # Same thing, just with some manipulations of the transition matrix with pm.Model(): - d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) p_0_rv = at.as_tensor([0, 0, 1]) p_1_rv = at.zeros(3) @@ -301,8 +310,10 @@ def test_TransMatConjugateStep_subtensors(): P_tt = at.horizontal_stack( p_0_rv[..., None], p_1_rv[..., None], p_2_rv[..., None] ) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt.T)) - DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0], shape=(10,)) + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt.T, (10,) + tuple(P_tt.T.shape)) + ) + DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0]) transmat = TransMatConjugateStep(P_rv) @@ -315,8 +326,8 @@ def test_TransMatConjugateStep_subtensors(): # Use an observed `DiscreteMarkovChain` and check the conjugate results with pm.Model(): - d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1], shape=2) - d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1], shape=2) + d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) + d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) p_0_rv = at.as_tensor([0, 0, 1]) p_1_rv = at.zeros(3) @@ -327,9 +338,9 @@ def test_TransMatConjugateStep_subtensors(): P_tt = at.horizontal_stack( p_0_rv[..., None], p_1_rv[..., None], p_2_rv[..., None] ) - P_rv = pm.Deterministic("P_tt", at.shape_padleft(P_tt.T)) - DiscreteMarkovChain( - "S_t", P_rv, np.r_[1, 0, 0], shape=(4,), observed=np.r_[0, 1, 0, 2] + P_rv = pm.Deterministic( + "P_tt", at.broadcast_to(P_tt.T, (4,) + tuple(P_tt.T.shape)) ) + DiscreteMarkovChain("S_t", P_rv, np.r_[1, 0, 0], observed=np.r_[0, 1, 0, 2]) transmat = TransMatConjugateStep(P_rv)