From 9053937e38efe6037f551060dd99ff54f0288ce5 Mon Sep 17 00:00:00 2001 From: shreyas3156 Date: Wed, 9 Aug 2023 04:35:39 +0530 Subject: [PATCH] add measurable branches and props to init --- pymc/logprob/censoring.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index da36086cdb..8037371ffe 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -248,11 +248,12 @@ class MeasurableSwitchEncoding(MeasurableElemwise): """A placeholder used to specify the log-likelihood for a encoded RV sub-graph.""" valid_scalar_types = (Switch,) - # number of measurable branches to facilitate correct logprob calculation - measurable_branches = 0 - -measurable_switch_encoding = MeasurableSwitchEncoding(scalar_switch) + def __init__(self, measurable_branches): + super().__init__(scalar_switch) + self.__props__ = super().__props__ + ("measurable_branches",) + self.measurable_branches = measurable_branches + # number of measurable branches to facilitate correct logprob calculation @node_rewriter(tracks=[switch]) @@ -286,6 +287,9 @@ def find_measurable_switch_encoding( if base_var.dtype.startswith("int"): return None + # default number of measurable branches is zero + measurable_switch_encoding = MeasurableSwitchEncoding(measurable_branches=0) + # Maximum one branch allowed to be measurable if len(measurable_comp_list) > 1: return None @@ -339,7 +343,7 @@ def switch_encoding_logprob(op, values, *inputs, **kwargs): ), ) else: - base_var = components[1] # there needs to be a better way to obtain the base variable. + base_var = components[1] logp_first_branch = _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs)