diff --git a/pymc/logprob/abstract.py b/pymc/logprob/abstract.py index c12c6b9b78..097bf84afa 100644 --- a/pymc/logprob/abstract.py +++ b/pymc/logprob/abstract.py @@ -40,7 +40,7 @@ from collections.abc import Sequence from functools import singledispatch -from pytensor.graph.op import Op +from pytensor.graph import Apply, Op, Variable from pytensor.graph.utils import MetaType from pytensor.tensor import TensorVariable from pytensor.tensor.elemwise import Elemwise @@ -165,3 +165,129 @@ def __init__(self, scalar_op, *args, **kwargs): def __str__(self): return f"Measurable{super().__str__()}" + + +class ValuedRV(Op): + r"""Represents the association of a measurable variable and its value. + + A `ValuedVariable` node represents the pair :math:`(Y, y)`, where `y` the value at which :math:`Y`'s density + or probability mass function is evaluated. + + The log-probability function takes such pairs as input, which makes these nodes in a graph an intermediate form + that serves to construct a log-probability from a model graph. + + + Notes + ----- + The introduction of these operations achieves two goals: + 1. Identify the conditioning points between multiple, potentially interdependent measurable variables, + and introduce the respective value variables in the IR graph. + 2. Prevent automatic rewrites across conditioning points + + About point 2. In the current framework, a RV logp cannot depend on a transformation of the value variable + of a second RV it depends on. While this is mathematically trivial, we don't have the machinery to achieve it. + + The only case we do something like this is in the ad-hoc transform_value rewrite, but there we are + told explicitly what value variables must be transformed before being used in the density of dependent RVs. + + For example ,the following is not supported: + + ```python + x_log = pt.random.normal() + x = pt.exp(x_log) + y = pt.random.normal(loc=x_log) + + x_value = pt.scalar() + y_value = pt.scalar() + conditional_logprob({x: x_value, y: y_value}) + ``` + + Our framework doesn't know that the density of y should depend on a (log) transform of x_value. + + Importantly, we need to prevent this limitation from being introduced automatically by our IR rewrites. + For example given the following: + + ```python + a_base = pm.Normal.dist() + a = a_base * 5 + b = pm.Normal.dist(a * 8) + + a_value = scalar() + b_value = scalar() + conditional_logp({a: a_value, b: b_value}) + ``` + + We do not want `b` to be rewritten as `pm.Normal.dist(a_base * 40)`, as it would then be disconnected from the + valued `a` associated with `pm.Normal.dist(a_base * 5). By introducing `ValuedRV` nodes the graph looks like: + + ```python + a_base = pm.Normal.dist() + a = valued_rv(a_base * 5, a_value) + b = valued_rv(a * 8, b_value) + ``` + + Since, PyTensor doesn't know what to do with `ValuedRV` nodes, there is no risk of rewriting across them + and breaking the dependency of `b` on `a`. The new nodes isolate the graphs between conditioning points. + """ + + def make_node(self, rv, value): + assert isinstance(rv, Variable) + assert isinstance(value, Variable) + return Apply(self, [rv, value], [rv.type(name=rv.name)]) + + def perform(self, node, inputs, out): + raise NotImplementedError("ValuedVar should not be present in the final graph!") + + def infer_shape(self, fgraph, node, input_shapes): + return [input_shapes[0]] + + +valued_rv = ValuedRV() + + +class PromisedValuedRV(Op): + r"""Marks a variable as being promised a valued variable that will only be assigned by the logprob method. + + Some measurable RVs like Join/MakeVector can combine multiple, potentially interdependent, RVs into a single + composite valued node. Only in the logp function is this value split and sent to each component, + but we still want to achieve the same goals that ValuedRVs achieve during the IR rewrites. + + Here is an example analogous to the one described in the docstrings of ValuedRV: + + ```python + a_base = pt.random.normal() + a = a_base * 5 + b = pt.random.normal(a * 8) + ab = pt.stack([a, b]) + ab_value = pt.vector(shape=(2,)) + + logp(ab, ab_value) + ``` + + The density of `ab[2]` (that is `b`) depends on `ab_value[1]` and `ab_value[0] * 8`, but this is not apparent + in the IR representation because the values of `a` and `b` are merged together, and will only be split by the logp + function (see why next). For the time being we introduce a PromisedValue to isolate the graphs of a and b, and + freezing the dependency of `b` on `a` (not `a_base`). + + Now why use a new Op and not just ValuedRV? Just for convenience! In the end we still want a function from + `ab_value` to `stack([logp(a), logp(b | a)])`, and if we split the values ahead of time we wouldn't know how to + stack them later (or even know that we were supposed to). + + One final point, while this achieves the same goal as introducing ValuedRVs, it already constitutes a form of inference + (knowing how/when to measure Join/MakeVectors), so we have to do it as an IR rewrite. However, we have to do it + before any other rewrites, so you'll see that the related rewrites are registered in `early_measurable_ir_rewrites_db`. + + """ + + def make_node(self, rv): + assert isinstance(rv, Variable) + return Apply(self, [rv], [rv.type(name=rv.name)]) + + def perform(self, node, inputs, out): + raise NotImplementedError("PromisedValuedRV should not be present in the final graph!") + + def infer_shape(self, fgraph, node, input_shapes): + return [input_shapes[0]] + + +promised_valued_rv = PromisedValuedRV() diff --git a/pymc/logprob/basic.py b/pymc/logprob/basic.py index 33fd1f5838..e1a5d2911a 100644 --- a/pymc/logprob/basic.py +++ b/pymc/logprob/basic.py @@ -36,22 +36,17 @@ import warnings -from collections import deque from collections.abc import Sequence from typing import TypeAlias import numpy as np import pytensor.tensor as pt -from pytensor import config from pytensor.graph.basic import ( Constant, Variable, ancestors, - graph_inputs, - io_toposort, ) -from pytensor.graph.op import compute_test_value from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter from pytensor.tensor.variable import TensorVariable @@ -65,7 +60,7 @@ from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph from pymc.logprob.transform_value import TransformValuesRewrite from pymc.logprob.transforms import Transform -from pymc.logprob.utils import rvs_in_graph +from pymc.logprob.utils import get_related_valued_nodes, rvs_in_graph from pymc.pytensorf import replace_vars_in_graphs TensorLike: TypeAlias = Variable | float | np.ndarray @@ -210,8 +205,9 @@ def normal_logp(value, mu, sigma): try: return _logprob_helper(rv, value, **kwargs) except NotImplementedError: - fgraph, _, _ = construct_ir_fgraph({rv: value}) - [(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items() + fgraph = construct_ir_fgraph({rv: value}) + [ir_valued_var] = fgraph.outputs + [ir_rv, ir_value] = ir_valued_var.owner.inputs expr = _logprob_helper(ir_rv, ir_value, **kwargs) cleanup_ir([expr]) if warn_rvs: @@ -308,9 +304,10 @@ def normal_logcdf(value, mu, sigma): return _logcdf_helper(rv, value, **kwargs) except NotImplementedError: # Try to rewrite rv - fgraph, _, _ = construct_ir_fgraph({rv: value}) - [ir_rv] = fgraph.outputs - expr = _logcdf_helper(ir_rv, value, **kwargs) + fgraph = construct_ir_fgraph({rv: value}) + [ir_valued_rv] = fgraph.outputs + [ir_rv, ir_value] = ir_valued_rv.owner.inputs + expr = _logcdf_helper(ir_rv, ir_value, **kwargs) cleanup_ir([expr]) if warn_rvs: _warn_rvs_in_inferred_graph(expr) @@ -390,9 +387,10 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens return _icdf_helper(rv, value, **kwargs) except NotImplementedError: # Try to rewrite rv - fgraph, _, _ = construct_ir_fgraph({rv: value}) - [ir_rv] = fgraph.outputs - expr = _icdf_helper(ir_rv, value, **kwargs) + fgraph = construct_ir_fgraph({rv: value}) + [ir_valued_rv] = fgraph.outputs + [ir_rv, ir_value] = ir_valued_rv.owner.inputs + expr = _icdf_helper(ir_rv, ir_value, **kwargs) cleanup_ir([expr]) if warn_rvs: _warn_rvs_in_inferred_graph(expr) @@ -476,111 +474,96 @@ def conditional_logp( """ warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs) - fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter) + fgraph = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter) if extra_rewrites is not None: extra_rewrites.rewrite(fgraph) - rv_remapper = fgraph.preserve_rv_mappings - - # This is the updated random-to-value-vars map with the lifted/rewritten - # variables. The rewrites are supposed to produce new - # `MeasurableOp`s whose variables are amenable to `_logprob`. - updated_rv_values = rv_remapper.rv_values - - # Some rewrites also transform the original value variables. This is the - # updated map from the new value variables to the original ones, which - # we want to use as the keys in the final dictionary output - original_values = rv_remapper.original_values - - # When a `_logprob` has been produced for a `MeasurableOp` node, all - # other references to it need to be replaced with its value-variable all - # throughout the `_logprob`-produced graphs. The following `dict` - # cumulatively maintains remappings for all the variables/nodes that needed - # to be recreated after replacing `MeasurableOp` variables with their - # value-variables. Since these replacements work in topological order, all - # the necessary value-variable replacements should be present for each - # node. - replacements = updated_rv_values.copy() + # Walk the graph from its inputs to its outputs and construct the + # log-probability + replacements = {} # To avoid cloning the value variables (or ancestors of value variables), # we map them to themselves in the `replacements` `dict` # (i.e. entries already existing in `replacements` aren't cloned) replacements.update( - { - v: v - for v in ancestors(rv_values.values()) - if (not isinstance(v, Constant) and v not in replacements) - } + {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)} ) # Walk the graph from its inputs to its outputs and construct the # log-probability - q = deque(fgraph.toposort()) - logprob_vars = {} - - while q: - node = q.popleft() + values_to_logprobs = {} + original_values = tuple(rv_values.values()) + # TODO: This seems too convoluted, can we just replace all RVs by their values, + # except for the fgraph outputs (for which we want to call _logprob on)? + for node in fgraph.toposort(): if not isinstance(node.op, MeasurableOp): continue - q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values] + valued_nodes = get_related_valued_nodes(node, fgraph) - if not q_values: + if not valued_nodes: continue + node_rvs = [valued_var.inputs[0] for valued_var in valued_nodes] + node_values = [valued_var.inputs[1] for valued_var in valued_nodes] + node_output_idxs = [ + fgraph.outputs.index(valued_var.outputs[0]) for valued_var in valued_nodes + ] + # Replace `RandomVariable`s in the inputs with value variables. + # Also, store the results in the `replacements` map for the nodes that follow. + for node_rv, node_value in zip(node_rvs, node_values): + replacements[node_rv] = node_value + remapped_vars = replace_vars_in_graphs( - graphs=q_values + list(node.inputs), + graphs=node_values + list(node.inputs), replacements=replacements, ) - q_values = remapped_vars[: len(q_values)] - q_rv_inputs = remapped_vars[len(q_values) :] + node_values = remapped_vars[: len(node_values)] + node_inputs = remapped_vars[len(node_values) :] - q_logprob_vars = _logprob( + node_logprobs = _logprob( node.op, - q_values, - *q_rv_inputs, + node_values, + *node_inputs, **kwargs, ) - if not isinstance(q_logprob_vars, list | tuple): - q_logprob_vars = [q_logprob_vars] + if not isinstance(node_logprobs, list | tuple): + node_logprobs = [node_logprobs] - for q_value_var, q_logprob_var in zip(q_values, q_logprob_vars): - q_value_var = original_values[q_value_var] + for node_output_idx, node_value, node_logprob in zip( + node_output_idxs, node_values, node_logprobs + ): + original_value = original_values[node_output_idx] - if q_value_var.name: - q_logprob_var.name = f"{q_value_var.name}_logprob" + if original_value.name: + node_logprob.name = f"{original_value.name}_logprob" - if q_value_var in logprob_vars: + if original_value in values_to_logprobs: raise ValueError( - f"More than one logprob term was assigned to the value var {q_value_var}" + f"More than one logprob term was assigned to the value var {original_value}" ) - logprob_vars[q_value_var] = q_logprob_var - - # Recompute test values for the changes introduced by the replacements above. - if config.compute_test_value != "off": - for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars): - compute_test_value(node) + values_to_logprobs[original_value] = node_logprob - missing_value_terms = set(original_values.values()) - set(logprob_vars.keys()) + missing_value_terms = set(original_values) - set(values_to_logprobs) if missing_value_terms: raise RuntimeError( f"The logprob terms of the following value variables could not be derived: {missing_value_terms}" ) - logprob_expressions = list(logprob_vars.values()) - cleanup_ir(logprob_expressions) + logprobs = list(values_to_logprobs.values()) + cleanup_ir(logprobs) if warn_rvs: - rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprob_expressions) + rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs) if rvs_in_logp_expressions: warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning) - return logprob_vars + return values_to_logprobs def transformed_conditional_logp( diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index df37e33782..0767d25f8f 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -28,8 +28,8 @@ _logprob, _logprob_helper, ) -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import check_potential_measurability +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import check_potential_measurability, filter_measurable_variables class MeasurableComparison(MeasurableElemwise): @@ -40,11 +40,7 @@ class MeasurableComparison(MeasurableElemwise): @node_rewriter(tracks=[gt, lt, ge, le]) def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - - measurable_inputs = rv_map_feature.request_measurable(node.inputs) + measurable_inputs = filter_measurable_variables(node.inputs) if len(measurable_inputs) != 1: return None @@ -62,7 +58,7 @@ def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[Tenso const = node.inputs[(measurable_var_idx + 1) % 2] # check for potential measurability of const - if check_potential_measurability([const], rv_map_feature.rv_values.keys()): + if check_potential_measurability([const]): return None node_scalar_op = node.op.scalar_op @@ -132,16 +128,12 @@ class MeasurableBitwise(MeasurableElemwise): @node_rewriter(tracks=[invert]) def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - base_var = node.inputs[0] if not base_var.dtype.startswith("bool"): raise None - if not rv_map_feature.request_measurable([base_var]): + if not filter_measurable_variables([base_var]): return None node_scalar_op = node.op.scalar_op diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index d582da0799..248c285ba5 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -48,8 +48,8 @@ from pytensor.tensor.variable import TensorConstant from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import CheckParameterValue +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import CheckParameterValue, filter_measurable_variables class MeasurableClip(MeasurableElemwise): @@ -65,11 +65,7 @@ class MeasurableClip(MeasurableElemwise): def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None base_var, lower_bound, upper_bound = node.inputs @@ -158,11 +154,7 @@ class MeasurableRound(MeasurableElemwise): @node_rewriter(tracks=[ceil, floor, round_half_to_even]) def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None: - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None [base_var] = node.inputs diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 93056ea435..c9c60bb0fb 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -33,7 +33,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - +from typing import cast import pytensor.tensor as pt @@ -43,8 +43,8 @@ from pytensor.tensor.shape import SpecifyShape from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import replace_rvs_by_values +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import filter_measurable_variables, replace_rvs_by_values class MeasurableSpecifyShape(MeasurableOp, SpecifyShape): @@ -66,24 +66,12 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None: if isinstance(node.op, MeasurableSpecifyShape): return None # pragma: no cover - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - - rv = node.outputs[0] - base_rv, *shape = node.inputs - if not ( - base_rv.owner - and isinstance(base_rv.owner.op, MeasurableOp) - and base_rv not in rv_map_feature.rv_values - ): - return None # pragma: no cover + if not filter_measurable_variables([base_rv]): + return None - new_op = MeasurableSpecifyShape() - new_rv = new_op.make_node(base_rv, *shape).default_output() + new_rv = cast(TensorVariable, MeasurableSpecifyShape()(base_rv, *shape)) return [new_rv] @@ -116,13 +104,9 @@ def find_measurable_check_and_raise(fgraph, node) -> list[TensorVariable] | None if isinstance(node.op, MeasurableCheckAndRaise): return None # pragma: no cover - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - base_rv, *conds = node.inputs - if not rv_map_feature.request_measurable([base_rv]): + + if not filter_measurable_variables([base_rv]): return None op = node.op diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index dc452b0011..af7f73888c 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -42,7 +42,8 @@ from pytensor.tensor.extra_ops import CumOp from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import filter_measurable_variables class MeasurableCumsum(MeasurableOp, CumOp): @@ -78,15 +79,10 @@ def find_measurable_cumsums(fgraph, node) -> list[TensorVariable] | None: r"""Finds `Cumsums`\s for which a `logprob` can be computed.""" if not (isinstance(node.op, CumOp) and node.op.mode == "add"): - return None # pragma: no cover + return None if isinstance(node.op, MeasurableCumsum): - return None # pragma: no cover - - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover + return None base_rv = node.inputs[0] @@ -94,7 +90,7 @@ def find_measurable_cumsums(fgraph, node) -> list[TensorVariable] | None: if base_rv.ndim > 1 and node.op.axis is None: return None - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None new_op = MeasurableCumsum(axis=node.op.axis or 0, mode="add") diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 00a59533c6..0970fe21ce 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -39,7 +39,7 @@ import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.basic import Apply, Constant, Variable, ancestors from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op, compute_test_value from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter, node_rewriter @@ -68,17 +68,22 @@ from pymc.logprob.abstract import ( MeasurableElemwise, MeasurableOp, + PromisedValuedRV, _logprob, _logprob_helper, + valued_rv, ) from pymc.logprob.rewriting import ( - PreserveRVMappings, - assume_measured_ir_outputs, + early_measurable_ir_rewrites_db, local_lift_DiracDelta, measurable_ir_rewrites_db, subtensor_ops, ) -from pymc.logprob.utils import check_potential_measurability, replace_rvs_by_values +from pymc.logprob.utils import ( + check_potential_measurability, + filter_measurable_variables, + get_related_valued_nodes, +) from pymc.pytensorf import constant_fold @@ -260,6 +265,11 @@ def get_stack_mixture_vars( mixture_rvs = joined_rvs.owner.inputs[1:] + # Join and MakeVector can introduce PromisedValuedRV to prevent losing interdependencies + mixture_rvs = [ + rv.owner.inputs[0] if rv.owner and isinstance(rv.owner.op, PromisedValuedRV) else rv + for rv in mixture_rvs + ] return mixture_rvs, join_axis @@ -273,11 +283,6 @@ def find_measurable_index_mixture(fgraph, node): From these terms, new terms ``Z_rv[i] = mixture_comps[i][i == I_rv]`` are created for each ``i`` in ``enumerate(mixture_comps)``. """ - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - mixing_indices = node.inputs[1:] # TODO: Add check / test case for Advanced Boolean indexing @@ -298,7 +303,7 @@ def find_measurable_index_mixture(fgraph, node): if mixture_rvs is None or not isinstance(join_axis, NoneTypeT | Constant): return None - if rv_map_feature.request_measurable(mixture_rvs) != mixture_rvs: + if set(filter_measurable_variables(mixture_rvs)) != set(mixture_rvs): return None # Replace this sub-graph with a `MixtureRV` @@ -403,10 +408,8 @@ class MeasurableSwitchMixture(MeasurableElemwise): @node_rewriter([switch]) def find_measurable_switch_mixture(fgraph, node): - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover + if isinstance(node.op, MeasurableOp): + return None switch_cond, *components = node.inputs @@ -417,12 +420,11 @@ def find_measurable_switch_mixture(fgraph, node): if any(comp.type.broadcastable != out_bcast for comp in components): return None - # Check that `switch_cond` is not potentially measurable - valued_rvs = rv_map_feature.rv_values.keys() - if check_potential_measurability([switch_cond], valued_rvs): + if set(filter_measurable_variables(components)) != set(components): return None - if rv_map_feature.request_measurable(components) != components: + # Check that `switch_cond` is not potentially measurable + if check_potential_measurability([switch_cond]): return None return [measurable_switch_mixture(switch_cond, *components)] @@ -459,67 +461,106 @@ class MeasurableIfElse(MeasurableOp, IfElse): @node_rewriter([IfElse]) -def useless_ifelse_outputs(fgraph, node): - """Remove outputs that are shared across the IfElse branches.""" - # TODO: This should be a PyTensor canonicalization +def split_valued_ifelse(fgraph, node): + """Split valued variables in multi-output ifelse into their own ifelse.""" op = node.op - if_var, *inputs = node.inputs - shared_inputs = set(inputs[op.n_outs :]).intersection(inputs[: op.n_outs]) - if not shared_inputs: + + if op.n_outs == 1: + # Single outputs IfElse + return None + + valued_output_nodes = get_related_valued_nodes(node, fgraph) + if not valued_output_nodes: return None - replacements = {} - for shared_inp in shared_inputs: - idx = inputs.index(shared_inp) - replacements[node.outputs[idx]] = shared_inp + cond, *all_outputs = node.inputs + then_outputs = all_outputs[: op.n_outs] + else_outputs = all_outputs[op.n_outs :] + + # Split first topological valued output + then_else_valued_outputs = [] + for valued_output_node in valued_output_nodes: + rv, value = valued_output_node.inputs + [valued_out] = valued_output_node.outputs + rv_idx = node.outputs.index(rv) + then_else_valued_outputs.append( + ( + then_outputs[rv_idx], + else_outputs[rv_idx], + value, + valued_out, + ) + ) - # IfElse isn't needed at all - if len(shared_inputs) == op.n_outs: - return replacements + toposort = fgraph.toposort() + then_else_valued_outputs = sorted( + then_else_valued_outputs, + key=lambda x: max(toposort.index(x[0].owner), toposort.index(x[1].owner)), + ) - # Create subset IfElse with remaining nodes - remaining_inputs = [inp for inp in inputs if inp not in shared_inputs] - new_outs = ( - IfElse(n_outs=len(remaining_inputs) // 2).make_node(if_var, *remaining_inputs).outputs + (first_then, first_else, first_value_var, first_valued_out), *remaining_vars = ( + then_else_valued_outputs ) - for inp, new_out in zip(remaining_inputs, new_outs): - idx = inputs.index(inp) - replacements[node.outputs[idx]] = new_out + first_ifelse = ifelse(cond, first_then, first_else) + first_valued_ifelse = valued_rv(first_ifelse, first_value_var) + replacements = {first_valued_out: first_valued_ifelse} + + if remaining_vars: + first_ifelse_ancestors = set(a for a in ancestors((first_then, first_else)) if a.owner) + remaining_thens = [then_out for (then_out, _, _, _) in remaining_vars] + remaininng_elses = [else_out for (_, else_out, _, _) in remaining_vars] + if set(remaining_thens + remaininng_elses) & first_ifelse_ancestors: + # IfElse graph cannot be split, because some remaining variables are inputs to first ifelse + return None + + remaining_ifelses = ifelse(cond, remaining_thens, remaininng_elses) + # Replace potential dependencies on first_then, first_else in remaining ifelse by first_valued_ifelse + dummy_first_valued_ifelse = first_valued_ifelse.type() + temp_fgraph = FunctionGraph( + outputs=[*remaining_ifelses, dummy_first_valued_ifelse], clone=False + ) + temp_fgraph.replace(first_then, dummy_first_valued_ifelse) + temp_fgraph.replace(first_else, dummy_first_valued_ifelse) + temp_fgraph.replace(dummy_first_valued_ifelse, first_valued_ifelse, import_missing=True) + for remaining_ifelse, (_, _, remaining_value_var, remaining_valued_out) in zip( + remaining_ifelses, remaining_vars + ): + remaining_valued_ifelse = valued_rv(remaining_ifelse, remaining_value_var) + replacements[remaining_valued_out] = remaining_valued_ifelse return replacements @node_rewriter([IfElse]) def find_measurable_ifelse_mixture(fgraph, node): - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - + """Find `IfElse` nodes that can be replaced by `MeasurableIfElse`.""" op = node.op - if_var, *base_rvs = node.inputs - valued_rvs = rv_map_feature.rv_values.keys() - if not all(check_potential_measurability([base_var], valued_rvs) for base_var in base_rvs): + if isinstance(op, MeasurableOp): + return None + + if op.n_outs > 1: + # The rewrite split_measurable_ifelse should take care of this return None - base_rvs = assume_measured_ir_outputs(valued_rvs, base_rvs) - if len(base_rvs) != op.n_outs * 2: + if_var, then_rv, else_rv = node.inputs + + if check_potential_measurability([if_var]): return None - if not all(var.owner and isinstance(var.owner.op, MeasurableOp) for var in base_rvs): + + if len(filter_measurable_variables([then_rv, else_rv])) != 2: return None - return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs + return MeasurableIfElse(n_outs=op.n_outs)(if_var, then_rv, else_rv, return_list=True) -measurable_ir_rewrites_db.register( - "useless_ifelse_outputs", - useless_ifelse_outputs, +early_measurable_ir_rewrites_db.register( + "split_valued_ifelse", + split_valued_ifelse, "basic", "mixture", ) - measurable_ir_rewrites_db.register( "find_measurable_ifelse_mixture", find_measurable_ifelse_mixture, @@ -529,27 +570,9 @@ def find_measurable_ifelse_mixture(fgraph, node): @_logprob.register(MeasurableIfElse) -def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs): +def logprob_ifelse(op, values, if_var, rv_then, rv_else, **kwargs): """Compute the log-likelihood graph for an `IfElse`.""" - - assert len(values) * 2 == len(base_rvs) - - rvs_to_values_then = {then_rv: value for then_rv, value in zip(base_rvs[: len(values)], values)} - rvs_to_values_else = {else_rv: value for else_rv, value in zip(base_rvs[len(values) :], values)} - - logps_then = [ - _logprob_helper(rv_then, value, **kwargs) for rv_then, value in rvs_to_values_then.items() - ] - logps_else = [ - _logprob_helper(rv_else, value, **kwargs) for rv_else, value in rvs_to_values_else.items() - ] - - # If the multiple variables depend on each other, we have to replace them - # by the respective values - logps_then = replace_rvs_by_values(logps_then, rvs_to_values=rvs_to_values_then) - logps_else = replace_rvs_by_values(logps_else, rvs_to_values=rvs_to_values_else) - - logps = ifelse(if_var, logps_then, logps_else) - if len(logps) == 1: - return logps[0] - return logps + [value] = values + logps_then = _logprob_helper(rv_then, value, **kwargs) + logps_else = _logprob_helper(rv_else, value, **kwargs) + return ifelse(if_var, logps_then, logps_else) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index f7322fbca7..51833a128b 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -51,6 +51,7 @@ _logprob_helper, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import filter_measurable_variables from pymc.math import logdiffexp from pymc.pytensorf import constant_fold @@ -65,10 +66,6 @@ class MeasurableMaxDiscrete(MeasurableOp, Max): @node_rewriter([Max]) def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: - rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover - if isinstance(node.op, MeasurableMax | MeasurableMaxDiscrete): return None @@ -77,7 +74,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab if base_var.owner is None: return None - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None # We allow Max of RandomVariables or Elemwise of univariate RandomVariables diff --git a/pymc/logprob/rewriting.py b/pymc/logprob/rewriting.py index 3e6cc8bac7..bd171441ac 100644 --- a/pymc/logprob/rewriting.py +++ b/pymc/logprob/rewriting.py @@ -33,24 +33,18 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import warnings -from collections import deque -from collections.abc import Collection, Sequence +from collections.abc import Sequence -from pytensor import config from pytensor.compile.mode import optdb from pytensor.graph.basic import ( Variable, - io_toposort, + ancestors, truncated_graph_inputs, ) -from pytensor.graph.features import Feature from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import ( - ChangeTracker, - EquilibriumGraphRewriter, GraphRewriter, node_rewriter, out2in, @@ -58,7 +52,6 @@ from pytensor.graph.rewriting.db import ( EquilibriumDB, LocalGroupDB, - RewriteDatabase, RewriteDatabaseQuery, SequenceDB, TopoDB, @@ -79,204 +72,34 @@ ) from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableOp +from pymc.logprob.abstract import PromisedValuedRV, ValuedRV, valued_rv from pymc.logprob.utils import DiracDelta +from pymc.pytensorf import toposort_replace inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1) subtensor_ops = (AdvancedSubtensor, AdvancedSubtensor1, Subtensor) -class MeasurableEquilibriumGraphRewriter(EquilibriumGraphRewriter): - """EquilibriumGraphRewriter focused on IR measurable rewrites. +@node_rewriter([ValuedRV]) +def local_remove_valued_rv(fgraph, node): + rv = node.inputs[0] + return [rv] - This is a stripped down version of the EquilibriumGraphRewriter, - which specifically targets nodes in `PreserveRVMAppings.needs_measuring` - that are not yet measurable. - """ - - def apply(self, fgraph): - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - if not rv_map_feature: - return None - - change_tracker = ChangeTracker() - fgraph.attach_feature(change_tracker) - - changed = True - max_use_abort = False - rewriter_name = None - global_process_count = {} - - for rewriter in self.global_rewriters + list(self.get_node_rewriters()): - global_process_count.setdefault(rewriter, 0) - - while changed and not max_use_abort: - changed = False - max_nb_nodes = len(fgraph.apply_nodes) - max_use = max_nb_nodes * self.max_use_ratio - - # Apply global rewriters - for grewrite in self.global_rewriters: - change_tracker.reset() - grewrite.apply(fgraph) - if change_tracker.changed: - global_process_count[grewrite] += 1 - changed = True - if global_process_count[grewrite] > max_use: - max_use_abort = True - rewriter_name = getattr(grewrite, "name", None) or getattr( - grewrite, "__name__", "" - ) - - # Apply local node rewriters - q = deque(io_toposort(fgraph.inputs, fgraph.outputs)) - while q: - node = q.pop() - if node not in fgraph.apply_nodes: - continue - # This is where we filter only those nodes we care about: - # Nodes that have variables that we want to measure and are not yet measurable - if isinstance(node.op, MeasurableOp): - continue - if not any(out in rv_map_feature.needs_measuring for out in node.outputs): - continue - for node_rewriter in self.node_tracker.get_trackers(node.op): # noqa F402 - node_rewriter_change = self.process_node(fgraph, node, node_rewriter) - if not node_rewriter_change: - continue - global_process_count[node_rewriter] += 1 - changed = True - if global_process_count[node_rewriter] > max_use: - max_use_abort = True - rewriter_name = getattr(node_rewriter, "name", None) or getattr( - node_rewriter, "__name__", "" - ) - # If we converted to a MeasurableOp we're done here! - if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableOp): - # go to next node - break - - if max_use_abort: - msg = ( - f"{type(self).__name__} max'ed out by {rewriter_name}." - "You can safely raise the current threshold of " - f"{config.optdb__max_use_ratio} with the option `optdb__max_use_ratio`." - ) - if config.on_opt_error == "raise": - raise AssertionError(msg) - else: - warnings.warn(msg) - fgraph.remove_feature(change_tracker) - - -class MeasurableEquilibriumDB(RewriteDatabase): - """A database of rewrites that should be applied until equilibrium is reached. - - This will return a MeasurableEquilibriumGraphRewriter when queried. - - """ - - def query(self, *tags, **kwtags): - rewriters = super().query(*tags, **kwtags) - return MeasurableEquilibriumGraphRewriter( - rewriters, - max_use_ratio=config.optdb__max_use_ratio, - ) - - -class PreserveRVMappings(Feature): - r"""Keeps track of random variables and their respective value variables during - graph rewrites in `rv_values` - - When a random variable is replaced in a rewrite, this `Feature` automatically - updates the `rv_values` mapping, so that the new variable is linked to the - original value variable. +remove_valued_rvs = out2in(local_remove_valued_rv) - In addition this `Feature` provides functionality to manually update a random - and/or value variable. A mapping from the transformed value variables to the - the original value variables is kept in `original_values`. - Likewise, a `measurable_conversions` map is maintained, which holds - information about un-valued and un-measurable variables that were replaced - with measurable variables. This information can be used to revert these - rewrites. +@node_rewriter([PromisedValuedRV]) +def local_remove_promised_value_rv(fgraph, node): + rv = node.inputs[0] + return [rv] - """ - def __init__(self, rv_values: dict[TensorVariable, TensorVariable]): - """ - Parameters - ---------- - rv_values - Mappings between random variables and their value variables. - The keys of this map are what this `Feature` keeps updated. - The ``dict`` is updated in-place. - """ - self.rv_values = rv_values - self.original_values = {v: v for v in rv_values.values()} - self.needs_measuring = set(rv_values.keys()) - - def on_attach(self, fgraph): - if hasattr(fgraph, "preserve_rv_mappings"): - raise ValueError(f"{fgraph} already has the `PreserveRVMappings` feature attached.") - - fgraph.preserve_rv_mappings = self - - def update_rv_maps( - self, - old_rv: TensorVariable, - new_value: TensorVariable, - new_rv: TensorVariable | None = None, - ): - """Update mappings for a random variable. - - It also creates/updates a map from new value variables to their - original value variables. - - Parameters - ---------- - old_rv - The random variable whose mappings will be updated. - new_value - The new value variable that will replace the current one assigned - to `old_rv`. - new_rv - When non-``None``, `old_rv` will also be replaced with `new_rv` in - the mappings, as well. - """ - old_value = self.rv_values.pop(old_rv) - original_value = self.original_values.pop(old_value) - - if new_rv is None: - new_rv = old_rv - - self.rv_values[new_rv] = new_value - self.original_values[new_value] = original_value - - def on_change_input(self, fgraph, node, i, r, new_r, reason=None): - """ - Whenever a node is replaced during rewrite, we check if it had a value - variable associated with it and map it to the new node. - """ - r_value_var = self.rv_values.pop(r, None) - if r_value_var is not None: - self.rv_values[new_r] = r_value_var - self.needs_measuring.add(new_r) - if new_r.name is None: - new_r.name = r.name - - def request_measurable(self, vars: Sequence[Variable]) -> list[Variable]: - measurable = [] - for var in vars: - # Input vars or valued vars can't be measured for derived expressions - if not var.owner or var in self.rv_values: - continue - if isinstance(var.owner.op, MeasurableOp): - measurable.append(var) - else: - self.needs_measuring.add(var) - return measurable +def remove_promised_valued_rvs(outputs): + fgraph = FunctionGraph(outputs=outputs, clone=False) + rewrite = out2in(local_remove_promised_value_rv) + rewrite.apply(fgraph) + return fgraph.outputs @register_canonicalize @@ -312,6 +135,20 @@ def remove_DiracDelta(fgraph, node): logprob_rewrites_db = SequenceDB() logprob_rewrites_db.name = "logprob_rewrites_db" + +early_measurable_ir_rewrites_db = LocalGroupDB() +early_measurable_ir_rewrites_db.name = "early_measurable_rewrites_db" +logprob_rewrites_db.register( + "early_ir_rewrites", + TopoDB( + early_measurable_ir_rewrites_db, + order="in_to_out", + ignore_newtrees=False, + failure_callback=None, + ), + "basic", +) + # Introduce sigmoid. We do it before canonicalization so that useless mul are removed next logprob_rewrites_db.register( "local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic" @@ -321,7 +158,7 @@ def remove_DiracDelta(fgraph, node): # These rewrites convert un-measurable variables into their measurable forms, # but they need to be reapplied, because some of the measurable forms require # their inputs to be measurable. -measurable_ir_rewrites_db = MeasurableEquilibriumDB() +measurable_ir_rewrites_db = EquilibriumDB() measurable_ir_rewrites_db.name = "measurable_ir_rewrites_db" logprob_rewrites_db.register("measurable_ir_rewrites", measurable_ir_rewrites_db, "basic") @@ -351,12 +188,13 @@ def remove_DiracDelta(fgraph, node): ) cleanup_ir_rewrites_db.register("remove_DiracDelta", remove_DiracDelta, "cleanup") +cleanup_ir_rewrites_db.register("local_remove_valued_rv", local_remove_valued_rv, "cleanup") def construct_ir_fgraph( rv_values: dict[Variable, Variable], ir_rewriter: GraphRewriter | None = None, -) -> tuple[FunctionGraph, dict[Variable, Variable], dict[Variable, Variable]]: +) -> FunctionGraph: r"""Construct a `FunctionGraph` in measurable IR form for the keys in `rv_values`. A custom IR rewriter can be specified. By default, @@ -383,46 +221,37 @@ def construct_ir_fgraph( Returns ------- - A `FunctionGraph` of the measurable IR, a copy of `rv_values` containing - the new, cloned versions of the original variables in `rv_values`, and - a ``dict`` mapping all the original variables to their cloned values in - `FunctionGraph`. + A `FunctionGraph` of the measurable IR. """ - # Since we're going to clone the entire graph, we need to keep a map from - # the old nodes to the new ones; otherwise, we won't be able to use - # `rv_values`. - # We start the `dict` with mappings from the value variables to themselves, - # to prevent them from being cloned. - memo = {v: v for v in rv_values.values()} - # We add `ShapeFeature` because it will get rid of references to the old # `RandomVariable`s that have been lifted; otherwise, it will be difficult - # to give good warnings when an unaccounted for `RandomVariable` is - # encountered + # to give good warnings when an unaccounted for `RandomVariable` is encountered fgraph = FunctionGraph( outputs=list(rv_values.keys()), clone=True, - memo=memo, copy_orphans=False, copy_inputs=False, features=[ShapeFeature()], ) - # Update `rv_values` so that it uses the new cloned variables - rv_values = {memo[k]: v for k, v in rv_values.items()} + # Replace valued RVs by ValuedVar Ops so that rewrites are aware of conditioning points + # We use clones of the value variables so that they are not affected by rewrites + cloned_values = tuple(v.clone() for v in rv_values.values()) + ir_rv_values = {rv: value for rv, value in zip(fgraph.outputs, cloned_values)} - # This `Feature` preserves the relationships between the original - # random variables (i.e. keys in `rv_values`) and the new ones - # produced when `Op`s are lifted through them. - rv_remapper = PreserveRVMappings(rv_values) - fgraph.attach_feature(rv_remapper) + replacements = tuple((rv, valued_rv(rv, value)) for rv, value in ir_rv_values.items()) + toposort_replace(fgraph, replacements, reverse=True) if ir_rewriter is None: ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])) ir_rewriter.rewrite(fgraph) - return fgraph, rv_values, memo + # Reintroduce original value variables + replacements = tuple((cloned_v, v) for v, cloned_v in zip(rv_values.values(), cloned_values)) + toposort_replace(fgraph, replacements=replacements, reverse=True) + + return fgraph def cleanup_ir(vars: Sequence[Variable]) -> None: @@ -431,9 +260,7 @@ def cleanup_ir(vars: Sequence[Variable]) -> None: ir_rewriter.rewrite(fgraph) -def assume_measured_ir_outputs( - inputs: Collection[TensorVariable], outputs: Sequence[TensorVariable] -) -> Sequence[TensorVariable]: +def assume_valued_outputs(outputs: Sequence[TensorVariable]) -> Sequence[TensorVariable]: """Run IR rewrite assuming each output is measured. IR variables could depend on each other in a way that looks unmeasurable without a value variable assigned to each. @@ -442,7 +269,12 @@ def assume_measured_ir_outputs( This helper runs an inner ir rewrite after giving each output a dummy value variable. We replace inputs by dummies and then undo it so that any dependency on outer variables is preserved. """ - # Replace inputs by dummy variables + # Replace inputs by dummy variables (so they are not affected) + inputs = [ + valued_var + for valued_var in ancestors(outputs) + if (valued_var.owner and isinstance(valued_var.owner.op, ValuedRV)) + ] replaced_inputs = { var: var.type() for var in truncated_graph_inputs(outputs, ancestors_to_include=inputs) @@ -451,9 +283,10 @@ def assume_measured_ir_outputs( cloned_outputs = clone_replace(outputs, replace=replaced_inputs) dummy_rv_values = {base_var: base_var.type() for base_var in cloned_outputs} - fgraph, *_ = construct_ir_fgraph(dummy_rv_values) + fgraph = construct_ir_fgraph(dummy_rv_values) + remove_valued_rvs.apply(fgraph) - # Replace dummy variables by inputs + # Replace dummy variables by original inputs fgraph.replace_all( tuple((repl, orig) for orig, repl in replaced_inputs.items()), import_missing=True, diff --git a/pymc/logprob/scan.py b/pymc/logprob/scan.py index 775fc4e2e7..4b643b7302 100644 --- a/pymc/logprob/scan.py +++ b/pymc/logprob/scan.py @@ -39,31 +39,29 @@ from typing import cast import numpy as np -import pytensor import pytensor.tensor as pt -from pytensor.graph.basic import Variable -from pytensor.graph.op import compute_test_value +from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.scan.op import Scan from pytensor.scan.rewriting import scan_eqopt1, scan_eqopt2 from pytensor.scan.utils import ScanArgs +from pytensor.tensor.basic import AllocEmpty from pytensor.tensor.random.type import RandomType -from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor +from pytensor.tensor.subtensor import IncSubtensor, Subtensor from pytensor.tensor.variable import TensorVariable from pytensor.updates import OrderedUpdates from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.basic import conditional_logp from pymc.logprob.rewriting import ( - PreserveRVMappings, construct_ir_fgraph, - inc_subtensor_ops, logprob_rewrites_db, measurable_ir_rewrites_db, + remove_valued_rvs, ) -from pymc.logprob.utils import replace_rvs_by_values +from pymc.logprob.utils import get_related_valued_nodes, replace_rvs_by_values +from pymc.pytensorf import toposort_replace class MeasurableScan(MeasurableOp, Scan): @@ -297,14 +295,59 @@ def construct_scan(scan_args: ScanArgs, **kwargs) -> tuple[list[TensorVariable], return node.outputs, updates +def get_initval_from_scan_tap_input(inp) -> TensorVariable: + """Get initval from the buffer allocated to tap (recurring) inputs. + + Raises ValueError, if input does not correspond to expected graph. + """ + if not isinstance(inp.owner.op, IncSubtensor) and inp.owner.op.set_instead_of_inc: + raise ValueError + + idx_list = inp.owner.op.idx_list + if not len(idx_list) == 1: + raise ValueError + + [idx_slice] = idx_list + if not ( + isinstance(idx_slice, slice) + and idx_slice.start is None + and idx_slice.stop is not None + and idx_slice.step is None + ): + raise ValueError + + empty, initval, _ = inp.owner.inputs + if not isinstance(empty.owner.op, AllocEmpty): + raise ValueError + + return initval + + @_logprob.register(MeasurableScan) -def logprob_ScanRV(op, values, *inputs, name=None, **kwargs): +def logprob_scan(op, values, *inputs, name=None, **kwargs): new_node = op.make_node(*inputs) scan_args = ScanArgs.from_node(new_node) rv_outer_outs = get_random_outer_outputs(scan_args) - var_indices, rv_vars, io_vars = zip(*rv_outer_outs) - value_map = {_rv: _val for _rv, _val in zip(rv_vars, values)} + # values = (pt.zeros(11)[1:].set(values[0]),) + # For random variable sequences with taps, we need to place the value variable in the + # input tensor that contains the initial state and the empty buffer for the output + values = list(values) + var_indices, outer_rvs, inner_rvs = zip(*rv_outer_outs) + for inp, out in zip( + scan_args.outer_in_sit_sot + scan_args.outer_in_mit_sot, + scan_args.outer_out_sit_sot + scan_args.outer_out_mit_sot, + ): + if out not in outer_rvs: + continue + + # Tap inputs should be a SetSubtensor(empty()[:start], initial_value) + # We will replace it by Join(axis=0, initial_value, value) + initval = get_initval_from_scan_tap_input(inp) + idx = outer_rvs.index(out) + values[idx] = pt.join(0, initval, values[idx]) + + value_map = dict(zip(outer_rvs, values)) def create_inner_out_logp(value_map: dict[TensorVariable, TensorVariable]) -> TensorVariable: """Create a log-likelihood inner-output for a `Scan`.""" @@ -313,7 +356,7 @@ def create_inner_out_logp(value_map: dict[TensorVariable, TensorVariable]) -> Te logp_scan_args = convert_outer_out_to_in( scan_args, - rv_vars, + outer_rvs, value_map, inner_out_fn=create_inner_out_logp, ) @@ -352,171 +395,103 @@ def create_inner_out_logp(value_map: dict[TensorVariable, TensorVariable]) -> Te @node_rewriter([Scan, Subtensor]) def find_measurable_scans(fgraph, node): - r"""Find `Scan`\s for which a `logprob` can be computed. - - This will convert said `Scan`\s into `MeasurableScan`\s. It also updates - random variable and value variable mappings that have been specified for - parts of a `Scan`\s outputs (e.g. everything except the initial values). - """ - - if not hasattr(fgraph, "shape_feature"): - return None # pragma: no cover - - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover + r"""Find `Scan`\s for which a `logprob` can be computed.""" if isinstance(node.op, Subtensor): node = node.inputs[0].owner if not (node and isinstance(node.op, Scan)): return None - if isinstance(node.op, MeasurableScan): - return None - - curr_scanargs = ScanArgs.from_node(node) - - # Find the un-output `MeasurablOp` variables created in the inner-graph - if not any(out in rv_map_feature.rv_values for out in node.outputs): - # TODO: T - # We need to remap user inputs that have been specified in terms of - # `Subtensor`s of this `Scan`'s node's outputs. - # - # For example, the output that the user got was something like - # `out[1:]` for `outputs_info = [{"initial": x0, "taps": [-1]}]`, so - # they likely passed `{out[1:]: x_1T_vv}` to `joint_logprob`. - # Since `out[1:]` isn't really the output of a `Scan`, but a - # `Subtensor` of the output `out` of a `Scan`, we need to account for - # that. - - # Get any `Subtensor` outputs that have been applied to outputs of this - # `Scan` (and get the corresponding indices of the outputs from this - # `Scan`) - output_clients: list[tuple[Variable, int]] = [ - # This is expected to work for `Subtensor` `Op`s, - # because they only ever have one output - (cl.default_output(), i) - for i, out in enumerate(node.outputs) - for cl, _ in fgraph.get_clients(out) - if isinstance(cl.op, Subtensor) - ] - - # The second items in these tuples are the value variables mapped to - # the *user-specified* measurable variables (i.e. the first items) that - # are `Subtensor`s of the outputs of this `Scan`. The second items are - # the index of the corresponding output of this `Scan` node. - indirect_rv_vars = [ - (out, rv_map_feature.rv_values[out], out_idx) - for out, out_idx in output_clients - if out in rv_map_feature.rv_values - ] - - if not indirect_rv_vars: - return None - # We need this for the `clone` in the loop that follows - if pytensor.config.compute_test_value != "off": - compute_test_value(node) - - # We're going to replace the user's random variable/value variable mappings - # with ones that map directly to outputs of this `Scan`. - for rv_var, val_var, out_idx in indirect_rv_vars: - # The full/un-`Subtensor`ed `Scan` output that we need to use - full_out = node.outputs[out_idx] - - assert rv_var.owner.inputs[0] == full_out - - # A new value variable that spans the full output. - # We don't want the old graph to appear in the new log-probability - # graph, so we use the shape feature to (hopefully) get the shape - # without the entire `Scan` itself. - full_out_shape = tuple( - fgraph.shape_feature.get_shape(full_out, i) for i in range(full_out.ndim) - ) - new_val_var = pt.empty(full_out_shape, dtype=full_out.dtype) + if isinstance(node.op, MeasurableScan): + return None - # Set the parts of this new value variable that applied to the - # user-specified value variable to the user's value variable - subtensor_indices = indices_from_subtensor( - rv_var.owner.inputs[1:], rv_var.owner.op.idx_list - ) - # E.g. for a single `-1` TAPS, `s_0T[1:] = s_1T` where `s_0T` is - # `new_val_var` and `s_1T` is the user-specified value variable - # that only spans times `t=1` to `t=T`. - new_val_var = pt.set_subtensor(new_val_var[subtensor_indices], val_var) - - # This is the outer-input that sets `s_0T[i] = taps[i]` where `i` - # is a TAP index (e.g. a TAP of `-1` maps to index `0` in a vector - # of the entire series). - var_info = curr_scanargs.find_among_fields(full_out) - alt_type = var_info.name[(var_info.name.index("_", 6) + 1) :] - outer_input_var = getattr(curr_scanargs, f"outer_in_{alt_type}")[var_info.index] - - # These outer-inputs are using by `pytensor.scan.utils.expand_empty`, and - # are expected to consist of only a single `set_subtensor` call. - # That's why we can simply replace the first argument of the node. - assert isinstance(outer_input_var.owner.op, inc_subtensor_ops) - - # We're going to set those values on our `new_val_var` so that it can - # serve as a complete replacement for the old input `outer_input_var`. - new_val_var = outer_input_var.owner.clone_with_new_inputs( - [new_val_var] + outer_input_var.owner.inputs[1:] - ).default_output() - - # Replace the mapping - rv_map_feature.update_rv_maps(rv_var, new_val_var, full_out) - - op = MeasurableScan( - curr_scanargs.inner_inputs, - curr_scanargs.inner_outputs, - curr_scanargs.info, - mode=node.op.mode, - ) - new_node = op.make_node(*curr_scanargs.outer_inputs) + if node.op.info.as_while: # May work but we haven't tested it + return None - return dict(zip(node.outputs, new_node.outputs)) + if node.op.info.n_mit_mot > 0: + return None + scan_args = ScanArgs.from_node(node) -@node_rewriter([Scan, Subtensor]) -def add_opts_to_inner_graphs(fgraph, node): - """Update the `Mode`(s) used to compile the inner-graph of a `Scan` `Op`. + # TODO: Check what outputs are actually needed for ValuedRVs more than one node deep - This is how we add the measurable IR rewrites to the "body" - (i.e. inner-graph) of a `Scan` loop. - """ + # To make the inner graph measurable, we need to know which inner outputs we are conditioning on from the outside + # If there is only one output, we could always try to make it measurable, but with more outputs it would be ambiguous. + # For example, if we have out1 = normal() and out2 = out1 + const, it's valid to condition on either (but not both). - if isinstance(node.op, Subtensor): - node = node.inputs[0].owner - if not (node and isinstance(node.op, Scan)): - return None + # Find outputs of scan that are directly valued. + # These must be mapping outputs, such as `outputs_info = [None]` (i.e, no recurrence nit_sot outputs) + direct_valued_outputs = [ + valued_node.inputs[0] for valued_node in get_related_valued_nodes(node, fgraph) + ] + if not all(valued_out in scan_args.outer_out_nit_sot for valued_out in direct_valued_outputs): + return None - # TODO: This might not be needed now that we only target relevant nodes - # Avoid unnecessarily re-applying this rewrite - if getattr(node.op.mode, "had_logprob_rewrites", False): + # Find indirect (sliced) outputs of scan that are valued. + # These must be recurring outputs, such as `outputs_info = [{"initial": x0, "taps": [-1]}]` (i.e, recurring sit-sot or mit-sot outputs) + # For these outputs, the scan helper returns `out[abs(min(taps)):]` (out[:abs(min(taps))] includes the initial values) + # This means that it's a Subtensor output, not a direct Scan output, that the user requests the logp of. + sliced_valued_outputs = [ + client.outputs[0] + for out in node.outputs + for client, _ in fgraph.clients[out] + if (isinstance(client.op, Subtensor) and get_related_valued_nodes(client, fgraph)) + ] + indirect_valued_outputs = [out.owner.inputs[0] for out in sliced_valued_outputs] + if not all( + (valued_out in scan_args.outer_out_sit_sot or valued_out in scan_args.outer_out_mit_sot) + for valued_out in indirect_valued_outputs + ): return None - inner_rv_values = {out: out.type() for out in node.op.inner_outputs} - ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])) - inner_fgraph, rv_values, _ = construct_ir_fgraph(inner_rv_values, ir_rewriter=ir_rewriter) + valued_outputs = direct_valued_outputs + indirect_valued_outputs - new_outputs = list(inner_fgraph.outputs) + if not valued_outputs: + return None - # TODO FIXME: This is pretty hackish. - new_mode = copy(node.op.mode) - new_mode.had_logprob_rewrites = True + valued_output_idxs = [node.outputs.index(out) for out in valued_outputs] - op = Scan(node.op.inner_inputs, new_outputs, node.op.info, mode=new_mode) - new_node = op.make_node(*node.inputs) + # Make inner graph measurable + mapping = node.op.get_oinp_iinp_iout_oout_mappings()["inner_out_from_outer_out"] + inner_rvs = [node.op.inner_outputs[mapping[idx][-1]] for idx in valued_output_idxs] + inner_fgraph = construct_ir_fgraph({rv: rv.type() for rv in inner_rvs}) + remove_valued_rvs(inner_fgraph) + inner_rvs = list(inner_fgraph.outputs) + if not all(isinstance(new_out.owner.op, MeasurableOp) for new_out in inner_rvs): + return None - return dict(zip(node.outputs, new_node.outputs)) + # Create MeasurableScan with new inner outs + # We must also replace any lingering references to the old RVs by the new measurable RVS + # For example if we had measurable out1 = exp(normal()) and out2 = out1 - x + # We need to replace references of original out1 by the new MeasurableExp(normal()) + inner_outs = node.op.inner_outputs.copy() + inner_rvs_replacements = [] + for idx, new_inner_rv in zip(valued_output_idxs, inner_rvs, strict=True): + old_inner_rv = inner_outs[idx] + inner_outs[idx] = new_inner_rv + inner_rvs_replacements.append((old_inner_rv, new_inner_rv)) + temp_fgraph = FunctionGraph( + outputs=inner_outs + [a for a, _ in inner_rvs_replacements], + clone=False, + ) + toposort_replace(temp_fgraph, inner_rvs_replacements) + inner_outs = temp_fgraph.outputs[: len(inner_outs)] + op = MeasurableScan(node.op.inner_inputs, inner_outs, node.op.info, mode=copy(node.op.mode)) + new_outs = op.make_node(*node.inputs).outputs + + old_outs = node.outputs + replacements = {} + for old_out, new_out in zip(old_outs, new_outs): + if old_out in indirect_valued_outputs: + # We sidestep the Subtensor operation, which is not relevant for the logp + sliced_idx = indirect_valued_outputs.index(old_out) + old_out = sliced_valued_outputs[sliced_idx] + replacements[old_out] = new_out + else: + replacements[old_out] = new_out + return replacements -measurable_ir_rewrites_db.register( - "add_opts_to_inner_graphs", - add_opts_to_inner_graphs, - "basic", - "scan", -) measurable_ir_rewrites_db.register( "find_measurable_scans", diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 7aa3c559c9..a4f1324c99 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -38,6 +38,7 @@ from pathlib import Path from pytensor import tensor as pt +from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor import TensorVariable from pytensor.tensor.basic import Join, MakeVector @@ -47,13 +48,18 @@ local_dimshuffle_rv_lift, ) -from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper +from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper, promised_valued_rv from pymc.logprob.rewriting import ( - PreserveRVMappings, - assume_measured_ir_outputs, + assume_valued_outputs, + early_measurable_ir_rewrites_db, measurable_ir_rewrites_db, + remove_promised_valued_rvs, +) +from pymc.logprob.utils import ( + check_potential_measurability, + filter_measurable_variables, + replace_rvs_by_values, ) -from pymc.logprob.utils import check_potential_measurability, replace_rvs_by_values from pymc.pytensorf import constant_fold @@ -68,6 +74,8 @@ def logprob_make_vector(op, values, *base_rvs, **kwargs): (value,) = values + base_rvs = remove_promised_valued_rvs(base_rvs) + base_rvs_to_values = {base_rv: value[i] for i, base_rv in enumerate(base_rvs)} for i, (base_rv, value) in enumerate(base_rvs_to_values.items()): base_rv.name = f"base_rv[{i}]" @@ -90,6 +98,8 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs): """Compute the log-likelihood graph for a `Join`.""" (value,) = values + base_rvs = remove_promised_valued_rvs(base_rvs) + base_rv_shapes = [base_var.shape[axis] for base_var in base_rvs] # We don't need the graph to be constant, just to have RandomVariables removed @@ -131,11 +141,10 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs): @node_rewriter([MakeVector, Join]) def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None: r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed.""" + from pymc.pytensorf import toposort_replace - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover + if isinstance(node.op, MeasurableOp): + return None is_join = isinstance(node.op, Join) @@ -144,18 +153,25 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None: else: base_vars = node.inputs - valued_rvs = rv_map_feature.rv_values.keys() - if not all(check_potential_measurability([base_var], valued_rvs) for base_var in base_vars): + if not all(check_potential_measurability([base_var]) for base_var in base_vars): return None - base_vars = assume_measured_ir_outputs(valued_rvs, base_vars) + base_vars = assume_valued_outputs(base_vars) if not all(var.owner and isinstance(var.owner.op, MeasurableOp) for var in base_vars): return None + # Each base var will be "valued" by the logprob method, so other rewrites shouldn't mess with it + # and potentially break interdependencies. For this reason, this rewrite should be applied early in + # the IR construction + replacements = [(base_var, promised_valued_rv(base_var)) for base_var in base_vars] + temp_fgraph = FunctionGraph(outputs=base_vars, clone=False) + toposort_replace(temp_fgraph, replacements) # type: ignore + new_base_vars = temp_fgraph.outputs + if is_join: - measurable_stack = MeasurableJoin()(axis, *base_vars) + measurable_stack = MeasurableJoin()(axis, *new_base_vars) else: - measurable_stack = MeasurableMakeVector(node.op.dtype)(*base_vars) + measurable_stack = MeasurableMakeVector(node.op.dtype)(*new_base_vars) assert isinstance(measurable_stack, TensorVariable) return [measurable_stack] @@ -203,13 +219,12 @@ def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs): @node_rewriter([DimShuffle]) def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None: r"""Finds `Dimshuffle`\s for which a `logprob` can be computed.""" + from pymc.distributions.distribution import SymbolicRandomVariable - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover + if isinstance(node.op, MeasurableOp): + return None - if not rv_map_feature.request_measurable(node.inputs): + if not filter_measurable_variables(node.inputs): return None base_var = node.inputs[0] @@ -222,7 +237,7 @@ def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None: # lifted towards the base RandomVariable. # TODO: If we include the support axis as meta information in each # intermediate MeasurableVariable, we can lift this restriction. - if not isinstance(base_var.owner.op, RandomVariable): + if not isinstance(base_var.owner.op, RandomVariable | SymbolicRandomVariable): return None # pragma: no cover measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)( @@ -241,7 +256,7 @@ def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None: "find_measurable_dimshuffles", find_measurable_dimshuffles, "basic", "tensor" ) -measurable_ir_rewrites_db.register( +early_measurable_ir_rewrites_db.register( "find_measurable_stacks", find_measurable_stacks, "basic", diff --git a/pymc/logprob/transform_value.py b/pymc/logprob/transform_value.py index 2523a9b6db..fa013dbf3d 100644 --- a/pymc/logprob/transform_value.py +++ b/pymc/logprob/transform_value.py @@ -19,14 +19,13 @@ from pytensor.graph import Apply, Op from pytensor.graph.features import AlreadyThere, Feature from pytensor.graph.fg import FunctionGraph -from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter -from pytensor.scan.op import Scan from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableOp, _logprob -from pymc.logprob.rewriting import PreserveRVMappings, cleanup_ir_rewrites_db +from pymc.logprob.abstract import MeasurableOp, ValuedRV, _logprob, valued_rv +from pymc.logprob.rewriting import cleanup_ir_rewrites_db from pymc.logprob.transforms import Transform +from pymc.logprob.utils import get_related_valued_nodes class TransformedValue(Op): @@ -56,8 +55,6 @@ class TransformedValueRV(MeasurableOp, Op): This is introduced by the `TransformValuesRewrite` """ - view_map = {0: [0]} - __props__ = ("transforms",) def __init__(self, transforms: Sequence[Transform]): @@ -130,7 +127,7 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs) return logprobs_jac -@node_rewriter(tracks=None) +@node_rewriter(tracks=[ValuedRV]) def transform_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None: """Apply transforms to value variables. @@ -143,146 +140,52 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None: ``Y`` on the natural scale. """ - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - values_to_transforms: TransformValuesMapping | None = getattr( - fgraph, "values_to_transforms", None - ) - - if rv_map_feature is None or values_to_transforms is None: - return None # pragma: no cover - - rv_vars = [] - value_vars = [] - - for out in node.outputs: - value = rv_map_feature.rv_values.get(out, None) - if value is None: - continue - rv_vars.append(out) - value_vars.append(value) - - if not value_vars: - return None - - transforms = [values_to_transforms.get(value_var, None) for value_var in value_vars] - - if all(transform is None for transform in transforms): - return None - - transformed_rv_op = TransformedValueRV(transforms) - # Clone outputs so that rewrite doesn't reference original variables circularly - cloned_outputs = node.clone().outputs - transformed_rv_node = transformed_rv_op.make_node(*cloned_outputs) - - # We now assume that the old value variable represents the *transformed space*. - # This means that we need to replace all instance of the old value variable - # with "inversely/un-" transformed versions of itself. - for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms): - rv_var_out_idx = node.outputs.index(rv_var) - - if transform is None: - continue - - new_value_var = transformed_value( - transform.backward(value_var, *node.inputs), - value_var, - ) - - if value_var.name and getattr(transform, "name", None): - new_value_var.name = f"{value_var.name}_{transform.name}" - - rv_map_feature.update_rv_maps( - rv_var, new_value_var, transformed_rv_node.outputs[rv_var_out_idx] - ) - - return transformed_rv_node.outputs - - -@node_rewriter(tracks=[Scan]) -def transform_scan_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None: - """Apply transforms to Scan value variables. - - This specialized rewrite is needed because Scan replaces the original value variables - by a more complex graph. We want to apply the transform to the original value variable - in this subgraph, leaving the rest intact - """ - - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) values_to_transforms: TransformValuesMapping | None = getattr( fgraph, "values_to_transforms", None ) - if rv_map_feature is None or values_to_transforms is None: - return None # pragma: no cover - - rv_vars = [] - value_vars = [] - - for out in node.outputs: - value = rv_map_feature.rv_values.get(out, None) - if value is None: - continue - rv_vars.append(out) - value_vars.append(value) - - if not value_vars: + if values_to_transforms is None: return None - transforms = [ - values_to_transforms.get(rv_map_feature.original_values[value_var], None) - for value_var in value_vars - ] + rv_node = node.inputs[0].owner + valued_nodes = get_related_valued_nodes(rv_node, fgraph) + rvs = [valued_var.inputs[0] for valued_var in valued_nodes] + values = [valued_var.inputs[1] for valued_var in valued_nodes] + transforms = [values_to_transforms.get(value, None) for value in values] if all(transform is None for transform in transforms): return None transformed_rv_op = TransformedValueRV(transforms) - # Clone outputs so that rewrite doesn't reference original variables circularly - cloned_outputs = node.clone().outputs - transformed_rv_node = transformed_rv_op.make_node(*cloned_outputs) + transformed_rv_node = transformed_rv_op.make_node(*rvs) # We now assume that the old value variable represents the *transformed space*. # This means that we need to replace all instance of the old value variable # with "inversely/un-" transformed versions of itself. - for rv_var, value_var, transform in zip(rv_vars, value_vars, transforms): - rv_var_out_idx = node.outputs.index(rv_var) + replacements = {} + for valued_node, transformed_rv, transform in zip( + valued_nodes, transformed_rv_node.outputs, transforms + ): + rv, value = valued_node.inputs + [val_rv] = valued_node.outputs if transform is None: - continue - - # We access the original value variable and apply the transform to that - original_value_var = rv_map_feature.original_values[value_var] - trans_original_value_var = transform.backward( - original_value_var, *transformed_rv_node.inputs - ) - - # We then replace the reference to the original value variable in the scan value - # variable by the back-transform projection computed above - - # The first input corresponds to the original value variable. We are careful to - # only clone_replace that part of the graph, as we don't want to break the - # mappings between other rvs that are likely to be present in the rest of the - # scan value variable graph - # TODO: Is it true that the original value only appears in the first input - # and that no other RV can appear there? - (trans_original_value_var,) = clone_replace( - (value_var.owner.inputs[0],), - replace={original_value_var: trans_original_value_var}, - ) - transformed_value_var = value_var.owner.clone_with_new_inputs( - inputs=[trans_original_value_var] + value_var.owner.inputs[1:] - ).default_output() + transformed_val = value - new_value_var = transformed_value(transformed_value_var, original_value_var) + else: + transformed_val = transformed_value( + transform.backward(value, *rv.owner.inputs), + value, + ) - if value_var.name and getattr(transform, "name", None): - new_value_var.name = f"{value_var.name}_{transform.name}" + value_name = value.name + transform_name = getattr(transform, "name", None) + if value_name and transform_name: + transformed_val.name = f"{value_name}_{transform.name}" - rv_map_feature.update_rv_maps( - rv_var, new_value_var, transformed_rv_node.outputs[rv_var_out_idx] - ) + replacements[val_rv] = valued_rv(transformed_rv, transformed_val) - return transformed_rv_node.outputs + return replacements class TransformValuesMapping(Feature): @@ -302,7 +205,6 @@ class TransformValuesRewrite(GraphRewriter): r"""Transforms value variables according to a map.""" transform_rewrite = in2out(transform_values, ignore_newtrees=True) - scan_transform_rewrite = in2out(transform_scan_values, ignore_newtrees=True) def __init__( self, @@ -327,7 +229,6 @@ def add_requirements(self, fgraph): def apply(self, fgraph: FunctionGraph): self.transform_rewrite.rewrite(fgraph) - self.scan_transform_rewrite.rewrite(fgraph) @node_rewriter([TransformedValue]) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 825103da87..d6dd0894b1 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -116,10 +116,11 @@ _logprob, _logprob_helper, ) -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db +from pymc.logprob.rewriting import measurable_ir_rewrites_db from pymc.logprob.utils import ( CheckParameterValue, check_potential_measurability, + filter_measurable_variables, find_negated_var, ) @@ -269,6 +270,9 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg # We don't know if this Op is monotonically increasing/decreasing raise NotImplementedError + if is_discrete: + return logcdf + # The jacobian is used to ensure a value in the supported domain was provided jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs) return pt.switch(pt.isnan(jacobian), -np.inf, logcdf) @@ -311,6 +315,9 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs) @node_rewriter([reciprocal]) def measurable_reciprocal_to_power(fgraph, node): """Convert reciprocal of `MeasurableVariable`s to power.""" + if not filter_measurable_variables(node.inputs): + return None + [inp] = node.inputs return [pt.pow(inp, -1.0)] @@ -318,6 +325,9 @@ def measurable_reciprocal_to_power(fgraph, node): @node_rewriter([sqr, sqrt]) def measurable_sqrt_sqr_to_power(fgraph, node): """Convert square root or square of `MeasurableVariable`s to power form.""" + if not filter_measurable_variables(node.inputs): + return None + [inp] = node.inputs if isinstance(node.op.scalar_op, Sqr): @@ -330,6 +340,9 @@ def measurable_sqrt_sqr_to_power(fgraph, node): @node_rewriter([true_div]) def measurable_div_to_product(fgraph, node): """Convert divisions involving `MeasurableVariable`s to products.""" + if not filter_measurable_variables(node.inputs): + return None + numerator, denominator = node.inputs # Check if numerator is 1 @@ -348,13 +361,19 @@ def measurable_div_to_product(fgraph, node): @node_rewriter([neg]) def measurable_neg_to_product(fgraph, node): """Convert negation of `MeasurableVariable`s to product with `-1`.""" + if not filter_measurable_variables(node.inputs): + return None + inp = node.inputs[0] - return [pt.mul(inp, -1.0)] + return [pt.mul(inp, -1)] @node_rewriter([sub]) def measurable_sub_to_neg(fgraph, node): """Convert subtraction involving `MeasurableVariable`s to addition with neg""" + if not filter_measurable_variables(node.inputs): + return None + minuend, subtrahend = node.inputs return [pt.add(minuend, pt.neg(subtrahend))] @@ -362,6 +381,9 @@ def measurable_sub_to_neg(fgraph, node): @node_rewriter([log1p, softplus, log1mexp, log2, log10]) def measurable_special_log_to_log(fgraph, node): """Convert log1p, log1mexp, softplus, log2, log10 of `MeasurableVariable`s to log form.""" + if not filter_measurable_variables(node.inputs): + return None + [inp] = node.inputs if isinstance(node.op.scalar_op, Log1p): @@ -379,6 +401,9 @@ def measurable_special_log_to_log(fgraph, node): @node_rewriter([expm1, sigmoid, exp2]) def measurable_special_exp_to_exp(fgraph, node): """Convert expm1, sigmoid, and exp2 of `MeasurableVariable`s to xp form.""" + if not filter_measurable_variables(node.inputs): + return None + [inp] = node.inputs if isinstance(node.op.scalar_op, Exp2): return [pt.exp(pt.log(2) * inp)] @@ -391,11 +416,14 @@ def measurable_special_exp_to_exp(fgraph, node): @node_rewriter([pow]) def measurable_power_exponent_to_exp(fgraph, node): """Convert power(base, rv) of `MeasurableVariable`s to exp(log(base) * rv) form.""" + if not filter_measurable_variables(node.inputs): + return None + base, inp_exponent = node.inputs # When the base is measurable we have `power(rv, exponent)`, which should be handled by `PowerTransform` and needs no further rewrite. # Here we change only the cases where exponent is measurable `power(base, rv)` which is not supported by the `PowerTransform` - if check_potential_measurability([base], fgraph.preserve_rv_mappings.rv_values.keys()): + if check_potential_measurability([base]): return None base = CheckParameterValue("base >= 0")(base, pt.all(pt.ge(base, 0.0))) @@ -427,14 +455,10 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node] # Node was already converted if isinstance(node.op, MeasurableOp): - return None # pragma: no cover - - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) - if rv_map_feature is None: - return None # pragma: no cover + return None # Check that we have a single source of measurement - measurable_inputs = rv_map_feature.request_measurable(node.inputs) + measurable_inputs = filter_measurable_variables(node.inputs) if len(measurable_inputs) != 1: return None @@ -454,7 +478,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> list[Node] # would be invalid other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input) - if check_potential_measurability(other_inputs, rv_map_feature.rv_values.keys()): + if check_potential_measurability(other_inputs): return None scalar_op = node.op.scalar_op diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 63fc409052..adc75b556a 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -36,15 +36,15 @@ import typing import warnings -from collections.abc import Container, Iterable, Sequence +from collections.abc import Iterable, Sequence import numpy as np import pytensor -from pytensor import Variable from pytensor import tensor as pt from pytensor.graph import Apply, Op, node_rewriter -from pytensor.graph.basic import Constant, clone_get_equiv, graph_inputs, walk +from pytensor.graph.basic import Constant, Variable, clone_get_equiv, graph_inputs, walk +from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import HasInnerGraph from pytensor.link.c.type import CType from pytensor.raise_op import CheckAndRaise @@ -55,7 +55,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable -from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.logprob.abstract import MeasurableOp, ValuedRV, _logprob from pymc.pytensorf import replace_vars_in_graphs from pymc.util import makeiter @@ -172,15 +172,17 @@ def indices_from_subtensor(idx_list, indices): ) -def check_potential_measurability( - inputs: Iterable[TensorVariable], valued_rvs: Container[TensorVariable] -) -> bool: - valued_rvs = set(valued_rvs) +def filter_measurable_variables(inputs): + return [ + inp for inp in inputs if (inp.owner is not None and isinstance(inp.owner.op, MeasurableOp)) + ] + +def check_potential_measurability(inputs: Iterable[TensorVariable]) -> bool: def expand_fn(var): # expand_fn does not go beyond valued_rvs or any MeasurableOp variables - if var.owner and not isinstance(var.owner.op, MeasurableOp) and var not in valued_rvs: - return reversed(var.owner.inputs) + if var.owner and not isinstance(var.owner.op, MeasurableOp | ValuedRV): + return var.owner.inputs else: return [] @@ -190,7 +192,7 @@ def expand_fn(var): if ( ancestor_var.owner and isinstance(ancestor_var.owner.op, MeasurableOp) - and ancestor_var not in valued_rvs + and not isinstance(ancestor_var.owner.op, ValuedRV) ) ): return True @@ -301,10 +303,8 @@ def diracdelta_logprob(op, values, *inputs, **kwargs): def find_negated_var(var): """Return a variable that is being multiplied by -1 or None otherwise.""" - if ( - not (var.owner) - and isinstance(var.owner.op, Elemwise) - and isinstance(var.owner.op.scalar_op, Mul) + if not ( + var.owner and isinstance(var.owner.op, Elemwise) and isinstance(var.owner.op.scalar_op, Mul) ): return None if len(var.owner.inputs) != 2: @@ -319,3 +319,20 @@ def find_negated_var(var): continue return None + + +def get_related_valued_nodes(node: Apply, fgraph: FunctionGraph) -> list[Apply]: + """Get all ValuedVars related to the same RV node. + + Returns + ------- + rv_node + valued_nodes + """ + clients = fgraph.clients + return [ + client + for out in node.outputs + for client, _ in clients[out] + if isinstance(client.op, ValuedRV) + ] diff --git a/tests/logprob/test_basic.py b/tests/logprob/test_basic.py index a0413be660..cfbd70b504 100644 --- a/tests/logprob/test_basic.py +++ b/tests/logprob/test_basic.py @@ -44,6 +44,7 @@ from pytensor.graph.basic import ancestors, equal_computations from pytensor.tensor.random.op import RandomVariable +from scipy import stats import pymc as pm @@ -413,3 +414,25 @@ def test_icdf_discrete(): dist_icdf.eval(), sp.geom.ppf(value, p), ) + + +def test_ir_rewrite_does_not_disconnect_valued_rvs(): + """Check that we don't lose the dependency across RV values do to automatic rewrites. + + See ValuedRV docstrings for more context. + + Regression test for https://github.com/pymc-devs/pymc/issues/6917 + """ + a_base = pm.Normal.dist() + a = a_base * 5 + b = pm.Normal.dist(a * 8) + + a_value = a.type() + b_value = b.type() + logp_b = conditional_logp({a: a_value, b: b_value})[b_value] + + assert_no_rvs(logp_b) + np.testing.assert_allclose( + logp_b.eval({a_value: np.pi, b_value: np.e}), + stats.norm.logpdf(np.e, np.pi * 8, 1), + ) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index de407fd579..ccbbb38bc2 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -48,7 +48,6 @@ from pymc.testing import assert_no_rvs -@pytensor.config.change_flags(compute_test_value="raise") def test_continuous_rv_clip(): x_rv = pt.random.normal(0.5, 1) cens_x_rv = pt.clip(x_rv, -2, 2) @@ -195,7 +194,7 @@ def test_fail_multiple_clip_single_base(): cens_vv1 = cens_rv1.clone() cens_vv2 = cens_rv2.clone() - with pytest.raises(RuntimeError, match="could not be derived: {cens2}"): + with pytest.raises(ValueError, match="too many values to unpack"): conditional_logp({cens_rv1: cens_vv1, cens_rv2: cens_vv2}) diff --git a/tests/logprob/test_composite_logprob.py b/tests/logprob/test_composite_logprob.py index 3653830ef9..b249a167fe 100644 --- a/tests/logprob/test_composite_logprob.py +++ b/tests/logprob/test_composite_logprob.py @@ -120,6 +120,7 @@ def test_nested_scalar_mixtures(): assert np.isclose(logp_fn(0, 0, 1, 50), st.norm.logpdf(150) + np.log(0.5) * 3) +@pytest.mark.xfail(reason="This is not currently enforced") @pytest.mark.parametrize("nested", (False, True)) def test_unvalued_ir_reversion(nested): """Make sure that un-valued IR rewrites are reverted.""" @@ -134,7 +135,7 @@ def test_unvalued_ir_reversion(nested): # measurable IR. rv_values = {z_rv: z_vv} - z_fgraph, _, memo = construct_ir_fgraph(rv_values) + z_fgraph = construct_ir_fgraph(rv_values) # assert len(z_fgraph.preserve_rv_mappings.measurable_conversions) == 1 assert ( diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index 1d09e844fd..61a78bf4db 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -108,40 +108,6 @@ def create_mix_model(size, axis): conditional_logp({M_rv: m_vv, I_rv: i_vv}) -@pytensor.config.change_flags(compute_test_value="warn") -@pytest.mark.parametrize( - "op_constructor", - [ - lambda _I, _X, _Y: pt.stack([_X, _Y])[_I], - lambda _I, _X, _Y: pt.switch(_I, _X, _Y), - ], -) -def test_compute_test_value(op_constructor): - X_rv = pt.random.normal(0, 1, name="X") - Y_rv = pt.random.gamma(0.5, scale=2.0, name="Y") - - p_at = pt.scalar("p") - p_at.tag.test_value = 0.3 - - I_rv = pt.random.bernoulli(p_at, name="I") - - i_vv = I_rv.clone() - i_vv.name = "i" - - M_rv = op_constructor(I_rv, X_rv, Y_rv) - M_rv.name = "M" - - m_vv = M_rv.clone() - m_vv.name = "m" - - del M_rv.tag.test_value - - M_logp = conditional_logp({M_rv: m_vv, I_rv: i_vv}) - M_logp_combined = pt.add(*M_logp.values()) - - assert isinstance(M_logp_combined.tag.test_value, np.ndarray) - - @pytest.mark.parametrize( "p_val, size, supported", [ @@ -920,8 +886,8 @@ def test_scalar_switch_mixture(): z_vv = Z1_rv.clone() z_vv.name = "z1" - fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) - assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) + fgraph = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv}) + assert isinstance(fgraph.outputs[0].owner.inputs[0].owner.op, MeasurableSwitchMixture) # building the identical graph but with a stack to check that mixture logps are identical Z2_rv = pt.stack((Y_rv, X_rv))[I_rv] @@ -992,17 +958,17 @@ def test_switch_mixture_invalid_bcast(): invalid_false_branch = pt.abs(pt.random.normal(size=())) valid_mix = pt.switch(valid_switch_cond, valid_true_branch, valid_false_branch) - fgraph, _, _ = construct_ir_fgraph({valid_mix: valid_mix.type()}) - assert isinstance(fgraph.outputs[0].owner.op, MeasurableOp) - assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchMixture) + fgraph = construct_ir_fgraph({valid_mix: valid_mix.type()}) + assert isinstance(fgraph.outputs[0].owner.inputs[0].owner.op, MeasurableOp) + assert isinstance(fgraph.outputs[0].owner.inputs[0].owner.op, MeasurableSwitchMixture) invalid_mix = pt.switch(invalid_switch_cond, valid_true_branch, valid_false_branch) - fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) - assert not isinstance(fgraph.outputs[0].owner.op, MeasurableOp) + fgraph = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) + assert not isinstance(fgraph.outputs[0].owner.inputs[0].owner.op, MeasurableOp) invalid_mix = pt.switch(valid_switch_cond, valid_true_branch, invalid_false_branch) - fgraph, _, _ = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) - assert not isinstance(fgraph.outputs[0].owner.op, MeasurableOp) + fgraph = construct_ir_fgraph({invalid_mix: invalid_mix.type()}) + assert not isinstance(fgraph.outputs[0].owner.inputs[0].owner.op, MeasurableOp) def test_ifelse_mixture_one_component(): @@ -1036,7 +1002,8 @@ def test_ifelse_mixture_multiple_components(): if_var = pt.scalar("if_var", dtype="bool") comp_then1 = pt.random.normal(size=(2,), name="comp_true1") - comp_then2 = comp_then1 + pt.random.normal(size=(2, 2), name="comp_then2") + comp_then2 = comp_then1 + pt.random.normal(size=(2, 2)) + comp_then2.name = "comp_then2" comp_else1 = pt.random.halfnormal(size=(4,), name="comp_else1") comp_else2 = pt.random.halfnormal(size=(4, 4), name="comp_else2") diff --git a/tests/logprob/test_scan.py b/tests/logprob/test_scan.py index a9e64e459c..09c7e2d952 100644 --- a/tests/logprob/test_scan.py +++ b/tests/logprob/test_scan.py @@ -505,7 +505,6 @@ def ref_logp(values, rho, sigma): ) -@pytest.mark.xfail(reason="Not implemented yet") def test_scan_multiple_output_types(): """Test we can derive the logp for a scan that contains recurring and non-recurring measurable outputs.""" [xs, ys, zs], _ = pytensor.scan( diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 0c4879c002..17fe096e92 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -451,7 +451,7 @@ def test_sqrt_transform(self): # ICDF is not implemented for chisquare, so we have to test with another identity # sqrt(exponential(lam)) = rayleigh(1 / sqrt(2 * lam)) lam = 2.5 - y_rv = pt.sqrt(pt.random.exponential(scale=1 / lam)) + y_rv = pt.sqrt(pt.random.exponential(scale=1 / lam, size=(4,))) y_vv = x_rv.clone() y_icdf_fn = pytensor.function([y_vv], icdf(y_rv, y_vv)) q_test_val = np.r_[0.2, 0.5, 0.7, 0.9] diff --git a/tests/logprob/test_utils.py b/tests/logprob/test_utils.py index d337e0317e..10cd36fc39 100644 --- a/tests/logprob/test_utils.py +++ b/tests/logprob/test_utils.py @@ -49,7 +49,7 @@ from pymc import SymbolicRandomVariable, inputvars from pymc.distributions.transforms import Interval -from pymc.logprob.abstract import MeasurableOp +from pymc.logprob.abstract import MeasurableOp, valued_rv from pymc.logprob.basic import logp from pymc.logprob.utils import ( ParameterValueError, @@ -307,13 +307,23 @@ def scipy_logprob(obs, c): def test_check_potential_measurability(): x1 = pt.random.normal() + x1_valued = valued_rv(x1, x1.type()) + x2 = pt.random.normal() + x2_valued = valued_rv(x2, x2.type()) + x3 = pt.scalar("x3") - y = pt.exp(x1 + x2 + x3) # In the first three cases, y is potentially measurable, because it has at least on unvalued RV input - assert check_potential_measurability([y], {}) - assert check_potential_measurability([y], {x1}) - assert check_potential_measurability([y], {x2}) + y = pt.exp(x1 + x2 + x3) + assert check_potential_measurability([y]) + + y = pt.exp(x1_valued + x2 + x3) + assert check_potential_measurability([y]) + + y = pt.exp(x1 + x2_valued + x3) + assert check_potential_measurability([y]) + # y is not potentially measurable because both RV inputs are valued - assert not check_potential_measurability([y], {x1, x2}) + y = pt.exp(x1_valued + x2_valued + x3) + assert not check_potential_measurability([y])