diff --git a/pymc/data.py b/pymc/data.py index 7e306f19e3..66a78ab28f 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -26,18 +26,19 @@ import pytensor.tensor as pt import xarray as xr +from pytensor.compile.builders import OpFromGraph from pytensor.compile.sharedvalue import SharedVariable +from pytensor.graph.basic import Variable from pytensor.raise_op import Assert from pytensor.scalar import Cast from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.basic import IntegersRV -from pytensor.tensor.subtensor import AdvancedSubtensor from pytensor.tensor.type import TensorType from pytensor.tensor.variable import TensorConstant, TensorVariable import pymc as pm -from pymc.pytensorf import convert_data, smarttypeX +from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX from pymc.vartypes import isgenerator __all__ = [ @@ -129,46 +130,47 @@ def __hash__(self): class MinibatchIndexRV(IntegersRV): _print_name = ("minibatch_index", r"\operatorname{minibatch\_index}") - # Work-around for https://github.com/pymc-devs/pytensor/issues/97 - def make_node(self, rng, *args, **kwargs): - if rng is None: - rng = pytensor.shared(np.random.default_rng()) - return super().make_node(rng, *args, **kwargs) - minibatch_index = MinibatchIndexRV() -def is_minibatch(v: TensorVariable) -> bool: - return ( - isinstance(v.owner.op, AdvancedSubtensor) - and isinstance(v.owner.inputs[1].owner.op, MinibatchIndexRV) - and valid_for_minibatch(v.owner.inputs[0]) - ) +class MinibatchOp(OpFromGraph): + """Encapsulate Minibatch random draws in an opaque OFG""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, inline=True) + + def __str__(self): + return "Minibatch" -def valid_for_minibatch(v: TensorVariable) -> bool: +def is_valid_observed(v) -> bool: + if not isinstance(v, Variable): + # Non-symbolic constant + return True + + if v.owner is None: + # Symbolic root variable (constant or not) + return True + return ( - v.owner is None # The only PyTensor operation we allow on observed data is type casting # Although we could allow for any graph that does not depend on other RVs - or ( + ( isinstance(v.owner.op, Elemwise) - and v.owner.inputs[0].owner is None and isinstance(v.owner.op.scalar_op, Cast) + and is_valid_observed(v.owner.inputs[0]) + ) + # Or Minibatch + or ( + isinstance(v.owner.op, MinibatchOp) + and all(is_valid_observed(inp) for inp in v.owner.inputs) ) + # Or Generator + or isinstance(v.owner.op, GeneratorOp) ) -def assert_all_scalars_equal(scalar, *scalars): - if len(scalars) == 0: - return scalar - else: - return Assert( - "All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code" - )(scalar, pt.all([pt.eq(scalar, s) for s in scalars])) - - def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: int): """Get random slices from variables from the leading dimension. @@ -188,18 +190,29 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: if not isinstance(batch_size, int): raise TypeError("batch_size must be an integer") - tensor, *tensors = tuple(map(pt.as_tensor, (variable, *variables))) - upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)]) - slc = minibatch_index(0, upper, size=batch_size) - for i, v in enumerate((tensor, *tensors)): - if not valid_for_minibatch(v): + tensors = tuple(map(pt.as_tensor, (variable, *variables))) + for i, v in enumerate(tensors): + if not is_valid_observed(v): raise ValueError( f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed" ) - result = tuple([v[slc] for v in (tensor, *tensors)]) - for i, r in enumerate(result): + + upper = tensors[0].shape[0] + if len(tensors) > 1: + upper = Assert( + "All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code" + )(upper, pt.all([pt.eq(upper, other_tensor.shape[0]) for other_tensor in tensors[1:]])) + + rng = pytensor.shared(np.random.default_rng()) + rng_update, mb_indices = minibatch_index(0, upper, size=batch_size, rng=rng).owner.outputs + mb_tensors = [tensor[mb_indices] for tensor in tensors] + + # Wrap graph in OFG so it's easily identifiable and not rewritten accidentally + *mb_tensors, _ = MinibatchOp([*tensors, rng], [*mb_tensors, rng_update])(*tensors, rng) + for i, r in enumerate(mb_tensors[:-1]): r.name = f"minibatch.{i}" - return result if tensors else result[0] + + return mb_tensors if len(variables) else mb_tensors[0] def determine_coords( diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index c945baa751..1f07b9545c 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -308,7 +308,7 @@ def normal_logcdf(value, mu, sigma): return _logcdf_helper(rv, value, **kwargs) except NotImplementedError: # Try to rewrite rv - fgraph, rv_values, _ = construct_ir_fgraph({rv: value}) + fgraph, _, _ = construct_ir_fgraph({rv: value}) [ir_rv] = fgraph.outputs expr = _logcdf_helper(ir_rv, value, **kwargs) cleanup_ir([expr]) @@ -390,7 +390,7 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens return _icdf_helper(rv, value, **kwargs) except NotImplementedError: # Try to rewrite rv - fgraph, rv_values, _ = construct_ir_fgraph({rv: value}) + fgraph, _, _ = construct_ir_fgraph({rv: value}) [ir_rv] = fgraph.outputs expr = _icdf_helper(ir_rv, value, **kwargs) cleanup_ir([expr]) diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index aa3586c21e..eacadfc5d3 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -41,9 +41,7 @@ from pytensor import config from pytensor.compile.mode import optdb from pytensor.graph.basic import ( - Constant, Variable, - ancestors, io_toposort, truncated_graph_inputs, ) @@ -400,8 +398,8 @@ def construct_ir_fgraph( # the old nodes to the new ones; otherwise, we won't be able to use # `rv_values`. # We start the `dict` with mappings from the value variables to themselves, - # to prevent them from being cloned. This also includes ancestors - memo = {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)} + # to prevent them from being cloned. + memo = {v: v for v in rv_values.values()} # We add `ShapeFeature` because it will get rid of references to the old # `RandomVariable`s that have been lifted; otherwise, it will be difficult diff --git a/pymc/model/core.py b/pymc/model/core.py index 3a27417661..eac1629407 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -39,15 +39,13 @@ from pytensor.compile import DeepCopyOp, Function, get_mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant, Variable, graph_inputs -from pytensor.scalar import Cast -from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.variable import TensorConstant, TensorVariable from typing_extensions import Self from pymc.blocking import DictToArrayBijection, RaveledVars -from pymc.data import GenTensorVariable, is_minibatch +from pymc.data import is_valid_observed from pymc.exceptions import ( BlockModelAccessError, ImputationWarning, @@ -1294,18 +1292,7 @@ def register_rv( self.add_named_variable(rv_var, dims) self.set_initval(rv_var, initval) else: - if ( - isinstance(observed, Variable) - and not isinstance(observed, GenTensorVariable) - and observed.owner is not None - # The only PyTensor operation we allow on observed data is type casting - # Although we could allow for any graph that does not depend on other RVs - and not ( - isinstance(observed.owner.op, Elemwise) - and isinstance(observed.owner.op.scalar_op, Cast) - ) - and not is_minibatch(observed) - ): + if not is_valid_observed(observed): raise TypeError( "Variables that depend on other nodes cannot be used for observed data." f"The data variable was: {observed}" diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index cc7204c28a..9d6ecd25a2 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -156,19 +156,25 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: TypeError """ + # TODO: These data functions should be in data.py or model/core.py + from pymc.data import MinibatchOp + if isinstance(x, Constant): return x.data if isinstance(x, SharedVariable): return x.get_value() - if x.owner and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast): - array_data = extract_obs_data(x.owner.inputs[0]) - return array_data.astype(x.type.dtype) - if x.owner and isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1): - array_data = extract_obs_data(x.owner.inputs[0]) - mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:]) - mask = np.zeros_like(array_data) - mask[mask_idx] = 1 - return np.ma.MaskedArray(array_data, mask) + if x.owner is not None: + if isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast): + array_data = extract_obs_data(x.owner.inputs[0]) + return array_data.astype(x.type.dtype) + if isinstance(x.owner.op, MinibatchOp): + return extract_obs_data(x.owner.inputs[x.owner.outputs.index(x)]) + if isinstance(x.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1): + array_data = extract_obs_data(x.owner.inputs[0]) + mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:]) + mask = np.zeros_like(array_data) + mask[mask_idx] = 1 + return np.ma.MaskedArray(array_data, mask) raise TypeError(f"Data cannot be extracted from {x}") @@ -666,6 +672,9 @@ class GeneratorOp(Op): __props__ = ("generator",) def __init__(self, gen, default=None): + warnings.warn( + "generator data is deprecated and will be removed in a future release", FutureWarning + ) from pymc.data import GeneratorAdapter super().__init__() diff --git a/pymc/variational/minibatch_rv.py b/pymc/variational/minibatch_rv.py index 864825910d..be71a358c9 100644 --- a/pymc/variational/minibatch_rv.py +++ b/pymc/variational/minibatch_rv.py @@ -20,7 +20,8 @@ from pytensor.graph import Apply, Op from pytensor.tensor import NoneConst, TensorVariable, as_tensor_variable -from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.logprob.basic import logp class MinibatchRandomVariable(MeasurableOp, Op): @@ -99,4 +100,4 @@ def get_scaling(total_size: Sequence[Variable], shape: TensorVariable) -> Tensor def minibatch_rv_logprob(op, values, *inputs, **kwargs): [value] = values rv, *total_size = inputs - return _logprob_helper(rv, value, **kwargs) * get_scaling(total_size, value.shape) + return logp(rv, value, **kwargs) * get_scaling(total_size, value.shape) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index e1165a874c..c593987105 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -116,13 +116,13 @@ class GroupError(VariationalInferenceError, TypeError): def _known_scan_ignored_inputs(terms): # TODO: remove when scan issue with grads is fixed - from pymc.data import MinibatchIndexRV + from pymc.data import MinibatchOp from pymc.distributions.simulator import SimulatorRV return [ n.owner.inputs[0] for n in pytensor.graph.ancestors(terms) - if n.owner is not None and isinstance(n.owner.op, MinibatchIndexRV | SimulatorRV) + if n.owner is not None and isinstance(n.owner.op, MinibatchOp | SimulatorRV) ] diff --git a/tests/test_data.py b/tests/test_data.py index c8472359f1..2ba66dc744 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -14,7 +14,6 @@ import io import itertools as it -import re from os import path @@ -29,7 +28,7 @@ import pymc as pm -from pymc.data import is_minibatch +from pymc.data import MinibatchOp from pymc.pytensorf import GeneratorOp, floatX @@ -593,44 +592,34 @@ class TestMinibatch: def test_1d(self): mb = pm.Minibatch(self.data, batch_size=20) - assert is_minibatch(mb) - assert mb.eval().shape == (20, 10) + assert isinstance(mb.owner.op, MinibatchOp) + draw1, draw2 = pm.draw(mb, draws=2) + assert draw1.shape == (20, 10) + assert draw2.shape == (20, 10) + assert not np.all(draw1 == draw2) def test_allowed(self): mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20) - assert is_minibatch(mb) + assert isinstance(mb.owner.op, MinibatchOp) - def test_not_allowed(self): with pytest.raises(ValueError, match="not valid for Minibatch"): - mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20) + pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20) - def test_not_allowed2(self): with pytest.raises(ValueError, match="not valid for Minibatch"): - mb = pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20) + pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20) def test_assert(self): + d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20) with pytest.raises( AssertionError, match=r"All variables shape\[0\] in Minibatch should be equal" ): - d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20) d1.eval() def test_multiple_vars(self): A = np.arange(1000) - B = np.arange(1000) + B = -np.arange(1000) mA, mB = pm.Minibatch(A, B, batch_size=10) [draw_mA, draw_mB] = pm.draw([mA, mB]) assert draw_mA.shape == (10,) - np.testing.assert_allclose(draw_mA, draw_mB) - - # Check invalid dims - A = np.arange(1000) - C = np.arange(999) - mA, mC = pm.Minibatch(A, C, batch_size=10) - - with pytest.raises( - AssertionError, - match=re.escape("All variables shape[0] in Minibatch should be equal"), - ): - pm.draw([mA, mC]) + np.testing.assert_allclose(draw_mA, -draw_mB) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index e8881451bf..de1b14c36d 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -26,11 +26,12 @@ from pytensor.compile.builders import OpFromGraph from pytensor.graph.basic import Variable, equal_computations from pytensor.tensor.random.basic import normal, uniform -from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1 +from pytensor.tensor.subtensor import AdvancedIncSubtensor from pytensor.tensor.variable import TensorVariable import pymc as pm +from pymc.data import Minibatch, MinibatchOp from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import SymbolicRandomVariable from pymc.exceptions import NotConstantValueError @@ -135,57 +136,68 @@ def _make_along_axis_idx(arr_shape, indices, axis): return tuple(fancy_index) -def test_extract_obs_data(): - with pytest.raises(TypeError): - extract_obs_data(pt.matrix()) +class TestExtractObsData: + def test_root_variable(self): + with pytest.raises(TypeError): + extract_obs_data(pt.matrix()) - data = np.random.normal(size=(2, 3)) - data_at = pt.as_tensor(data) - mask = np.random.binomial(1, 0.5, size=(2, 3)).astype(bool) - - for val_at in (data_at, pytensor.shared(data)): - res = extract_obs_data(val_at) + def test_constant_variable(self): + data = np.random.normal(size=(2, 3)) + data_pt = pt.as_tensor(data) + res = extract_obs_data(data_pt) assert isinstance(res, np.ndarray) - assert np.array_equal(res, data) - - # AdvancedIncSubtensor check - data_m = np.ma.MaskedArray(data, mask) - missing_values = data_at.type()[mask] - constant = pt.as_tensor(data_m.filled()) - z_at = pt.set_subtensor(constant[mask.nonzero()], missing_values) - - assert isinstance(z_at.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) + np.testing.assert_array_equal(res, data) - res = extract_obs_data(z_at) + def test_shared_variable(self): + data = np.random.normal(size=(2, 3)) + data_pt = shared(data) - assert isinstance(res, np.ndarray) - assert np.ma.allequal(res, data_m) - - # AdvancedIncSubtensor1 check - data = np.random.normal(size=(3,)) - data_at = pt.as_tensor(data) - mask = np.random.binomial(1, 0.5, size=(3,)).astype(bool) + res = extract_obs_data(data_pt) + assert isinstance(res, np.ndarray) + np.testing.assert_array_equal(res, data) + + def test_masked_variable(self): + # Extract data from auto-imputation graph + data = np.random.normal(size=(2, 3)) + data_pt = pt.as_tensor(data) + mask = np.random.binomial(1, 0.5, size=(2, 3)).astype(bool) + + # AdvancedIncSubtensor check + data_m = np.ma.MaskedArray(data, mask) + missing_values = data_pt.type()[mask] + constant = pt.as_tensor(data_m.filled()) + z_at = pt.set_subtensor(constant[mask.nonzero()], missing_values) + assert isinstance(z_at.owner.op, AdvancedIncSubtensor) + + res = extract_obs_data(z_at) + assert isinstance(res, np.ndarray) + assert np.ma.allequal(res, data_m) - data_m = np.ma.MaskedArray(data, mask) - missing_values = data_at.type()[mask] - constant = pt.as_tensor(data_m.filled()) - z_at = pt.set_subtensor(constant[mask.nonzero()], missing_values) + def test_cast_variable(self): + # Cast check + data = np.array(5) + data_pt = pt.cast(pt.as_tensor(5.0), np.int64) - assert isinstance(z_at.owner.op, AdvancedIncSubtensor | AdvancedIncSubtensor1) + res = extract_obs_data(data_pt) + assert isinstance(res, np.ndarray) + np.testing.assert_array_equal(res, data) - res = extract_obs_data(z_at) + def test_minibatch_variable(self): + x = np.arange(5) + y = x * 2 - assert isinstance(res, np.ndarray) - assert np.ma.allequal(res, data_m) + x_mb, y_mb = Minibatch(x, y, batch_size=2) + assert isinstance(x_mb.owner.op, MinibatchOp) + assert isinstance(y_mb.owner.op, MinibatchOp) - # Cast check - data = np.array(5) - t = pt.cast(pt.as_tensor(5.0), np.int64) - res = extract_obs_data(t) + res = extract_obs_data(x_mb) + assert isinstance(res, np.ndarray) + np.testing.assert_array_equal(res, x) - assert isinstance(res, np.ndarray) - assert np.array_equal(res, data) + res = extract_obs_data(y_mb) + assert isinstance(res, np.ndarray) + np.testing.assert_array_equal(res, y) @pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"]) diff --git a/tests/variational/test_inference.py b/tests/variational/test_inference.py index 5bbc95693c..df0208633a 100644 --- a/tests/variational/test_inference.py +++ b/tests/variational/test_inference.py @@ -14,9 +14,6 @@ import io import operator -import warnings - -from contextlib import nullcontext import cloudpickle import numpy as np @@ -162,22 +159,7 @@ def fit_kwargs(inference, use_minibatch): def test_fit_oo(inference, fit_kwargs, simple_model_data): - # Minibatch data can't be extracted into the `observed_data` group in the final InferenceData - if getattr(simple_model_data["data"], "name", "").startswith("minibatch"): - warn_ctxt = pytest.warns( - UserWarning, match="Could not extract data from symbolic observation" - ) - else: - warn_ctxt = nullcontext() - - with warn_ctxt: - with warnings.catch_warnings(): - # Related to https://github.com/arviz-devs/arviz/issues/2327 - warnings.filterwarnings( - "ignore", message="datetime.datetime.utcnow()", category=DeprecationWarning - ) - - trace = inference.fit(**fit_kwargs).sample(10000) + trace = inference.fit(**fit_kwargs).sample(10000) mu_post = simple_model_data["mu_post"] d = simple_model_data["d"] np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_post, rtol=0.05) @@ -203,33 +185,10 @@ def test_fit_start(inference_spec, simple_model): with simple_model: inference = inference_spec(**kw) - # Minibatch data can't be extracted into the `observed_data` group in the final InferenceData - [observed_value] = [simple_model.rvs_to_values[obs] for obs in simple_model.observed_RVs] - - # We can`t use pytest.warns here because after version 8.0 it`s still check for warning when - # exception raised and test failed instead being skipped - warning_raised = False - expected_warning = observed_value.name.startswith("minibatch") - with warnings.catch_warnings(record=True) as record: - warnings.simplefilter("always") - with warnings.catch_warnings(): - # Related to https://github.com/arviz-devs/arviz/issues/2327 - warnings.filterwarnings( - "ignore", message="datetime.datetime.utcnow()", category=DeprecationWarning - ) - - try: - trace = inference.fit(n=0).sample(10000) - except NotImplementedInference as e: - pytest.skip(str(e)) - - if expected_warning: - assert len(record) > 0 - for item in record: - assert issubclass(item.category, UserWarning) - assert "Could not extract data from symbolic observation" in str(item.message) - if not expected_warning: - assert not record + try: + trace = inference.fit(n=0).sample(10000) + except NotImplementedInference as e: + pytest.skip(str(e)) np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05) if has_start_sigma: diff --git a/tests/variational/test_minibatch_rv.py b/tests/variational/test_minibatch_rv.py index 10ab0914fc..6f3e715af7 100644 --- a/tests/variational/test_minibatch_rv.py +++ b/tests/variational/test_minibatch_rv.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np import pytensor +import pytensor.tensor as pt import pytest from scipy import stats as st @@ -20,7 +21,7 @@ import pymc as pm from pymc import Normal, draw -from pymc.data import minibatch_index +from pymc.data import Minibatch from pymc.testing import select_by_precision from pymc.variational.minibatch_rv import create_minibatch_rv from tests.test_data import gen1, gen2 @@ -165,10 +166,7 @@ def test_minibatch_parameter_and_value(self): with pm.Model(check_bounds=False) as m: AD = pm.Data("AD", np.arange(total_size, dtype="float64")) TD = pm.Data("TD", np.arange(total_size, dtype="float64")) - - minibatch_idx = minibatch_index(0, 10, size=(9,)) - AD_mt = AD[minibatch_idx] - TD_mt = TD[minibatch_idx] + AD_mt, TD_mt = Minibatch(AD, TD, batch_size=9) pm.Normal( "AD_predicted", @@ -189,3 +187,12 @@ def test_minibatch_parameter_and_value(self): with m: pm.set_data({"AD": rng.normal(size=1000)}) assert logp_fn(ip) != logp_fn(ip) + + def test_derived_rv(self): + """Test we can obtain a minibatch logp out of a derived RV.""" + dist = pt.clip(pm.Normal.dist(0, 1, size=(1,)), -1, 1) + mb_dist = create_minibatch_rv(dist, total_size=(2,)) + np.testing.assert_allclose( + pm.logp(mb_dist, -1).eval(), + pm.logp(dist, -1).eval() * 2, + )