Skip to content

Commit

Permalink
Convert custom Distributions to PyMC v4 format
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 26, 2021
1 parent ca7f869 commit 36589db
Show file tree
Hide file tree
Showing 9 changed files with 589 additions and 672 deletions.
698 changes: 371 additions & 327 deletions pymc3_hmm/distributions.py

Large diffs are not rendered by default.

41 changes: 13 additions & 28 deletions pymc3_hmm/step_methods.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,18 @@
from itertools import chain

import aesara.scalar as aes
import aesara.tensor as at
import numpy as np

try: # pragma: no cover
import aesara.scalar as aes
import aesara.tensor as at
from aesara.compile import optdb
from aesara.graph.basic import Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value as test_value
from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer
from aesara.graph.optdb import Query
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.subtensor import AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant
except ImportError: # pragma: no cover
import theano.scalar as aes
import theano.tensor as at
from theano.compile import optdb
from theano.graph.basic import Variable, graph_inputs
from theano.graph.fg import FunctionGraph
from theano.graph.op import get_test_value as test_value
from theano.graph.opt import OpRemove, pre_greedy_local_optimizer
from theano.graph.optdb import Query
from theano.tensor.elemwise import DimShuffle, Elemwise
from theano.tensor.subtensor import AdvancedIncSubtensor1
from theano.tensor.var import TensorConstant

import pymc3 as pm
from pymc3.distributions.distribution import draw_values
from aesara.compile import optdb
from aesara.graph.basic import Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value as test_value
from aesara.graph.opt import OpRemove, pre_greedy_local_optimizer
from aesara.graph.optdb import Query
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.subtensor import AdvancedIncSubtensor1
from aesara.tensor.var import TensorConstant
from pymc3.step_methods.arraystep import ArrayStep, BlockedStep, Competence
from pymc3.util import get_untransformed_name

Expand Down Expand Up @@ -184,7 +168,8 @@ def __init__(self, vars, values=None, model=None):

comp_logp_stacked = at.sum(dep_comps_logp_stacked, axis=0)

(M,) = draw_values([var.distribution.gamma_0.shape[-1]], point=model.test_point)
# XXX: This isn't correct.
M = var.owner.inputs[2].eval(model.test_point)
N = model.test_point[var.name].shape[-1]
self.alphas = np.empty((M, N), dtype=float)

Expand Down
26 changes: 1 addition & 25 deletions pymc3_hmm/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import aesara.tensor as at
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand All @@ -8,18 +9,6 @@
from matplotlib.colors import Colormap
from scipy.special import logsumexp

try: # pragma: no cover
import aesara.tensor as at
from aesara.tensor.extra_ops import broadcast_shape
from aesara.tensor.extra_ops import broadcast_to as at_broadcast_to
from aesara.tensor.var import TensorVariable
except ImportError: # pragma: no cover
import theano.tensor as at
from theano.tensor.extra_ops import broadcast_shape
from theano.tensor.extra_ops import broadcast_to as at_broadcast_to
from theano.tensor.var import TensorVariable


vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")


Expand Down Expand Up @@ -177,19 +166,6 @@ def tt_expand_dims(x, dims):
return x.dimshuffle(dim_range)


def tt_broadcast_arrays(*args: TensorVariable):
"""Broadcast any number of arrays against each other.
Parameters
----------
`*args` : array_likes
The arrays to broadcast.
"""
bcast_shape = broadcast_shape(*args)
return tuple(at_broadcast_to(a, bcast_shape) for a in args)


def multilogit_inv(ys):
"""Compute the multilogit-inverse function for both NumPy and Theano arrays.
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
install_requires=[
"numpy>=1.18.1",
"scipy>=1.4.0",
"pymc3>=3.11.1,<4.0.0",
"pymc3>=4.0.0",
"aesara>=2.0.10",
],
tests_require=["pytest"],
long_description=open("README.md").read() if exists("README.md") else "",
Expand Down
Loading

0 comments on commit 36589db

Please sign in to comment.