diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 123c3394b9..1ac2bc953e 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -34,22 +34,25 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import List, Optional +from typing import Callable, Container, Generator, Iterable, List, Optional, Set, Tuple import numpy as np import pytensor.tensor as pt -from pytensor.graph.basic import Node +from pytensor.graph.basic import Node, walk from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven +from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven, Switch from pytensor.scalar.basic import clip as scalar_clip +from pytensor.scalar.basic import switch as scalar_switch +from pytensor.tensor.basic import switch as switch from pytensor.tensor.math import ceil, clip, floor, round_half_to_even -from pytensor.tensor.var import TensorConstant +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.var import TensorConstant, TensorVariable -from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob +from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob, _logprob_helper from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import CheckParameterValue +from pymc.logprob.utils import CheckParameterValue, check_potential_measurability class MeasurableClip(MeasurableElemwise): @@ -237,3 +240,149 @@ def round_logprob(op, values, base_rv, **kwargs): from pymc.math import logdiffexp return logdiffexp(logcdf_upper, logcdf_lower) + + +class MeasurableSwitchEncoding(MeasurableElemwise): + """A placeholder used to specify the log-likelihood for a encoded RV sub-graph.""" + + valid_scalar_types = (Switch,) + + +measurable_switch_encoding = MeasurableSwitchEncoding(scalar_switch) + + +@node_rewriter(tracks=[switch]) +def find_measurable_switch_encoding( + fgraph: FunctionGraph, node: Node +) -> Optional[List[MeasurableSwitchEncoding]]: + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + valued_rvs = rv_map_feature.rv_values.keys() + + switch_condn, *components = node.inputs + + # broadcasting of switch condition is not supported + if switch_condn.ndim != 0: + if any(switch_condn.type.broadcastable): + return None + + if rv_map_feature.request_measurable([switch_condn]) != [switch_condn]: + return None + # this automatically checks the measurability of the switch condition and converts switch to MeasurableSwitch + + measurable_comp_idx = next( + ( + idx + for idx, component in enumerate(components) + if check_potential_measurability([component], valued_rvs) + ), + -1, + ) + + # If at least one of the branches is measurable + if measurable_comp_idx != -1: + measurable_component = components[measurable_comp_idx] + + # broadcasting of the measurable component is not supported + if measurable_component.ndim != 0 and any(measurable_component.type.broadcastable): + return None + + if not compare_measurability_source([switch_condn, measurable_component], valued_rvs): + return None + + measurable_inputs = rv_map_feature.request_measurable(components) + # Maximum one branch allowed to be measurable + if len(measurable_inputs) > 1: + return None + + if measurable_comp_idx == 0: + # changing the first branch of switch to always be the encoding + encoded_rv = measurable_switch_encoding.make_node( + pt.invert(switch_condn), *components[::-1] + ).default_output() + # FIXME: For graphs like y = pt.switch(x > 0.5, x, 0.3), they should be rewritten + # to pt.switch(x <= 0.5, 0.3, x). + # But the invert Op does not get converted to its Measurable counterpart. + + return [encoded_rv] + + encoded_rv = measurable_switch_encoding.make_node(switch_condn, *components).default_output() + + return [encoded_rv] + + +@_logprob.register(MeasurableSwitchEncoding) +def switch_encoding_logprob(op, values, *inputs, **kwargs): + (value,) = values + + switch_condn, *components = inputs + + # Right now, this only works for switch with both encoding branches. + logprob = pt.switch( + pt.eq(value, components[0]), + _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs), + pt.switch( + pt.eq(value, components[1]), + _logprob_helper(switch_condn, pt.as_tensor(np.array(False))), + -np.inf, + ), + ) + + # TODO: Calculate logprob for switch with one measurable component If RV is discrete, + # give preference over encoding. + + return logprob + + +measurable_ir_rewrites_db.register( + "find_measurable_switch_encoding", find_measurable_switch_encoding, "basic", "censoring" +) + + +def compare_measurability_source( + inputs: Tuple[TensorVariable], valued_rvs: Container[TensorVariable] +) -> bool: + ancestor_var_set = set() + + # retrieve the source of measurability for all elements in 'inputs' separately. + for inp in inputs: + for ancestor_var in walk_model( + [inp], + walk_past_rvs=False, + stop_at_vars=set(valued_rvs), + ): + if ( + ancestor_var.owner + and isinstance(ancestor_var.owner.op, RandomVariable) + and ancestor_var not in valued_rvs + ): + ancestor_var_set.add(ancestor_var) + + return len(ancestor_var_set) == 1 + + +def walk_model( + graphs: Iterable[TensorVariable], + walk_past_rvs: bool = False, + stop_at_vars: Optional[Set[TensorVariable]] = None, + expand_fn: Callable[[TensorVariable], List[TensorVariable]] = lambda var: [], +) -> Generator[TensorVariable, None, None]: + if stop_at_vars is None: + stop_at_vars = set() + + def expand(var: TensorVariable, stop_at_vars=stop_at_vars) -> List[TensorVariable]: + new_vars = expand_fn(var) + + if ( + var.owner + and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable)) + and (var not in stop_at_vars) + ): + new_vars.extend(reversed(var.owner.inputs)) + + return new_vars + + yield from walk(graphs, expand, False)