Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change theano imports to aesara #83

Merged
merged 4 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
strategy:
matrix:
python-version: [3.7]
pymc3-version: [stable, dev]

steps:
- uses: actions/checkout@v2
Expand All @@ -60,6 +61,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [[ "${{ matrix.pymc3-version }}" != "stable" ]]; then pip install "pymc3 @ git+https://github.com/pymc-devs/pymc3.git@master"; fi
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
Expand Down
76 changes: 43 additions & 33 deletions pymc3_hmm/distributions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import warnings

import numpy as np

try: # pragma: no cover
import aesara
import aesara.tensor as at
from aesara.graph.op import get_test_value
from aesara.graph.utils import TestValueError
from aesara.scalar import upcast
from aesara.tensor.extra_ops import broadcast_to as at_broadcast_to
except ImportError: # pragma: no cover
import theano as aesara
import theano.tensor as at
from theano.graph.op import get_test_value
from theano.graph.utils import TestValueError
from theano.scalar import upcast
from theano.tensor.extra_ops import broadcast_to as at_broadcast_to

import pymc3 as pm
import theano
import theano.tensor as tt
from pymc3.distributions.distribution import _DrawValuesContext, draw_values
from pymc3.distributions.mixture import _conversion_map, all_discrete
from theano.graph.op import get_test_value
from theano.graph.utils import TestValueError
from theano.scalar import upcast
from theano.tensor.extra_ops import broadcast_to as tt_broadcast_to

from pymc3_hmm.utils import tt_broadcast_arrays, tt_expand_dims, vsearchsorted

Expand Down Expand Up @@ -73,7 +83,7 @@ def distribution_subset_args(dist, shape, idx, point=None):
else:
x = point[param]

bcast_res = tt_broadcast_to(x, shape)
bcast_res = at_broadcast_to(x, shape)

res.append(bcast_res[idx])

Expand Down Expand Up @@ -121,7 +131,7 @@ def __init__(self, comp_dists, states, *args, **kwargs):
equal to the size of `comp_dists`.

"""
self.states = tt.as_tensor_variable(pm.intX(states))
self.states = at.as_tensor_variable(pm.intX(states))

if len(comp_dists) > 31:
warnings.warn(
Expand All @@ -147,7 +157,7 @@ def __init__(self, comp_dists, states, *args, **kwargs):
bcast_means = tt_broadcast_arrays(
*([self.states] + [d.mean.astype(dtype) for d in self.comp_dists])
)
self.mean = tt.choose(self.states, bcast_means[1:])
self.mean = at.choose(self.states, bcast_means[1:])

if "mean" not in defaults:
defaults.append("mean")
Expand All @@ -159,7 +169,7 @@ def __init__(self, comp_dists, states, *args, **kwargs):
bcast_modes = tt_broadcast_arrays(
*([self.states] + [d.mode.astype(dtype) for d in self.comp_dists])
)
self.mode = tt.choose(self.states, bcast_modes[1:])
self.mode = at.choose(self.states, bcast_modes[1:])

if "mode" not in defaults:
defaults.append("mode")
Expand All @@ -172,15 +182,15 @@ def __init__(self, comp_dists, states, *args, **kwargs):
def logp(self, obs):
"""Return the scalar Theano log-likelihood at a point."""

obs_tt = tt.as_tensor_variable(obs)
obs_tt = at.as_tensor_variable(obs)

logp_val = tt.alloc(-np.inf, *obs.shape)
logp_val = at.alloc(-np.inf, *obs.shape)

for i, dist in enumerate(self.comp_dists):
i_mask = tt.eq(self.states, i)
i_mask = at.eq(self.states, i)
obs_i = obs_tt[i_mask]
subset_dist = dist.dist(*distribution_subset_args(dist, obs.shape, i_mask))
logp_val = tt.set_subtensor(logp_val[i_mask], subset_dist.logp(obs_i))
logp_val = at.set_subtensor(logp_val[i_mask], subset_dist.logp(obs_i))

return logp_val

Expand Down Expand Up @@ -265,8 +275,8 @@ def __init__(self, mu=None, states=None, **kwargs):
A vector of integer 0-1 states that indicate which component of
the mixture is active at each point/time.
"""
self.mu = tt.as_tensor_variable(pm.floatX(mu))
self.states = tt.as_tensor_variable(states)
self.mu = at.as_tensor_variable(pm.floatX(mu))
self.states = at.as_tensor_variable(states)

super().__init__([pm.Constant.dist(0), pm.Poisson.dist(mu)], states, **kwargs)

Expand Down Expand Up @@ -298,16 +308,16 @@ def __init__(self, Gammas, gamma_0, shape, **kwargs):
Shape of the state sequence. The last dimension is `N`, i.e. the
length of the state sequence(s).
"""
self.gamma_0 = tt.as_tensor_variable(pm.floatX(gamma_0))
self.gamma_0 = at.as_tensor_variable(pm.floatX(gamma_0))

assert Gammas.ndim >= 3

self.Gammas = tt.as_tensor_variable(pm.floatX(Gammas))
self.Gammas = at.as_tensor_variable(pm.floatX(Gammas))

shape = np.atleast_1d(shape)

dtype = _conversion_map[theano.config.floatX]
self.mode = tt.zeros(tuple(shape), dtype=dtype)
dtype = _conversion_map[aesara.config.floatX]
self.mode = np.zeros(tuple(shape), dtype=dtype)

super().__init__(shape=shape, **kwargs)

Expand All @@ -330,44 +340,44 @@ def logp(self, states):

""" # noqa: E501

