Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Minibatch of derived RVs and deprecate generators as data #7480

Merged
merged 4 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 48 additions & 35 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@
import pytensor.tensor as pt
import xarray as xr

from pytensor.compile.builders import OpFromGraph
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Variable
from pytensor.raise_op import Assert
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.basic import IntegersRV
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable

import pymc as pm

from pymc.pytensorf import convert_data, smarttypeX
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
from pymc.vartypes import isgenerator

__all__ = [
Expand Down Expand Up @@ -129,46 +130,47 @@
class MinibatchIndexRV(IntegersRV):
_print_name = ("minibatch_index", r"\operatorname{minibatch\_index}")

# Work-around for https://github.com/pymc-devs/pytensor/issues/97
def make_node(self, rng, *args, **kwargs):
if rng is None:
rng = pytensor.shared(np.random.default_rng())
return super().make_node(rng, *args, **kwargs)


minibatch_index = MinibatchIndexRV()


def is_minibatch(v: TensorVariable) -> bool:
return (
isinstance(v.owner.op, AdvancedSubtensor)
and isinstance(v.owner.inputs[1].owner.op, MinibatchIndexRV)
and valid_for_minibatch(v.owner.inputs[0])
)
class MinibatchOp(OpFromGraph):
"""Encapsulate Minibatch random draws in an opaque OFG"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, inline=True)

def __str__(self):
return "Minibatch"

Check warning on line 144 in pymc/data.py

View check run for this annotation

Codecov / codecov/patch

pymc/data.py#L144

Added line #L144 was not covered by tests


def valid_for_minibatch(v: TensorVariable) -> bool:
def is_valid_observed(v) -> bool:
if not isinstance(v, Variable):
# Non-symbolic constant
return True

if v.owner is None:
# Symbolic root variable (constant or not)
return True

return (
v.owner is None
# The only PyTensor operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
or (
(
isinstance(v.owner.op, Elemwise)
and v.owner.inputs[0].owner is None
and isinstance(v.owner.op.scalar_op, Cast)
and is_valid_observed(v.owner.inputs[0])
)
# Or Minibatch
or (
isinstance(v.owner.op, MinibatchOp)
and all(is_valid_observed(inp) for inp in v.owner.inputs)
)
# Or Generator
or isinstance(v.owner.op, GeneratorOp)
)


def assert_all_scalars_equal(scalar, *scalars):
if len(scalars) == 0:
return scalar
else:
return Assert(
"All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code"
)(scalar, pt.all([pt.eq(scalar, s) for s in scalars]))


def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: int):
"""Get random slices from variables from the leading dimension.

Expand All @@ -188,18 +190,29 @@
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")

tensor, *tensors = tuple(map(pt.as_tensor, (variable, *variables)))
upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)])
slc = minibatch_index(0, upper, size=batch_size)
for i, v in enumerate((tensor, *tensors)):
if not valid_for_minibatch(v):
tensors = tuple(map(pt.as_tensor, (variable, *variables)))
for i, v in enumerate(tensors):
if not is_valid_observed(v):
raise ValueError(
f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
)
result = tuple([v[slc] for v in (tensor, *tensors)])
for i, r in enumerate(result):

upper = tensors[0].shape[0]
if len(tensors) > 1:
upper = Assert(
"All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code"
)(upper, pt.all([pt.eq(upper, other_tensor.shape[0]) for other_tensor in tensors[1:]]))

rng = pytensor.shared(np.random.default_rng())
rng_update, mb_indices = minibatch_index(0, upper, size=batch_size, rng=rng).owner.outputs
mb_tensors = [tensor[mb_indices] for tensor in tensors]

# Wrap graph in OFG so it's easily identifiable and not rewritten accidentally
*mb_tensors, _ = MinibatchOp([*tensors, rng], [*mb_tensors, rng_update])(*tensors, rng)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice trick, did not know that

for i, r in enumerate(mb_tensors[:-1]):
r.name = f"minibatch.{i}"
return result if tensors else result[0]

return mb_tensors if len(variables) else mb_tensors[0]


