Skip to content

Commit

Permalink
Introduce valued variables in logprob IR
Browse files Browse the repository at this point in the history
This avoids rewrites across conditioning points, that could break dependencies

Also extend logprob derivation of scans with multiple valued output types
  • Loading branch information
ricardoV94 committed Sep 6, 2024
1 parent 600a1f3 commit 36312a2
Show file tree
Hide file tree
Showing 21 changed files with 583 additions and 799 deletions.
45 changes: 44 additions & 1 deletion pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from collections.abc import Sequence
from functools import singledispatch

from pytensor.graph.op import Op
from pytensor.graph import Apply, Op, Variable
from pytensor.graph.utils import MetaType
from pytensor.tensor import TensorVariable
from pytensor.tensor.elemwise import Elemwise
Expand Down Expand Up @@ -165,3 +165,46 @@ def __init__(self, scalar_op, *args, **kwargs):

def __str__(self):
return f"Measurable{super().__str__()}"


class ValuedRV(Op):
r"""Represents the association of a measurable variable and its value.
A `ValuedVariable` node represents the pair :math:`(Y, y)`, where
:math:`Y` is a random variable and :math:`y \sim Y`.
Log-probability (densities) are functions over these pairs, which makes
these nodes in a graph an intermediate form that serves to construct a
log-probability from a model graph.
"""

def make_node(self, rv, value):
assert isinstance(rv, Variable)
assert isinstance(value, Variable)
return Apply(self, [rv, value], [rv.type(name=rv.name)])

def perform(self, node, inputs, out):
raise NotImplementedError("ValuedVar should not be present in the final graph!")

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


valued_rv = ValuedRV()


class PromisedValuedRV(Op):
r"""Marks a variable as being promised a valued variable in the logprob method."""

def make_node(self, rv):
assert isinstance(rv, Variable)
return Apply(self, [rv], [rv.type(name=rv.name)])

def perform(self, node, inputs, out):
raise NotImplementedError("PromisedValuedRV should not be present in the final graph!")

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]


promised_valued_rv = PromisedValuedRV()
115 changes: 51 additions & 64 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,17 @@

import warnings

from collections import deque
from collections.abc import Sequence
from typing import TypeAlias

import numpy as np
import pytensor.tensor as pt

from pytensor import config
from pytensor.graph.basic import (
Constant,
Variable,
ancestors,
graph_inputs,
io_toposort,
)
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
from pytensor.tensor.variable import TensorVariable

Expand All @@ -65,7 +60,7 @@
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
from pymc.logprob.transform_value import TransformValuesRewrite
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import rvs_in_graph
from pymc.logprob.utils import get_related_valued_nodes, rvs_in_graph
from pymc.pytensorf import replace_vars_in_graphs

