From 41cd3aba400aa11fd86f55d9c134ab35139e3506 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Wed, 19 Jul 2023 08:33:48 +0530 Subject: [PATCH 01/14] Add logprob derivation for switch encoding graphs --- pymc/logprob/censoring.py | 161 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 155 insertions(+), 6 deletions(-) diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 123c3394b9..1ac2bc953e 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -34,22 +34,25 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import List, Optional +from typing import Callable, Container, Generator, Iterable, List, Optional, Set, Tuple import numpy as np import pytensor.tensor as pt -from pytensor.graph.basic import Node +from pytensor.graph.basic import Node, walk from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven +from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven, Switch from pytensor.scalar.basic import clip as scalar_clip +from pytensor.scalar.basic import switch as scalar_switch +from pytensor.tensor.basic import switch as switch from pytensor.tensor.math import ceil, clip, floor, round_half_to_even -from pytensor.tensor.var import TensorConstant +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.var import TensorConstant, TensorVariable -from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob +from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob, _logprob_helper from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import CheckParameterValue +from pymc.logprob.utils import CheckParameterValue, check_potential_measurability class MeasurableClip(MeasurableElemwise): @@ -237,3 +240,149 @@ def round_logprob(op, values, base_rv, **kwargs): from pymc.math import logdiffexp return logdiffexp(logcdf_upper, logcdf_lower) + + +class MeasurableSwitchEncoding(MeasurableElemwise): + """A placeholder used to specify the log-likelihood for a encoded RV sub-graph.""" + + valid_scalar_types = (Switch,) + + +measurable_switch_encoding = MeasurableSwitchEncoding(scalar_switch) + + +@node_rewriter(tracks=[switch]) +def find_measurable_switch_encoding( + fgraph: FunctionGraph, node: Node +) -> Optional[List[MeasurableSwitchEncoding]]: + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + valued_rvs = rv_map_feature.rv_values.keys() + + switch_condn, *components = node.inputs + + # broadcasting of switch condition is not supported + if switch_condn.ndim != 0: + if any(switch_condn.type.broadcastable): + return None + + if rv_map_feature.request_measurable([switch_condn]) != [switch_condn]: + return None + # this automatically checks the measurability of the switch condition and converts switch to MeasurableSwitch + + measurable_comp_idx = next( + ( + idx + for idx, component in enumerate(components) + if check_potential_measurability([component], valued_rvs) + ), + -1, + ) + + # If at least one of the branches is measurable + if measurable_comp_idx != -1: + measurable_component = components[measurable_comp_idx] + + # broadcasting of the measurable component is not supported + if measurable_component.ndim != 0 and any(measurable_component.type.broadcastable): + return None + + if not compare_measurability_source([switch_condn, measurable_component], valued_rvs): + return None + + measurable_inputs = rv_map_feature.request_measurable(components) + # Maximum one branch allowed to be measurable + if len(measurable_inputs) > 1: + return None + + if measurable_comp_idx == 0: + # changing the first branch of switch to always be the encoding + encoded_rv = measurable_switch_encoding.make_node( + pt.invert(switch_condn), *components[::-1] + ).default_output() + # FIXME: For graphs like y = pt.switch(x > 0.5, x, 0.3), they should be rewritten + # to pt.switch(x <= 0.5, 0.3, x). + # But the invert Op does not get converted to its Measurable counterpart. + + return [encoded_rv] + + encoded_rv = measurable_switch_encoding.make_node(switch_condn, *components).default_output() + + return [encoded_rv] + + +@_logprob.register(MeasurableSwitchEncoding) +def switch_encoding_logprob(op, values, *inputs, **kwargs): + (value,) = values + + switch_condn, *components = inputs + + # Right now, this only works for switch with both encoding branches. + logprob = pt.switch( + pt.eq(value, components[0]), + _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs), + pt.switch( + pt.eq(value, components[1]), + _logprob_helper(switch_condn, pt.as_tensor(np.array(False))), + -np.inf, + ), + ) + + # TODO: Calculate logprob for switch with one measurable component If RV is discrete, + # give preference over encoding. + + return logprob + + +measurable_ir_rewrites_db.register( + "find_measurable_switch_encoding", find_measurable_switch_encoding, "basic", "censoring" +) + + +def compare_measurability_source( + inputs: Tuple[TensorVariable], valued_rvs: Container[TensorVariable] +) -> bool: + ancestor_var_set = set() + + # retrieve the source of measurability for all elements in 'inputs' separately. + for inp in inputs: + for ancestor_var in walk_model( + [inp], + walk_past_rvs=False, + stop_at_vars=set(valued_rvs), + ): + if ( + ancestor_var.owner + and isinstance(ancestor_var.owner.op, RandomVariable) + and ancestor_var not in valued_rvs + ): + ancestor_var_set.add(ancestor_var) + + return len(ancestor_var_set) == 1 + + +def walk_model( + graphs: Iterable[TensorVariable], + walk_past_rvs: bool = False, + stop_at_vars: Optional[Set[TensorVariable]] = None, + expand_fn: Callable[[TensorVariable], List[TensorVariable]] = lambda var: [], +) -> Generator[TensorVariable, None, None]: + if stop_at_vars is None: + stop_at_vars = set() + + def expand(var: TensorVariable, stop_at_vars=stop_at_vars) -> List[TensorVariable]: + new_vars = expand_fn(var) + + if ( + var.owner + and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable)) + and (var not in stop_at_vars) + ): + new_vars.extend(reversed(var.owner.inputs)) + + return new_vars + + yield from walk(graphs, expand, False) From 74f42c399f68494aaadfba759a9e9ae57aee01b4 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Wed, 19 Jul 2023 08:34:30 +0530 Subject: [PATCH 02/14] Tests --- tests/logprob/test_censoring.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index 46c0a69d3f..eddcb47ec1 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -261,3 +261,36 @@ def test_rounding(rounding_op): logprob.eval({xr_vv: test_value}), expected_logp, ) + + +def test_switch_encoding_both_branches(): + x_rv = pt.random.normal(0.5, 1) + y_rv = pt.switch(x_rv < 0.3, 1, 2) + + y_vv = y_rv.clone() + ref_scipy = st.norm(0.5, 1) + + logprob = logp(y_rv, y_vv) + logp_fn = pytensor.function([y_vv], logprob) + + assert logp_fn(3) == -np.inf + + assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.3)) + assert np.isclose(logp_fn(2), ref_scipy.logsf(0.3)) + + +@pytest.mark.skip(reason="Logprob calculation for measurable branches not added") +def test_switch_encoding_second_branch_measurable(): + x_rv = pt.random.normal(0.5, 1) + y_rv = pt.switch(x_rv < 0.3, 1, x_rv) + + y_vv = y_rv.clone() + ref_scipy = st.norm(0.5, 1) + + logprob = logp(y_rv, y_vv) + logp_fn = pytensor.function([y_vv], logprob) + + assert logp_fn(3) == -np.inf + + assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.3)) + assert np.isclose(logp_fn(0.2), -np.inf) From 4134b0db768beddabd108e7a9c91622f3b4908a5 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Thu, 20 Jul 2023 15:09:24 +0530 Subject: [PATCH 03/14] Rectify broadcasting check and identification of measurable component --- pymc/logprob/censoring.py | 41 ++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 1ac2bc953e..6aa1f2f1b9 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -265,37 +265,34 @@ def find_measurable_switch_encoding( switch_condn, *components = node.inputs # broadcasting of switch condition is not supported - if switch_condn.ndim != 0: - if any(switch_condn.type.broadcastable): - return None + if switch_condn.type.broadcastable != node.outputs[0].type.broadcastable: + return None if rv_map_feature.request_measurable([switch_condn]) != [switch_condn]: return None - # this automatically checks the measurability of the switch condition and converts switch to MeasurableSwitch + # this automatically checks the measurability of the switch condition and converts switch to MeasurableSwitch - measurable_comp_idx = next( - ( - idx - for idx, component in enumerate(components) - if check_potential_measurability([component], valued_rvs) - ), - -1, - ) + measurable_comp_list = [ + idx + for idx, component in enumerate(components) + if check_potential_measurability([component], valued_rvs) + ] + + # Maximum one branch allowed to be measurable + if len(measurable_comp_list) > 1: + return None # If at least one of the branches is measurable - if measurable_comp_idx != -1: + if len(measurable_comp_list) == 1: + measurable_comp_idx = measurable_comp_list[0] measurable_component = components[measurable_comp_idx] # broadcasting of the measurable component is not supported - if measurable_component.ndim != 0 and any(measurable_component.type.broadcastable): - return None - - if not compare_measurability_source([switch_condn, measurable_component], valued_rvs): - return None - - measurable_inputs = rv_map_feature.request_measurable(components) - # Maximum one branch allowed to be measurable - if len(measurable_inputs) > 1: + if ( + (measurable_component.type.broadcastable != node.outputs[0].broadcastable) + or (not compare_measurability_source([switch_condn, measurable_component], valued_rvs)) + or (not rv_map_feature.request_measurable([measurable_component])) + ): return None if measurable_comp_idx == 0: From fd7d4b936c5b1710aa91f0156f0c41ef1d38ed5b Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Fri, 21 Jul 2023 16:09:41 +0530 Subject: [PATCH 04/14] Support inverting the order of switch branches --- pymc/logprob/censoring.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 6aa1f2f1b9..9be81c44e9 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -51,6 +51,7 @@ from pytensor.tensor.var import TensorConstant, TensorVariable from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob, _logprob_helper +from pymc.logprob.binary import MeasurableBitwise from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import CheckParameterValue, check_potential_measurability @@ -297,12 +298,13 @@ def find_measurable_switch_encoding( if measurable_comp_idx == 0: # changing the first branch of switch to always be the encoding + inverted_switch = pt.invert(switch_condn) + + bitwise_op = MeasurableBitwise(inverted_switch.owner.op.scalar_op) + measurable_inverted_switch = bitwise_op.make_node(switch_condn).default_output() encoded_rv = measurable_switch_encoding.make_node( - pt.invert(switch_condn), *components[::-1] + measurable_inverted_switch, *components[::-1] ).default_output() - # FIXME: For graphs like y = pt.switch(x > 0.5, x, 0.3), they should be rewritten - # to pt.switch(x <= 0.5, 0.3, x). - # But the invert Op does not get converted to its Measurable counterpart. return [encoded_rv] From 6fd4167d873b89c14a503056c91eeaad1531fb8f Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Tue, 1 Aug 2023 16:20:53 +0530 Subject: [PATCH 05/14] request_measurable on switch after checking potential measurability --- pymc/logprob/censoring.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 9be81c44e9..e0d1a61520 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -269,16 +269,16 @@ def find_measurable_switch_encoding( if switch_condn.type.broadcastable != node.outputs[0].type.broadcastable: return None - if rv_map_feature.request_measurable([switch_condn]) != [switch_condn]: - return None - # this automatically checks the measurability of the switch condition and converts switch to MeasurableSwitch - measurable_comp_list = [ idx for idx, component in enumerate(components) if check_potential_measurability([component], valued_rvs) ] + # this automatically checks the measurability of the switch condition and converts switch to MeasurableSwitch + if rv_map_feature.request_measurable([switch_condn]) != [switch_condn]: + return None + # Maximum one branch allowed to be measurable if len(measurable_comp_list) > 1: return None From f8a1c770eff75c5a0b64121cbbc37a55fb0987e8 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Wed, 2 Aug 2023 17:46:19 +0530 Subject: [PATCH 06/14] not allowing discrete RVs --- pymc/logprob/censoring.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index e0d1a61520..f9728356a7 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -279,6 +279,10 @@ def find_measurable_switch_encoding( if rv_map_feature.request_measurable([switch_condn]) != [switch_condn]: return None + [base_var] = rv_map_feature.request_measurable([switch_condn.owner.inputs[0]]) + if base_var.dtype.startswith("int"): + return None + # Maximum one branch allowed to be measurable if len(measurable_comp_list) > 1: return None From 73d79794ad26d8c06fd5524e1e2bc0a095b7b164 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Thu, 3 Aug 2023 07:38:41 +0530 Subject: [PATCH 07/14] Add logp derivation for switch with one measurable branch --- pymc/logprob/censoring.py | 43 +++++++++++++++++++++++---------- tests/logprob/test_censoring.py | 9 +++---- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index f9728356a7..da36086cdb 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -54,6 +54,7 @@ from pymc.logprob.binary import MeasurableBitwise from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db from pymc.logprob.utils import CheckParameterValue, check_potential_measurability +from pymc.pytensorf import replace_rvs_by_values class MeasurableClip(MeasurableElemwise): @@ -247,6 +248,8 @@ class MeasurableSwitchEncoding(MeasurableElemwise): """A placeholder used to specify the log-likelihood for a encoded RV sub-graph.""" valid_scalar_types = (Switch,) + # number of measurable branches to facilitate correct logprob calculation + measurable_branches = 0 measurable_switch_encoding = MeasurableSwitchEncoding(scalar_switch) @@ -292,6 +295,8 @@ def find_measurable_switch_encoding( measurable_comp_idx = measurable_comp_list[0] measurable_component = components[measurable_comp_idx] + measurable_switch_encoding.measurable_branches = 1 + # broadcasting of the measurable component is not supported if ( (measurable_component.type.broadcastable != node.outputs[0].broadcastable) @@ -323,19 +328,31 @@ def switch_encoding_logprob(op, values, *inputs, **kwargs): switch_condn, *components = inputs - # Right now, this only works for switch with both encoding branches. - logprob = pt.switch( - pt.eq(value, components[0]), - _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs), - pt.switch( - pt.eq(value, components[1]), - _logprob_helper(switch_condn, pt.as_tensor(np.array(False))), - -np.inf, - ), - ) - - # TODO: Calculate logprob for switch with one measurable component If RV is discrete, - # give preference over encoding. + if op.measurable_branches == 0: + logprob = pt.switch( + pt.eq(value, components[0]), + _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs), + pt.switch( + pt.eq(value, components[1]), + _logprob_helper(switch_condn, pt.as_tensor(np.array(False))), + -np.inf, + ), + ) + else: + base_var = components[1] # there needs to be a better way to obtain the base variable. + + logp_first_branch = _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs) + + (switch_condn,) = replace_rvs_by_values([switch_condn], rvs_to_values={base_var: value}) + logprob = pt.switch( + pt.eq(value, components[0]), + logp_first_branch, + pt.switch( + pt.invert(switch_condn), + _logprob_helper(base_var, value, **kwargs), + -np.inf, + ), + ) return logprob diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index eddcb47ec1..41620d46d4 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -279,10 +279,9 @@ def test_switch_encoding_both_branches(): assert np.isclose(logp_fn(2), ref_scipy.logsf(0.3)) -@pytest.mark.skip(reason="Logprob calculation for measurable branches not added") def test_switch_encoding_second_branch_measurable(): x_rv = pt.random.normal(0.5, 1) - y_rv = pt.switch(x_rv < 0.3, 1, x_rv) + y_rv = pt.switch(x_rv < 1, 1, x_rv) y_vv = y_rv.clone() ref_scipy = st.norm(0.5, 1) @@ -290,7 +289,7 @@ def test_switch_encoding_second_branch_measurable(): logprob = logp(y_rv, y_vv) logp_fn = pytensor.function([y_vv], logprob) - assert logp_fn(3) == -np.inf + assert logp_fn(0.5) == -np.inf - assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.3)) - assert np.isclose(logp_fn(0.2), -np.inf) + assert np.isclose(logp_fn(1), ref_scipy.logcdf(1)) + assert np.isclose(logp_fn(1.2), ref_scipy.logpdf(1.2)) From d4be8b5cde2ff790567f953299be48728ba94edf Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Mon, 7 Aug 2023 19:14:53 +0530 Subject: [PATCH 08/14] Tests for measurable branches and deny discrete --- tests/logprob/test_censoring.py | 35 ++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index 41620d46d4..fdcfcf0781 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -279,17 +279,38 @@ def test_switch_encoding_both_branches(): assert np.isclose(logp_fn(2), ref_scipy.logsf(0.3)) -def test_switch_encoding_second_branch_measurable(): - x_rv = pt.random.normal(0.5, 1) - y_rv = pt.switch(x_rv < 1, 1, x_rv) +@pytest.mark.parametrize( + "measurable_idx, test_values, exp_logp", + [ + (1, (0.9, 1, 1.5), (-np.inf, st.norm(0.5, 1).logcdf(1), st.norm(0.5, 1).logpdf(1.5))), + (0, (1.5, 1, 0.9), (-np.inf, st.norm(0.5, 1).logsf(1), st.norm(0.5, 1).logpdf(0.9))), + ], +) +def test_switch_encoding_one_branch_measurable(measurable_idx, test_values, exp_logp): + x_rv = pt.random.normal(0.5, 1) # should not be defined again ideally + branches = (1, x_rv) if measurable_idx == 1 else (x_rv, 1) + + y_rv = pt.switch(x_rv < 1, *branches) y_vv = y_rv.clone() - ref_scipy = st.norm(0.5, 1) logprob = logp(y_rv, y_vv) + logp_fn = pytensor.function([y_vv], logprob) - assert logp_fn(0.5) == -np.inf + for i, j in zip(test_values, exp_logp): + assert np.isclose(logp_fn(i), j) - assert np.isclose(logp_fn(1), ref_scipy.logcdf(1)) - assert np.isclose(logp_fn(1.2), ref_scipy.logpdf(1.2)) + +def test_switch_encoding_discrete_fail(): + x_rv = pt.random.poisson(2) + y_rv = pt.switch(x_rv > 3, x_rv, 1) + + y_vv = x_rv.clone() + y_vv_test = 1 + + with pytest.raises( + NotImplementedError, + match="Logprob method not implemented", + ): + logp(y_rv, y_vv).eval({y_vv: y_vv_test}) From 9053937e38efe6037f551060dd99ff54f0288ce5 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Wed, 9 Aug 2023 04:35:39 +0530 Subject: [PATCH 09/14] add measurable branches and props to init --- pymc/logprob/censoring.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index da36086cdb..8037371ffe 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -248,11 +248,12 @@ class MeasurableSwitchEncoding(MeasurableElemwise): """A placeholder used to specify the log-likelihood for a encoded RV sub-graph.""" valid_scalar_types = (Switch,) - # number of measurable branches to facilitate correct logprob calculation - measurable_branches = 0 - -measurable_switch_encoding = MeasurableSwitchEncoding(scalar_switch) + def __init__(self, measurable_branches): + super().__init__(scalar_switch) + self.__props__ = super().__props__ + ("measurable_branches",) + self.measurable_branches = measurable_branches + # number of measurable branches to facilitate correct logprob calculation @node_rewriter(tracks=[switch]) @@ -286,6 +287,9 @@ def find_measurable_switch_encoding( if base_var.dtype.startswith("int"): return None + # default number of measurable branches is zero + measurable_switch_encoding = MeasurableSwitchEncoding(measurable_branches=0) + # Maximum one branch allowed to be measurable if len(measurable_comp_list) > 1: return None @@ -339,7 +343,7 @@ def switch_encoding_logprob(op, values, *inputs, **kwargs): ), ) else: - base_var = components[1] # there needs to be a better way to obtain the base variable. + base_var = components[1] logp_first_branch = _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs) From 67a634cd7634ec28db28285ee8a28db068c2bccc Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Wed, 9 Aug 2023 04:40:29 +0530 Subject: [PATCH 10/14] remove comment --- tests/logprob/test_censoring.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index fdcfcf0781..74a5e004ce 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -287,7 +287,7 @@ def test_switch_encoding_both_branches(): ], ) def test_switch_encoding_one_branch_measurable(measurable_idx, test_values, exp_logp): - x_rv = pt.random.normal(0.5, 1) # should not be defined again ideally + x_rv = pt.random.normal(0.5, 1) branches = (1, x_rv) if measurable_idx == 1 else (x_rv, 1) y_rv = pt.switch(x_rv < 1, *branches) From e73293f4e05d8d0e60e70d5a342e3080c6e6a41a Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Mon, 14 Aug 2023 08:17:38 +0530 Subject: [PATCH 11/14] Add test for broadcastability --- tests/logprob/test_censoring.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index 74a5e004ce..d46d711ccd 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -43,6 +43,9 @@ from pymc import logp from pymc.logprob import conditional_logp +from pymc.logprob.abstract import MeasurableVariable +from pymc.logprob.censoring import MeasurableSwitchEncoding +from pymc.logprob.rewriting import construct_ir_fgraph from pymc.logprob.transforms import LogTransform, TransformValuesRewrite from pymc.testing import assert_no_rvs @@ -302,6 +305,27 @@ def test_switch_encoding_one_branch_measurable(measurable_idx, test_values, exp_ assert np.isclose(logp_fn(i), j) +def test_switch_encoding_invalid_bcast(): + x_rv = pt.random.normal(0.5, 1, size=(4,)) + + switch_cond = x_rv < 0.3 + + valid_true_branch = pt.vector("valid_true_branch") + valid_false_branch = pt.vector("valid_false_branch") + + invalid_false_branch = pt.matrix("invalid_false_branch") + + valid_encoding = pt.switch(switch_cond, valid_true_branch, valid_false_branch) + fgraph, _, _ = construct_ir_fgraph({valid_encoding: valid_encoding.type()}) + assert isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchEncoding) + + invalid_encoding = pt.switch(switch_cond, valid_true_branch, invalid_false_branch) + fgraph, _, _ = construct_ir_fgraph({invalid_encoding: invalid_encoding.type()}) + assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) + assert not isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchEncoding) + + def test_switch_encoding_discrete_fail(): x_rv = pt.random.poisson(2) y_rv = pt.switch(x_rv > 3, x_rv, 1) From 78e9bc5ee57a1fbdcf972e3a6a670d1b6a6709b6 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Mon, 14 Aug 2023 08:20:04 +0530 Subject: [PATCH 12/14] Add docstring in test saying discrete is not supported yet --- tests/logprob/test_censoring.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index d46d711ccd..64cee57a4e 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -327,6 +327,7 @@ def test_switch_encoding_invalid_bcast(): def test_switch_encoding_discrete_fail(): + """We do not support the encoding graphs of discrete RVs yet""" x_rv = pt.random.poisson(2) y_rv = pt.switch(x_rv > 3, x_rv, 1) From 80093e56d5d9131f2a579604099bad35281d9c49 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Mon, 14 Aug 2023 08:22:06 +0530 Subject: [PATCH 13/14] Modify encoding from int to float in tests --- tests/logprob/test_censoring.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index 64cee57a4e..ad7f53a318 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -268,7 +268,7 @@ def test_rounding(rounding_op): def test_switch_encoding_both_branches(): x_rv = pt.random.normal(0.5, 1) - y_rv = pt.switch(x_rv < 0.3, 1, 2) + y_rv = pt.switch(x_rv < 0.3, 1.0, 2.0) y_vv = y_rv.clone() ref_scipy = st.norm(0.5, 1) @@ -276,7 +276,7 @@ def test_switch_encoding_both_branches(): logprob = logp(y_rv, y_vv) logp_fn = pytensor.function([y_vv], logprob) - assert logp_fn(3) == -np.inf + assert logp_fn(1.5) == -np.inf assert np.isclose(logp_fn(1), ref_scipy.logcdf(0.3)) assert np.isclose(logp_fn(2), ref_scipy.logsf(0.3)) From dc500611a34cd7b27285e8ea10f2ebad3f2cf8c1 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Mon, 28 Aug 2023 08:37:16 -0400 Subject: [PATCH 14/14] test invalid broadcast of switch --- tests/logprob/test_censoring.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/logprob/test_censoring.py b/tests/logprob/test_censoring.py index ad7f53a318..e53e7e85b0 100644 --- a/tests/logprob/test_censoring.py +++ b/tests/logprob/test_censoring.py @@ -306,24 +306,23 @@ def test_switch_encoding_one_branch_measurable(measurable_idx, test_values, exp_ def test_switch_encoding_invalid_bcast(): - x_rv = pt.random.normal(0.5, 1, size=(4,)) + x_rv = pt.random.normal(0.5, 1) - switch_cond = x_rv < 0.3 + y_rv = pt.switch(x_rv < 0.3, 0.0, 1.0) + y_rv_invalid = pt.switch(x_rv < 0.3, [0.0, 0.5], 1.0) - valid_true_branch = pt.vector("valid_true_branch") - valid_false_branch = pt.vector("valid_false_branch") + y_vv = y_rv.clone() + y_vv_invalid = y_rv_invalid.clone() - invalid_false_branch = pt.matrix("invalid_false_branch") + y_test = 1.0 - valid_encoding = pt.switch(switch_cond, valid_true_branch, valid_false_branch) - fgraph, _, _ = construct_ir_fgraph({valid_encoding: valid_encoding.type()}) - assert isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) - assert isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchEncoding) + assert np.isclose(logp(y_rv, y_vv).eval({y_vv: y_test}), st.norm(0.5, 1).logsf(0.3)) - invalid_encoding = pt.switch(switch_cond, valid_true_branch, invalid_false_branch) - fgraph, _, _ = construct_ir_fgraph({invalid_encoding: invalid_encoding.type()}) - assert not isinstance(fgraph.outputs[0].owner.op, MeasurableVariable) - assert not isinstance(fgraph.outputs[0].owner.op, MeasurableSwitchEncoding) + with pytest.raises( + NotImplementedError, + match="Logprob method not implemented", + ): + logp(y_rv_invalid, y_vv_invalid).eval({y_vv_invalid: y_test}) def test_switch_encoding_discrete_fail():