Skip to content

Commit

Permalink
Change theano imports to aesara
Browse files Browse the repository at this point in the history
  • Loading branch information
fanshi118 authored and brandonwillard committed May 11, 2021
1 parent 06ce40e commit e9b3b90
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 124 deletions.
74 changes: 42 additions & 32 deletions pymc3_hmm/distributions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import warnings

import numpy as np

try:
import theano as aesara
import theano.tensor as at
from theano.graph.op import get_test_value
from theano.graph.utils import TestValueError
from theano.scalar import upcast
from theano.tensor.extra_ops import broadcast_to as at_broadcast_to
except ImportError: # pragma: no cover
import aesara
import aesara.tensor as at
from aesara.graph.op import get_test_value
from aesara.graph.utils import TestValueError
from aesara.scalar import upcast
from aesara.tensor.extra_ops import broadcast_to as at_broadcast_to

import pymc3 as pm
import theano
import theano.tensor as tt
from pymc3.distributions.distribution import _DrawValuesContext, draw_values
from pymc3.distributions.mixture import _conversion_map, all_discrete
from theano.graph.op import get_test_value
from theano.graph.utils import TestValueError
from theano.scalar import upcast
from theano.tensor.extra_ops import broadcast_to as tt_broadcast_to

from pymc3_hmm.utils import tt_broadcast_arrays, tt_expand_dims, vsearchsorted

Expand Down Expand Up @@ -73,7 +83,7 @@ def distribution_subset_args(dist, shape, idx, point=None):
else:
x = point[param]

bcast_res = tt_broadcast_to(x, shape)
bcast_res = at_broadcast_to(x, shape)

res.append(bcast_res[idx])

Expand Down Expand Up @@ -121,7 +131,7 @@ def __init__(self, comp_dists, states, *args, **kwargs):
equal to the size of `comp_dists`.
"""
self.states = tt.as_tensor_variable(pm.intX(states))
self.states = at.as_tensor_variable(pm.intX(states))

if len(comp_dists) > 31:
warnings.warn(
Expand All @@ -147,7 +157,7 @@ def __init__(self, comp_dists, states, *args, **kwargs):
bcast_means = tt_broadcast_arrays(
*([self.states] + [d.mean.astype(dtype) for d in self.comp_dists])
)
self.mean = tt.choose(self.states, bcast_means[1:])
self.mean = at.choose(self.states, bcast_means[1:])

if "mean" not in defaults:
defaults.append("mean")
Expand All @@ -159,7 +169,7 @@ def __init__(self, comp_dists, states, *args, **kwargs):
bcast_modes = tt_broadcast_arrays(
*([self.states] + [d.mode.astype(dtype) for d in self.comp_dists])
)
self.mode = tt.choose(self.states, bcast_modes[1:])
self.mode = at.choose(self.states, bcast_modes[1:])

if "mode" not in defaults:
defaults.append("mode")
Expand All @@ -172,15 +182,15 @@ def __init__(self, comp_dists, states, *args, **kwargs):
def logp(self, obs):
"""Return the scalar Theano log-likelihood at a point."""

obs_tt = tt.as_tensor_variable(obs)
obs_tt = at.as_tensor_variable(obs)

logp_val = tt.alloc(-np.inf, *obs.shape)
logp_val = at.alloc(-np.inf, *obs.shape)

for i, dist in enumerate(self.comp_dists):
i_mask = tt.eq(self.states, i)
i_mask = at.eq(self.states, i)
obs_i = obs_tt[i_mask]
subset_dist = dist.dist(*distribution_subset_args(dist, obs.shape, i_mask))
logp_val = tt.set_subtensor(logp_val[i_mask], subset_dist.logp(obs_i))
logp_val = at.set_subtensor(logp_val[i_mask], subset_dist.logp(obs_i))

return logp_val

Expand Down Expand Up @@ -265,8 +275,8 @@ def __init__(self, mu=None, states=None, **kwargs):
A vector of integer 0-1 states that indicate which component of
the mixture is active at each point/time.
"""
self.mu = tt.as_tensor_variable(pm.floatX(mu))
self.states = tt.as_tensor_variable(states)
self.mu = at.as_tensor_variable(pm.floatX(mu))
self.states = at.as_tensor_variable(states)

super().__init__([pm.Constant.dist(0), pm.Poisson.dist(mu)], states, **kwargs)

Expand Down Expand Up @@ -298,15 +308,15 @@ def __init__(self, Gammas, gamma_0, shape, **kwargs):
Shape of the state sequence. The last dimension is `N`, i.e. the
length of the state sequence(s).
"""
self.gamma_0 = tt.as_tensor_variable(pm.floatX(gamma_0))
self.gamma_0 = at.as_tensor_variable(pm.floatX(gamma_0))

assert Gammas.ndim >= 3

self.Gammas = tt.as_tensor_variable(pm.floatX(Gammas))
self.Gammas = at.as_tensor_variable(pm.floatX(Gammas))

shape = np.atleast_1d(shape)

dtype = _conversion_map[theano.config.floatX]
dtype = _conversion_map[aesara.config.floatX]
self.mode = np.zeros(tuple(shape), dtype=dtype)

super().__init__(shape=shape, **kwargs)
Expand All @@ -330,44 +340,44 @@ def logp(self, states):
""" # noqa: E501

