Skip to content

Commit

Permalink
Bump PyTensor dependency to 2.13
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 18, 2023
1 parent 7c2e829 commit 3cc3522
Show file tree
Hide file tree
Showing 17 changed files with 47 additions and 94 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- scipy>=1.4.1
- typing-extensions>=3.7.4
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
17 changes: 10 additions & 7 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.var import TensorVariable
from typing_extensions import TypeAlias

Expand All @@ -54,7 +55,12 @@
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.model import new_or_existing_block_model_access
from pymc.printing import str_for_dist
from pymc.pytensorf import collect_default_updates, convert_observed_data, floatX
from pymc.pytensorf import (
collect_default_updates,
constant_fold,
convert_observed_data,
floatX,
)
from pymc.util import UNSET, _add_future_warning_tag
from pymc.vartypes import continuous_types, string_types

Expand Down Expand Up @@ -1229,15 +1235,12 @@ def create_partial_observed_rv(
can_rewrite = True

if can_rewrite:
# Rewrite doesn't work with boolean masks. Should be fixed after https://github.com/pymc-devs/pytensor/pull/329
mask, antimask = mask.nonzero(), antimask.nonzero()

masked_rv = rv[mask]
fgraph = FunctionGraph(outputs=[masked_rv], clone=False)
fgraph = FunctionGraph(outputs=[masked_rv], clone=False, features=[ShapeFeature()])
[unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)

antimasked_rv = rv[antimask]
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False)
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False, features=[ShapeFeature()])
[observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)

# Make a clone of the observedRV, with a distinct rng so that observed and
Expand Down Expand Up @@ -1270,7 +1273,7 @@ def partial_observed_rv_logprob(op, values, dist, mask, **kwargs):
# For the logp, simply join the values
[obs_value, unobs_value] = values
antimask = ~mask
joined_value = pt.empty_like(dist)
joined_value = pt.empty(constant_fold([dist.shape])[0])
joined_value = pt.set_subtensor(joined_value[mask], unobs_value)
joined_value = pt.set_subtensor(joined_value[antimask], obs_value)
joined_logp = logp(dist, joined_value)
Expand Down
5 changes: 4 additions & 1 deletion pymc/gp/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,10 @@ def n_dims(self):
def _slice(self, X, Xs=None):
xdims = X.shape[-1]
if isinstance(xdims, Variable):
xdims = xdims.eval()
# Circular dependency
from pymc.pytensorf import constant_fold

xdims = constant_fold([xdims])[0]
if self.input_dim != xdims:
warnings.warn(
f"Only {self.input_dim} column(s) out of {xdims} are"
Expand Down
28 changes: 12 additions & 16 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, compute_test_value
from pytensor.graph.rewriting.basic import node_rewriter, pre_greedy_node_rewriter
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter, node_rewriter
from pytensor.ifelse import IfElse, ifelse
from pytensor.scalar import Switch
from pytensor.scalar import switch as scalar_switch
Expand All @@ -52,6 +52,7 @@
local_rv_size_lift,
local_subtensor_rv_lift,
)
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.subtensor import (
AdvancedSubtensor,
Expand All @@ -77,7 +78,6 @@
measurable_ir_rewrites_db,
subtensor_ops,
)
from pymc.logprob.tensor import naive_bcast_rv_lift
from pymc.logprob.utils import check_potential_measurability


Expand Down Expand Up @@ -203,21 +203,17 @@ def expand_indices(
return cast(Tuple[TensorVariable], tuple(pt.broadcast_arrays(*adv_indices)))


def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
def rv_pull_down(x: TensorVariable) -> TensorVariable:
"""Pull a ``RandomVariable`` ``Op`` down through a graph, when possible."""
fgraph = FunctionGraph(outputs=dont_touch_vars or [], clone=False)

return pre_greedy_node_rewriter(
fgraph,
[
local_rv_size_lift,
local_dimshuffle_rv_lift,
local_subtensor_rv_lift,
naive_bcast_rv_lift,
local_lift_DiracDelta,
],
x,
)
fgraph = FunctionGraph(outputs=[x], clone=False, features=[ShapeFeature()])
rewrites = [
local_rv_size_lift,
local_dimshuffle_rv_lift,
local_subtensor_rv_lift,
local_lift_DiracDelta,
]
EquilibriumGraphRewriter(rewrites, max_use_ratio=100).rewrite(fgraph)
return fgraph.outputs[0]


class MixtureRV(Op):
Expand Down
29 changes: 3 additions & 26 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,14 @@
where,
zeros_like,
)
from pytensor.tensor.special import log_softmax, softmax

try:
from pytensor.tensor.basic import extract_diag
except ImportError:
from pytensor.tensor.nlinalg import extract_diag

from pytensor.tensor.nlinalg import det, matrix_dot, matrix_inverse, trace
from pytensor.tensor.nlinalg import matrix_inverse
from scipy.linalg import block_diag as scipy_block_diag

from pymc.pytensorf import floatX, ix_, largest_common_dtype
Expand Down Expand Up @@ -267,31 +268,7 @@ def logdiffexp_numpy(a, b):
return a + log1mexp_numpy(b - a, negative_input=True)