Gammas = tt.shape_padleft(self.Gammas, states.ndim - (self.Gammas.ndim - 2))
Gammas = at.shape_padleft(self.Gammas, states.ndim - (self.Gammas.ndim - 2))

# Multiply the initial state probabilities by the first transition
# matrix by to get the marginal probability for state `S_1`.
# The integral that produces the marginal is essentially
# `gamma_0.dot(Gammas[0])`
Gamma_1 = Gammas[..., 0:1, :, :]
gamma_0 = tt_expand_dims(self.gamma_0, (-3, -1))
P_S_1 = tt.sum(gamma_0 * Gamma_1, axis=-2)
P_S_1 = at.sum(gamma_0 * Gamma_1, axis=-2)

# The `tt.switch`s allow us to broadcast the indexing operation when
# the replication dimensions of `states` and `Gammas` don't match
# (e.g. `states.shape[0] > Gammas.shape[0]`)
S_1_slices = tuple(
slice(
tt.switch(tt.eq(P_S_1.shape[i], 1), 0, 0),
tt.switch(tt.eq(P_S_1.shape[i], 1), 1, d),
at.switch(at.eq(P_S_1.shape[i], 1), 0, 0),
at.switch(at.eq(P_S_1.shape[i], 1), 1, d),
)
for i, d in enumerate(states.shape)
)
S_1_slices = (tuple(tt.ogrid[S_1_slices]) if S_1_slices else tuple()) + (
S_1_slices = (tuple(at.ogrid[S_1_slices]) if S_1_slices else tuple()) + (
states[..., 0:1],
)
logp_S_1 = tt.log(P_S_1[S_1_slices]).sum(axis=-1)
logp_S_1 = at.log(P_S_1[S_1_slices]).sum(axis=-1)

# These are slices for the extra dimensions--including the state
# sequence dimension (e.g. "time")--along which which we need to index
# the transition matrix rows using the "observed" `states`.
trans_slices = tuple(
slice(
tt.switch(
tt.eq(Gammas.shape[i], 1), 0, 1 if i == states.ndim - 1 else 0
at.switch(
at.eq(Gammas.shape[i], 1), 0, 1 if i == states.ndim - 1 else 0
),
tt.switch(tt.eq(Gammas.shape[i], 1), 1, d),
at.switch(at.eq(Gammas.shape[i], 1), 1, d),
)
for i, d in enumerate(states.shape)
)
trans_slices = (tuple(tt.ogrid[trans_slices]) if trans_slices else tuple()) + (
trans_slices = (tuple(at.ogrid[trans_slices]) if trans_slices else tuple()) + (
states[..., :-1],
)

Expand All @@ -376,12 +386,12 @@ def logp(self, states):
P_S_2T = Gammas[trans_slices]

obs_slices = tuple(slice(None, d) for d in P_S_2T.shape[:-1])
obs_slices = (tuple(tt.ogrid[obs_slices]) if obs_slices else tuple()) + (
obs_slices = (tuple(at.ogrid[obs_slices]) if obs_slices else tuple()) + (
states[..., 1:],
)
logp_S_1T = tt.log(P_S_2T[obs_slices])
logp_S_1T = at.log(P_S_2T[obs_slices])

res = logp_S_1 + tt.sum(logp_S_1T, axis=-1)
res = logp_S_1 + at.sum(logp_S_1T, axis=-1)
res.name = "DiscreteMarkovChain_logp"

return res
Expand Down
49 changes: 32 additions & 17 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
from itertools import chain

import numpy as np

try: # pragma: no cover
import aesara.scalar as aes
import aesara.tensor as at
from aesara.compile import optdb
from aesara.graph.basic import Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value as test_value
from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer
from aesara.graph.optdb import Query
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.subtensor import AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant
except ImportError: # pragma: no cover
import theano.scalar as aes
import theano.tensor as at
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
from theano.graph.fg import FunctionGraph
from theano.graph.op import get_test_value as test_value
from theano.graph.opt import OpRemove, pre_greedy_local_optimizer
from theano.graph.optdb import Query
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant

import pymc3 as pm
import theano.scalar as ts
import theano.tensor as tt
from pymc3.distributions.distribution import draw_values
from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence
from pymc3.util import get_untransformed_name
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
from theano.graph.fg import FunctionGraph
from theano.graph.op import get_test_value as test_value
from theano.graph.opt import OpRemove, pre_greedy_local_optimizer
from theano.graph.optdb import Query
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant

from pymc3_hmm.distributions import DiscreteMarkovChain, SwitchingProcess
from pymc3_hmm.utils import compute_trans_freqs
Expand Down Expand Up @@ -159,15 +174,15 @@ def __init__(self, vars, values=None, model=None):
for comp_dist in dependent_rv.distribution.comp_dists:
comp_logps.append(comp_dist.logp(dependent_rv))

comp_logp_stacked = tt.stack(comp_logps)
comp_logp_stacked = at.stack(comp_logps)
else:
raise TypeError(
"This sampler only supports `SwitchingProcess` observations"
)

dep_comps_logp_stacked.append(comp_logp_stacked)

comp_logp_stacked = tt.sum(dep_comps_logp_stacked, axis=0)
comp_logp_stacked = at.sum(dep_comps_logp_stacked, axis=0)

(M,) = draw_values([var.distribution.gamma_0.shape[-1]], point=model.test_point)
N = model.test_point[var.name].shape[-1]
Expand Down Expand Up @@ -326,9 +341,9 @@ def _set_row_mappings(self, Gamma, dir_priors, model):
Gamma = pre_greedy_local_optimizer(
FunctionGraph([], []),
[
OpRemove(Elemwise(ts.Cast(ts.float32))),
OpRemove(Elemwise(ts.Cast(ts.float64))),
OpRemove(Elemwise(ts.identity)),
OpRemove(Elemwise(aes.Cast(aes.float32))),
OpRemove(Elemwise(aes.Cast(aes.float64))),
OpRemove(Elemwise(aes.identity)),
],
Gamma,
)
Expand All @@ -352,7 +367,7 @@ def _set_row_mappings(self, Gamma, dir_priors, model):

Gamma_Join = Gamma_DimShuffle.inputs[0].owner

if not (isinstance(Gamma_Join.op, tt.basic.Join)):
if not (isinstance(Gamma_Join.op, at.basic.Join)):
raise TypeError(
"The transition matrix should be comprised of stacked row vectors"
)
Expand Down
36 changes: 22 additions & 14 deletions pymc3_hmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import theano.tensor as tt
from matplotlib import cm
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from scipy.special import logsumexp
from theano.tensor.extra_ops import broadcast_shape
from theano.tensor.extra_ops import broadcast_to as tt_broadcast_to
from theano.tensor.var import TensorVariable

try: # pragma: no cover
import aesara.tensor as at
from aesara.tensor.extra_ops import broadcast_shape
from aesara.tensor.extra_ops import broadcast_to as at_broadcast_to
from aesara.tensor.var import TensorVariable
except ImportError: # pragma: no cover
import theano.tensor as at
from theano.tensor.extra_ops import broadcast_shape
from theano.tensor.extra_ops import broadcast_to as at_broadcast_to
from theano.tensor.var import TensorVariable


vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")

Expand All @@ -30,8 +38,8 @@ def compute_steady_state(P):

P = P[0]
N_states = P.shape[-1]
Lam = (tt.eye(N_states) - P + tt.ones((N_states, N_states))).T
u = tt.slinalg.solve(Lam, tt.ones((N_states,)))
Lam = (at.eye(N_states) - P + at.ones((N_states, N_states))).T
u = at.slinalg.solve(Lam, at.ones((N_states,)))
return u


Expand Down Expand Up @@ -81,15 +89,15 @@ def compute_trans_freqs(states, N_states, counts_only=False):

def tt_logsumexp(x, axis=None, keepdims=False):
"""Construct a Theano graph for a log-sum-exp calculation."""
x_max_ = tt.max(x, axis=axis, keepdims=True)
x_max_ = at.max(x, axis=axis, keepdims=True)

if x_max_.ndim > 0:
x_max_ = tt.set_subtensor(x_max_[tt.isinf(x_max_)], 0.0)
elif tt.isinf(x_max_):
x_max_ = tt.as_tensor(0.0)
x_max_ = at.set_subtensor(x_max_[at.isinf(x_max_)], 0.0)
elif at.isinf(x_max_):
x_max_ = at.as_tensor(0.0)

res = tt.sum(tt.exp(x - x_max_), axis=axis, keepdims=keepdims)
res = tt.log(res)
res = at.sum(at.exp(x - x_max_), axis=axis, keepdims=keepdims)
res = at.log(res)

if not keepdims:
# SciPy uses the `axis` keyword here, but Theano doesn't support that.
Expand Down Expand Up @@ -179,7 +187,7 @@ def tt_broadcast_arrays(*args: TensorVariable):

"""
bcast_shape = broadcast_shape(*args)
return tuple(tt_broadcast_to(a, bcast_shape) for a in args)
return tuple(at_broadcast_to(a, bcast_shape) for a in args)


def multilogit_inv(ys):
Expand All @@ -202,7 +210,7 @@ def multilogit_inv(ys):
lib = np
lib_logsumexp = logsumexp
else:
lib = tt
lib = at
lib_logsumexp = tt_logsumexp

# exp_ys = lib.exp(ys)
Expand Down
5 changes: 0 additions & 5 deletions pytest.ini

This file was deleted.

4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ add-ignore = D100,D101,D102,D103,D104,D105,D106,D107,D202
convention = numpy

[tool:pytest]
python_files=test*.py
testpaths=tests
python_files = test_*.py
testpaths = tests

[coverage:run]
omit =
Expand Down
Loading