Gammas = tt.shape_padleft(self.Gammas, states.ndim - (self.Gammas.ndim - 2))
Gammas = at.shape_padleft(self.Gammas, states.ndim - (self.Gammas.ndim - 2))

# Multiply the initial state probabilities by the first transition
# matrix by to get the marginal probability for state `S_1`.
# The integral that produces the marginal is essentially
# `gamma_0.dot(Gammas[0])`
Gamma_1 = Gammas[..., 0:1, :, :]
gamma_0 = tt_expand_dims(self.gamma_0, (-3, -1))
P_S_1 = tt.sum(gamma_0 * Gamma_1, axis=-2)
P_S_1 = at.sum(gamma_0 * Gamma_1, axis=-2)

# The `tt.switch`s allow us to broadcast the indexing operation when
# the replication dimensions of `states` and `Gammas` don't match
# (e.g. `states.shape[0] > Gammas.shape[0]`)
S_1_slices = tuple(
slice(
tt.switch(tt.eq(P_S_1.shape[i], 1), 0, 0),
tt.switch(tt.eq(P_S_1.shape[i], 1), 1, d),
at.switch(at.eq(P_S_1.shape[i], 1), 0, 0),
at.switch(at.eq(P_S_1.shape[i], 1), 1, d),
)
for i, d in enumerate(states.shape)
)
S_1_slices = (tuple(tt.ogrid[S_1_slices]) if S_1_slices else tuple()) + (
S_1_slices = (tuple(at.ogrid[S_1_slices]) if S_1_slices else tuple()) + (
states[..., 0:1],
)
logp_S_1 = tt.log(P_S_1[S_1_slices]).sum(axis=-1)
logp_S_1 = at.log(P_S_1[S_1_slices]).sum(axis=-1)

# These are slices for the extra dimensions--including the state
# sequence dimension (e.g. "time")--along which which we need to index
# the transition matrix rows using the "observed" `states`.
trans_slices = tuple(
slice(
tt.switch(
tt.eq(Gammas.shape[i], 1), 0, 1 if i == states.ndim - 1 else 0
at.switch(
at.eq(Gammas.shape[i], 1), 0, 1 if i == states.ndim - 1 else 0
),
tt.switch(tt.eq(Gammas.shape[i], 1), 1, d),
at.switch(at.eq(Gammas.shape[i], 1), 1, d),
)
for i, d in enumerate(states.shape)
)
trans_slices = (tuple(tt.ogrid[trans_slices]) if trans_slices else tuple()) + (
trans_slices = (tuple(at.ogrid[trans_slices]) if trans_slices else tuple()) + (
states[..., :-1],
)

Expand All @@ -376,12 +386,12 @@ def logp(self, states):
P_S_2T = Gammas[trans_slices]

obs_slices = tuple(slice(None, d) for d in P_S_2T.shape[:-1])
obs_slices = (tuple(tt.ogrid[obs_slices]) if obs_slices else tuple()) + (
obs_slices = (tuple(at.ogrid[obs_slices]) if obs_slices else tuple()) + (
states[..., 1:],
)
logp_S_1T = tt.log(P_S_2T[obs_slices])
logp_S_1T = at.log(P_S_2T[obs_slices])

res = logp_S_1 + tt.sum(logp_S_1T, axis=-1)
res = logp_S_1 + at.sum(logp_S_1T, axis=-1)
res.name = "DiscreteMarkovChain_logp"

return res
Expand Down
49 changes: 32 additions & 17 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
from itertools import chain

import numpy as np

try:
import theano.scalar as aes
import theano.tensor as at
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
from theano.graph.fg import FunctionGraph
from theano.graph.op import get_test_value as test_value
from theano.graph.opt import OpRemove, pre_greedy_local_optimizer
from theano.graph.optdb import Query
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant
except ImportError: # pragma: no cover
import aesara.scalar as aes
import aesara.tensor as at
from aesara.compile import optdb
from aesara.graph.basic import Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value as test_value
from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer
from aesara.graph.optdb import Query
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.subtensor import AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant

import pymc3 as pm
import theano.scalar as ts
import theano.tensor as tt
from pymc3.distributions.distribution import draw_values
from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence
from pymc3.util import get_untransformed_name
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
from theano.graph.fg import FunctionGraph
from theano.graph.op import get_test_value as test_value
from theano.graph.opt import OpRemove, pre_greedy_local_optimizer
from theano.graph.optdb import Query
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant

from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess
from pymc3_hmm.utils import compute_trans_freqs
Expand Down Expand Up @@ -159,15 +174,15 @@ def __init__(self, vars, values=None, model=None):
for comp_dist in dependent_rv.distribution.comp_dists:
comp_logps.append(comp_dist.logp(dependent_rv))

comp_logp_stacked = tt.stack(comp_logps)
comp_logp_stacked = at.stack(comp_logps)
else:
raise TypeError(
"This sampler only supports `SwitchingProcess` observations"
)

dep_comps_logp_stacked.append(comp_logp_stacked)

comp_logp_stacked = tt.sum(dep_comps_logp_stacked, axis=0)
comp_logp_stacked = at.sum(dep_comps_logp_stacked, axis=0)

(M,) = draw_values([var.distribution.gamma_0.shape[-1]], point=model.test_point)
N = model.test_point[var.name].shape[-1]
Expand Down Expand Up @@ -326,9 +341,9 @@ def _set_row_mappings(self, Gamma, dir_priors, model):
Gamma = pre_greedy_local_optimizer(
FunctionGraph([], []),
[
OpRemove(Elemwise(ts.Cast(ts.float32))),
OpRemove(Elemwise(ts.Cast(ts.float64))),
OpRemove(Elemwise(ts.identity)),
OpRemove(Elemwise(aes.Cast(aes.float32))),
OpRemove(Elemwise(aes.Cast(aes.float64))),
OpRemove(Elemwise(aes.identity)),
],
Gamma,
)
Expand All @@ -352,7 +367,7 @@ def _set_row_mappings(self, Gamma, dir_priors, model):

Gamma_Join = Gamma_DimShuffle.inputs[0].owner

if not (isinstance(Gamma_Join.op, tt.basic.Join)):
if not (isinstance(Gamma_Join.op, at.basic.Join)):
raise TypeError(
"The transition matrix should be comprised of stacked row vectors"
)
Expand Down
36 changes: 22 additions & 14 deletions pymc3_hmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import theano.tensor as tt
from matplotlib import cm
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from scipy.special import logsumexp
from theano.tensor.extra_ops import broadcast_shape
from theano.tensor.extra_ops import broadcast_to as tt_broadcast_to
from theano.tensor.var import TensorVariable

try:
import theano.tensor as at
from theano.tensor.extra_ops import broadcast_shape
from theano.tensor.extra_ops import broadcast_to as at_broadcast_to
from theano.tensor.var import TensorVariable
except ImportError: # pragma: no cover
import aesara.tensor as at
from aesara.tensor.extra_ops import broadcast_shape
from aesara.tensor.extra_ops import broadcast_to as at_broadcast_to
from aesara.tensor.var import TensorVariable


vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")

Expand All @@ -30,8 +38,8 @@ def compute_steady_state(P):

P = P[0]
N_states = P.shape[-1]
Lam = (tt.eye(N_states) - P + tt.ones((N_states, N_states))).T
u = tt.slinalg.solve(Lam, tt.ones((N_states,)))
Lam = (at.eye(N_states) - P + at.ones((N_states, N_states))).T
u = at.slinalg.solve(Lam, at.ones((N_states,)))
return u


Expand Down Expand Up @@ -81,15 +89,15 @@ def compute_trans_freqs(states, N_states, counts_only=False):

def tt_logsumexp(x, axis=None, keepdims=False):
"""Construct a Theano graph for a log-sum-exp calculation."""
x_max_ = tt.max(x, axis=axis, keepdims=True)
x_max_ = at.max(x, axis=axis, keepdims=True)

if x_max_.ndim > 0:
x_max_ = tt.set_subtensor(x_max_[tt.isinf(x_max_)], 0.0)
elif tt.isinf(x_max_):
x_max_ = tt.as_tensor(0.0)
x_max_ = at.set_subtensor(x_max_[at.isinf(x_max_)], 0.0)
elif at.isinf(x_max_):
x_max_ = at.as_tensor(0.0)

res = tt.sum(tt.exp(x - x_max_), axis=axis, keepdims=keepdims)
res = tt.log(res)
res = at.sum(at.exp(x - x_max_), axis=axis, keepdims=keepdims)
res = at.log(res)

if not keepdims:
# SciPy uses the `axis` keyword here, but Theano doesn't support that.
Expand Down Expand Up @@ -179,7 +187,7 @@ def tt_broadcast_arrays(*args: TensorVariable):
"""
bcast_shape = broadcast_shape(*args)
return tuple(tt_broadcast_to(a, bcast_shape) for a in args)
return tuple(at_broadcast_to(a, bcast_shape) for a in args)


def multilogit_inv(ys):
Expand All @@ -202,7 +210,7 @@ def multilogit_inv(ys):
lib = np
lib_logsumexp = logsumexp
else:
lib = tt
lib = at
lib_logsumexp = tt_logsumexp

# exp_ys = lib.exp(ys)
Expand Down
Loading

0 comments on commit e9b3b90

Please sign in to comment.