def invlogit(x, eps=None):
"""The inverse of the logit function, 1 / (1 + exp(-x))."""
if eps is not None:
warnings.warn(
"pymc.math.invlogit no longer supports the ``eps`` argument and it will be ignored.",
FutureWarning,
stacklevel=2,
)
return pt.sigmoid(x)


def softmax(x, axis=None):
# Ignore vector case UserWarning issued by PyTensor. This can be removed once PyTensor
# drops that warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
return pt.special.softmax(x, axis=axis)


def log_softmax(x, axis=None):
# Ignore vector case UserWarning issued by PyTensor. This can be removed once PyTensor
# drops that warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
return pt.special.log_softmax(x, axis=axis)
invlogit = sigmoid


def logbern(log_p):
Expand Down
4 changes: 4 additions & 0 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytensor.scalar.basic import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.shape import Shape
from pytensor.tensor.var import TensorConstant, TensorVariable

import pymc as pm
Expand Down Expand Up @@ -55,6 +56,9 @@ def get_parent_names(self, var: TensorVariable) -> Set[VarName]:

def _filter_non_parameter_inputs(var):
node = var.owner
if isinstance(node.op, Shape):
# Don't show shape-related dependencies
return []

Check warning on line 61 in pymc/model_graph.py

View check run for this annotation

Codecov / codecov/patch

pymc/model_graph.py#L61

Added line #L61 was not covered by tests
if isinstance(node.op, RandomVariable):
# Filter out rng, dtype and size parameters or RandomVariable nodes
return node.inputs[3:]
Expand Down
1 change: 0 additions & 1 deletion pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
"generator",
"convert_observed_data",
"compile_pymc",
"constant_fold",
]


Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ numpydoc
pandas>=0.24.0
polyagamma
pre-commit>=2.8.0
pytensor>=2.12.0,<2.13
pytensor>=2.13.0,<2.14
pytest-cov>=2.5
pytest>=3.0
scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ cloudpickle
fastprogress>=0.2.0
numpy>=1.15.0
pandas>=0.24.0
pytensor>=2.12.0,<2.13
pytensor>=2.13.0,<2.14
scipy>=1.4.1
typing-extensions>=3.7.4
2 changes: 1 addition & 1 deletion tests/distributions/test_dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def check_vals(fn1, fn2, *args):


def test_multigamma():
x = pt.vector("x")
x = pt.vector("x", shape=(1,))
p = pt.scalar("p")

xvals = [np.array([v], dtype=config.floatX) for v in [0.1, 2, 5, 10, 50, 100]]
Expand Down
5 changes: 5 additions & 0 deletions tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,7 @@ def test_ifelse_mixture_shared_component():
)


@pytest.mark.xfail(reason="Relied on rewrite-case that is no longer supported by PyTensor")
def test_joint_logprob_subtensor():
"""Make sure we can compute a joint log-probability for ``Y[I]`` where ``Y`` and ``I`` are random variables."""

Expand All @@ -1137,6 +1138,10 @@ def test_joint_logprob_subtensor():
I_rv = pt.random.bernoulli(p, size=size, rng=rng)
I_rv.name = "I"

# The rewrite for lifting subtensored RVs refuses to work with advanced
# indexing as it could lead to repeated draws.
# TODO: Re-enable rewrite for cases where this is not a concern
# (e.g., at least one of the advanced indexes has non-repeating values)
A_idx = A_rv[I_rv, pt.ogrid[A_rv.shape[-1] :]]

assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1))
Expand Down
34 changes: 0 additions & 34 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,37 +262,3 @@ def test_expand_packed_triangular():
assert np.all(expand_upper.eval({packed: upper_packed}) == upper)
assert np.all(expand_diag_lower.eval({packed: lower_packed}) == floatX(np.diag(vals)))
assert np.all(expand_diag_upper.eval({packed: upper_packed}) == floatX(np.diag(vals)))


def test_invlogit_deprecation_warning():
with pytest.warns(
FutureWarning,
match="pymc.math.invlogit no longer supports the",
):
res = invlogit(np.array(-750.0), 1e-5).eval()

with warnings.catch_warnings():
warnings.simplefilter("error")
res_zero_eps = invlogit(np.array(-750.0)).eval()

assert np.isclose(res, res_zero_eps)


@pytest.mark.parametrize(
"pytensor_function, pymc_wrapper",
[
(pt.special.softmax, softmax),
(pt.special.log_softmax, log_softmax),
],
)
def test_softmax_logsoftmax_no_warnings(pytensor_function, pymc_wrapper):
"""Test that wrappers for pytensor functions do not issue Warnings"""

vector = pt.vector("vector")
with pytest.warns(Warning) as record:
pytensor_function(vector)
assert {w.category for w in record.list} == {UserWarning, FutureWarning}

with warnings.catch_warnings():
warnings.simplefilter("error")
pymc_wrapper(vector)
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ def test_set_data_constant_shape_error():
pmodel.add_coord("weekday", length=x.shape[0])
pm.MutableData("y", np.arange(7), dims="weekday")

msg = "because the dimension length is tied to a TensorConstant"
msg = "because the dimension was initialized from 'x'"
with pytest.raises(ShapeError, match=msg):
pmodel.set_data("y", np.arange(10))

Expand Down

0 comments on commit 3cc3522

Please sign in to comment.