TensorLike: TypeAlias = Variable | float | np.ndarray
Expand Down Expand Up @@ -211,7 +206,8 @@ def normal_logp(value, mu, sigma):
return _logprob_helper(rv, value, **kwargs)
except NotImplementedError:
fgraph, _, _ = construct_ir_fgraph({rv: value})
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
[ir_valued_var] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_var.owner.inputs
expr = _logprob_helper(ir_rv, ir_value, **kwargs)
cleanup_ir([expr])
if warn_rvs:
Expand Down Expand Up @@ -309,8 +305,9 @@ def normal_logcdf(value, mu, sigma):
except NotImplementedError:
# Try to rewrite rv
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
expr = _logcdf_helper(ir_rv, value, **kwargs)
[ir_valued_rv] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
expr = _logcdf_helper(ir_rv, ir_value, **kwargs)
cleanup_ir([expr])
if warn_rvs:
_warn_rvs_in_inferred_graph(expr)
Expand Down Expand Up @@ -391,8 +388,9 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
except NotImplementedError:
# Try to rewrite rv
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
[ir_rv] = fgraph.outputs
expr = _icdf_helper(ir_rv, value, **kwargs)
[ir_valued_rv] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
expr = _icdf_helper(ir_rv, ir_value, **kwargs)
cleanup_ir([expr])
if warn_rvs:
_warn_rvs_in_inferred_graph(expr)
Expand Down Expand Up @@ -481,27 +479,9 @@ def conditional_logp(
if extra_rewrites is not None:
extra_rewrites.rewrite(fgraph)

rv_remapper = fgraph.preserve_rv_mappings

# This is the updated random-to-value-vars map with the lifted/rewritten
# variables. The rewrites are supposed to produce new
# `MeasurableOp`s whose variables are amenable to `_logprob`.
updated_rv_values = rv_remapper.rv_values

# Some rewrites also transform the original value variables. This is the
# updated map from the new value variables to the original ones, which
# we want to use as the keys in the final dictionary output
original_values = rv_remapper.original_values

# When a `_logprob` has been produced for a `MeasurableOp` node, all
# other references to it need to be replaced with its value-variable all
# throughout the `_logprob`-produced graphs. The following `dict`
# cumulatively maintains remappings for all the variables/nodes that needed
# to be recreated after replacing `MeasurableOp` variables with their
# value-variables. Since these replacements work in topological order, all
# the necessary value-variable replacements should be present for each
# node.
replacements = updated_rv_values.copy()
# Walk the graph from its inputs to its outputs and construct the
# log-probability
replacements = {}

# To avoid cloning the value variables (or ancestors of value variables),
# we map them to themselves in the `replacements` `dict`
Expand All @@ -516,71 +496,78 @@ def conditional_logp(

# Walk the graph from its inputs to its outputs and construct the
# log-probability
q = deque(fgraph.toposort())
logprob_vars = {}

while q:
node = q.popleft()
values_to_logprobs = {}
original_values = tuple(rv_values.values())

# TODO: This seems too convoluted, can we just replace all RVs by their values,
# except for the fgraph outputs (for which we want to call _logprob on)?
for node in fgraph.toposort():
if not isinstance(node.op, MeasurableOp):
continue

q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values]
valued_nodes = get_related_valued_nodes(node, fgraph)

if not q_values:
if not valued_nodes:
continue

node_rvs = [valued_var.inputs[0] for valued_var in valued_nodes]
node_values = [valued_var.inputs[1] for valued_var in valued_nodes]
node_output_idxs = [
fgraph.outputs.index(valued_var.outputs[0]) for valued_var in valued_nodes
]

# Replace `RandomVariable`s in the inputs with value variables.
# Also, store the results in the `replacements` map for the nodes that follow.
for node_rv, node_value in zip(node_rvs, node_values):
replacements[node_rv] = node_value

remapped_vars = replace_vars_in_graphs(
graphs=q_values + list(node.inputs),
graphs=node_values + list(node.inputs),
replacements=replacements,
)
q_values = remapped_vars[: len(q_values)]
q_rv_inputs = remapped_vars[len(q_values) :]
node_values = remapped_vars[: len(node_values)]
node_inputs = remapped_vars[len(node_values) :]

q_logprob_vars = _logprob(
node_logprobs = _logprob(
node.op,
q_values,
*q_rv_inputs,
node_values,
*node_inputs,
**kwargs,
)

if not isinstance(q_logprob_vars, list | tuple):
q_logprob_vars = [q_logprob_vars]
if not isinstance(node_logprobs, list | tuple):
node_logprobs = [node_logprobs]

for q_value_var, q_logprob_var in zip(q_values, q_logprob_vars):
q_value_var = original_values[q_value_var]
for node_output_idx, node_value, node_logprob in zip(
node_output_idxs, node_values, node_logprobs
):
original_value = original_values[node_output_idx]

if q_value_var.name:
q_logprob_var.name = f"{q_value_var.name}_logprob"
if original_value.name:
node_logprob.name = f"{original_value.name}_logprob"

if q_value_var in logprob_vars:
if original_value in values_to_logprobs:
raise ValueError(
f"More than one logprob term was assigned to the value var {q_value_var}"
f"More than one logprob term was assigned to the value var {original_value}"
)

logprob_vars[q_value_var] = q_logprob_var

# Recompute test values for the changes introduced by the replacements above.
if config.compute_test_value != "off":
for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars):
compute_test_value(node)
values_to_logprobs[original_value] = node_logprob

missing_value_terms = set(original_values.values()) - set(logprob_vars.keys())
missing_value_terms = set(original_values) - set(values_to_logprobs)
if missing_value_terms:
raise RuntimeError(
f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
)

logprob_expressions = list(logprob_vars.values())
cleanup_ir(logprob_expressions)
logprobs = list(values_to_logprobs.values())
cleanup_ir(logprobs)

if warn_rvs:
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprob_expressions)
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs)
if rvs_in_logp_expressions:
warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)

return logprob_vars
return values_to_logprobs


def transformed_conditional_logp(
Expand Down
18 changes: 5 additions & 13 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import check_potential_measurability
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import check_potential_measurability, filter_measurable_variables


class MeasurableComparison(MeasurableElemwise):
Expand All @@ -40,11 +40,7 @@ class MeasurableComparison(MeasurableElemwise):

@node_rewriter(tracks=[gt, lt, ge, le])
def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

measurable_inputs = rv_map_feature.request_measurable(node.inputs)
measurable_inputs = filter_measurable_variables(node.inputs)

if len(measurable_inputs) != 1:
return None
Expand All @@ -62,7 +58,7 @@ def find_measurable_comparisons(fgraph: FunctionGraph, node: Node) -> list[Tenso
const = node.inputs[(measurable_var_idx + 1) % 2]

# check for potential measurability of const
if check_potential_measurability([const], rv_map_feature.rv_values.keys()):
if check_potential_measurability([const]):
return None

node_scalar_op = node.op.scalar_op
Expand Down Expand Up @@ -132,16 +128,12 @@ class MeasurableBitwise(MeasurableElemwise):

@node_rewriter(tracks=[invert])
def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

base_var = node.inputs[0]

if not base_var.dtype.startswith("bool"):
raise None

if not rv_map_feature.request_measurable([base_var]):
if not filter_measurable_variables([base_var]):
return None

node_scalar_op = node.op.scalar_op
Expand Down
16 changes: 4 additions & 12 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue, filter_measurable_variables


class MeasurableClip(MeasurableElemwise):
Expand All @@ -65,11 +65,7 @@ class MeasurableClip(MeasurableElemwise):
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if not rv_map_feature.request_measurable(node.inputs):
if not filter_measurable_variables(node.inputs):
return None

base_var, lower_bound, upper_bound = node.inputs
Expand Down Expand Up @@ -158,11 +154,7 @@ class MeasurableRound(MeasurableElemwise):

@node_rewriter(tracks=[ceil, floor, round_half_to_even])
def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> list[TensorVariable] | None:
rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if not rv_map_feature.request_measurable(node.inputs):
if not filter_measurable_variables(node.inputs):
return None

[base_var] = node.inputs
Expand Down
Loading

0 comments on commit 36312a2

Please sign in to comment.