Skip to content

Commit

Permalink
Bump PyTensor dependency to 2.13
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 17, 2023
1 parent 6bfbff5 commit 2e57d53
Show file tree
Hide file tree
Showing 12 changed files with 16 additions and 74 deletions.
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- scipy>=1.4.1
- typing-extensions>=3.7.4
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- pip
- pytensor>=2.12.0,<2.13
- pytensor>=2.13.0,<2.14
- python-graphviz
- networkx
- scipy>=1.4.1
Expand Down
8 changes: 3 additions & 5 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.var import TensorVariable
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -1229,15 +1230,12 @@ def create_partial_observed_rv(
can_rewrite = True

if can_rewrite:
# Rewrite doesn't work with boolean masks. Should be fixed after https://github.com/pymc-devs/pytensor/pull/329
mask, antimask = mask.nonzero(), antimask.nonzero()

masked_rv = rv[mask]
fgraph = FunctionGraph(outputs=[masked_rv], clone=False)
fgraph = FunctionGraph(outputs=[masked_rv], clone=False, features=[ShapeFeature()])
[unobserved_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)

antimasked_rv = rv[antimask]
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False)
fgraph = FunctionGraph(outputs=[antimasked_rv], clone=False, features=[ShapeFeature()])
[observed_rv] = local_subtensor_rv_lift.transform(fgraph, fgraph.outputs[0].owner)

# Make a clone of the observedRV, with a distinct rng so that observed and
Expand Down
3 changes: 2 additions & 1 deletion pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
local_rv_size_lift,
local_subtensor_rv_lift,
)
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.subtensor import (
AdvancedSubtensor,
Expand Down Expand Up @@ -205,7 +206,7 @@ def expand_indices(

def rv_pull_down(x: TensorVariable, dont_touch_vars=None) -> TensorVariable:
"""Pull a ``RandomVariable`` ``Op`` down through a graph, when possible."""
fgraph = FunctionGraph(outputs=dont_touch_vars or [], clone=False)
fgraph = FunctionGraph(outputs=dont_touch_vars or [], clone=False, features=[ShapeFeature()])

Check warning on line 209 in pymc/logprob/mixture.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/mixture.py#L209

Added line #L209 was not covered by tests

return pre_greedy_node_rewriter(
fgraph,
Expand Down
29 changes: 3 additions & 26 deletions pymc/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,14 @@
where,
zeros_like,
)
from pytensor.tensor.special import log_softmax, softmax

try:
from pytensor.tensor.basic import extract_diag
except ImportError:
from pytensor.tensor.nlinalg import extract_diag

from pytensor.tensor.nlinalg import det, matrix_dot, matrix_inverse, trace
from pytensor.tensor.nlinalg import matrix_inverse
from scipy.linalg import block_diag as scipy_block_diag

from pymc.pytensorf import floatX, ix_, largest_common_dtype
Expand Down Expand Up @@ -267,31 +268,7 @@ def logdiffexp_numpy(a, b):
return a + log1mexp_numpy(b - a, negative_input=True)


def invlogit(x, eps=None):
"""The inverse of the logit function, 1 / (1 + exp(-x))."""
if eps is not None:
warnings.warn(
"pymc.math.invlogit no longer supports the ``eps`` argument and it will be ignored.",
FutureWarning,
stacklevel=2,
)
return pt.sigmoid(x)


def softmax(x, axis=None):
# Ignore vector case UserWarning issued by PyTensor. This can be removed once PyTensor
# drops that warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
return pt.special.softmax(x, axis=axis)


def log_softmax(x, axis=None):
# Ignore vector case UserWarning issued by PyTensor. This can be removed once PyTensor
# drops that warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
return pt.special.log_softmax(x, axis=axis)
invlogit = sigmoid


def logbern(log_p):
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ numpydoc
pandas>=0.24.0
polyagamma
pre-commit>=2.8.0
pytensor>=2.12.0,<2.13
pytensor>=2.13.0,<2.14
pytest-cov>=2.5
pytest>=3.0
scipy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ cloudpickle
fastprogress>=0.2.0
numpy>=1.15.0
pandas>=0.24.0
pytensor>=2.12.0,<2.13
pytensor>=2.13.0,<2.14
scipy>=1.4.1
typing-extensions>=3.7.4
2 changes: 1 addition & 1 deletion tests/distributions/test_dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def check_vals(fn1, fn2, *args):


def test_multigamma():
x = pt.vector("x")
x = pt.vector("x", shape=(1,))
p = pt.scalar("p")

xvals = [np.array([v], dtype=config.floatX) for v in [0.1, 2, 5, 10, 50, 100]]
Expand Down
34 changes: 0 additions & 34 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,37 +262,3 @@ def test_expand_packed_triangular():
assert np.all(expand_upper.eval({packed: upper_packed}) == upper)
assert np.all(expand_diag_lower.eval({packed: lower_packed}) == floatX(np.diag(vals)))
assert np.all(expand_diag_upper.eval({packed: upper_packed}) == floatX(np.diag(vals)))


def test_invlogit_deprecation_warning():
with pytest.warns(
FutureWarning,
match="pymc.math.invlogit no longer supports the",
):
res = invlogit(np.array(-750.0), 1e-5).eval()

with warnings.catch_warnings():
warnings.simplefilter("error")
res_zero_eps = invlogit(np.array(-750.0)).eval()

assert np.isclose(res, res_zero_eps)


@pytest.mark.parametrize(
"pytensor_function, pymc_wrapper",
[
(pt.special.softmax, softmax),
(pt.special.log_softmax, log_softmax),
],
)
def test_softmax_logsoftmax_no_warnings(pytensor_function, pymc_wrapper):
"""Test that wrappers for pytensor functions do not issue Warnings"""

vector = pt.vector("vector")
with pytest.warns(Warning) as record:
pytensor_function(vector)
assert {w.category for w in record.list} == {UserWarning, FutureWarning}

with warnings.catch_warnings():
warnings.simplefilter("error")
pymc_wrapper(vector)

0 comments on commit 2e57d53

Please sign in to comment.