def determine_coords(
Expand Down
4 changes: 2 additions & 2 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def normal_logcdf(value, mu, sigma):
return _logcdf_helper(rv, value, **kwargs)
except NotImplementedError:
# Try to rewrite rv
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
fgraph, _, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
expr = _logcdf_helper(ir_rv, value, **kwargs)
cleanup_ir([expr])
Expand Down Expand Up @@ -390,7 +390,7 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
return _icdf_helper(rv, value, **kwargs)
except NotImplementedError:
# Try to rewrite rv
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
fgraph, _, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
expr = _icdf_helper(ir_rv, value, **kwargs)
cleanup_ir([expr])
Expand Down
6 changes: 2 additions & 4 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
from pytensor import config
from pytensor.compile.mode import optdb
from pytensor.graph.basic import (
Constant,
Variable,
ancestors,
io_toposort,
truncated_graph_inputs,
)
Expand Down Expand Up @@ -400,8 +398,8 @@ def construct_ir_fgraph(
# the old nodes to the new ones; otherwise, we won't be able to use
# `rv_values`.
# We start the `dict` with mappings from the value variables to themselves,
# to prevent them from being cloned. This also includes ancestors
memo = {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)}
# to prevent them from being cloned.
memo = {v: v for v in rv_values.values()}

# We add `ShapeFeature` because it will get rid of references to the old
# `RandomVariable`s that have been lifted; otherwise, it will be difficult
Expand Down
17 changes: 2 additions & 15 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,13 @@
from pytensor.compile import DeepCopyOp, Function, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.variable import TensorConstant, TensorVariable
from typing_extensions import Self

from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import GenTensorVariable, is_minibatch
from pymc.data import is_valid_observed
from pymc.exceptions import (
BlockModelAccessError,
ImputationWarning,
Expand Down Expand Up @@ -1294,18 +1292,7 @@ def register_rv(
self.add_named_variable(rv_var, dims)
self.set_initval(rv_var, initval)
else:
if (
isinstance(observed, Variable)
and not isinstance(observed, GenTensorVariable)
and observed.owner is not None
# The only PyTensor operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
and not (
isinstance(observed.owner.op, Elemwise)
and isinstance(observed.owner.op.scalar_op, Cast)
)
and not is_minibatch(observed)
):
if not is_valid_observed(observed):
raise TypeError(
"Variables that depend on other nodes cannot be used for observed data."
f"The data variable was: {observed}"
Expand Down
27 changes: 18 additions & 9 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,25 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
TypeError

"""
# TODO: These data functions should be in data.py or model/core.py
from pymc.data import MinibatchOp

if isinstance(x, Constant):
return x.data
if isinstance(x, SharedVariable):
return x.get_value()
if x.owner and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast):
array_data = extract_obs_data(x.owner.inputs[0])
return array_data.astype(x.type.dtype)
if x.owner and isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1):
array_data = extract_obs_data(x.owner.inputs[0])
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:])
mask = np.zeros_like(array_data)
mask[mask_idx] = 1
return np.ma.MaskedArray(array_data, mask)
if x.owner is not None:
if isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast):
array_data = extract_obs_data(x.owner.inputs[0])
return array_data.astype(x.type.dtype)
if isinstance(x.owner.op, MinibatchOp):
return extract_obs_data(x.owner.inputs[x.owner.outputs.index(x)])
if isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1):
array_data = extract_obs_data(x.owner.inputs[0])
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:])
mask = np.zeros_like(array_data)
mask[mask_idx] = 1
return np.ma.MaskedArray(array_data, mask)

raise TypeError(f"Data cannot be extracted from {x}")

Expand Down Expand Up @@ -666,6 +672,9 @@ class GeneratorOp(Op):
__props__ = ("generator",)

def __init__(self, gen, default=None):
warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

"generator data is deprecated and will be removed in a future release", FutureWarning
)
from pymc.data import GeneratorAdapter

super().__init__()
Expand Down
5 changes: 3 additions & 2 deletions pymc/variational/minibatch_rv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from pytensor.graph import Apply, Op
from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable

from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
from pymc.logprob.abstract import MeasurableOp, _logprob
from pymc.logprob.basic import logp


class MinibatchRandomVariable(MeasurableOp, Op):
Expand Down Expand Up @@ -99,4 +100,4 @@ def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> Tensor
def minibatch_rv_logprob(op, values, *inputs, **kwargs):
[value] = values
rv, *total_size = inputs
return _logprob_helper(rv, value, **kwargs) * get_scaling(total_size, value.shape)
return logp(rv, value, **kwargs) * get_scaling(total_size, value.shape)
4 changes: 2 additions & 2 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ class GroupError(VariationalInferenceError, TypeError):

def _known_scan_ignored_inputs(terms):
# TODO: remove when scan issue with grads is fixed
from pymc.data import MinibatchIndexRV
from pymc.data import MinibatchOp
from pymc.distributions.simulator import SimulatorRV

return [
n.owner.inputs[0]
for n in pytensor.graph.ancestors(terms)
if n.owner is not None and isinstance(n.owner.op, MinibatchIndexRV | SimulatorRV)
if n.owner is not None and isinstance(n.owner.op, MinibatchOp | SimulatorRV)
]


Expand Down
35 changes: 12 additions & 23 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import io
import itertools as it
import re

from os import path

Expand All @@ -29,7 +28,7 @@

import pymc as pm

from pymc.data import is_minibatch
from pymc.data import MinibatchOp
from pymc.pytensorf import GeneratorOp, floatX


Expand Down Expand Up @@ -593,44 +592,34 @@ class TestMinibatch:

def test_1d(self):
mb = pm.Minibatch(self.data, batch_size=20)
assert is_minibatch(mb)
assert mb.eval().shape == (20, 10)
assert isinstance(mb.owner.op, MinibatchOp)
draw1, draw2 = pm.draw(mb, draws=2)
assert draw1.shape == (20, 10)
assert draw2.shape == (20, 10)
assert not np.all(draw1 == draw2)

def test_allowed(self):
mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20)
assert is_minibatch(mb)
assert isinstance(mb.owner.op, MinibatchOp)

def test_not_allowed(self):
with pytest.raises(ValueError, match="not valid for Minibatch"):
mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)

def test_not_allowed2(self):
with pytest.raises(ValueError, match="not valid for Minibatch"):
mb = pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)
pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)

def test_assert(self):
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
with pytest.raises(
AssertionError, match=r"All variables shape\[0\] in Minibatch should be equal"
):
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
d1.eval()

def test_multiple_vars(self):
A = np.arange(1000)
B = np.arange(1000)
B = -np.arange(1000)
mA, mB = pm.Minibatch(A, B, batch_size=10)

[draw_mA, draw_mB] = pm.draw([mA, mB])
assert draw_mA.shape == (10,)
np.testing.assert_allclose(draw_mA, draw_mB)

# Check invalid dims
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already checked in the test above

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

A = np.arange(1000)
C = np.arange(999)
mA, mC = pm.Minibatch(A, C, batch_size=10)

with pytest.raises(
AssertionError,
match=re.escape("All variables shape[0] in Minibatch should be equal"),
):
pm.draw([mA, mC])
np.testing.assert_allclose(draw_mA, -draw_mB)
Loading
Loading