Skip to content

Commit

Permalink
add measurable branches and props to init
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas3156 committed Aug 8, 2023
1 parent d4be8b5 commit 9053937
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,12 @@ class MeasurableSwitchEncoding(MeasurableElemwise):
"""A placeholder used to specify the log-likelihood for a encoded RV sub-graph."""

valid_scalar_types = (Switch,)
# number of measurable branches to facilitate correct logprob calculation
measurable_branches = 0


measurable_switch_encoding = MeasurableSwitchEncoding(scalar_switch)
def __init__(self, measurable_branches):
super().__init__(scalar_switch)
self.__props__ = super().__props__ + ("measurable_branches",)
self.measurable_branches = measurable_branches

Check warning on line 255 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L253-L255

Added lines #L253 - L255 were not covered by tests
# number of measurable branches to facilitate correct logprob calculation


@node_rewriter(tracks=[switch])
Expand Down Expand Up @@ -286,6 +287,9 @@ def find_measurable_switch_encoding(
if base_var.dtype.startswith("int"):
return None

Check warning on line 288 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L286-L288

Added lines #L286 - L288 were not covered by tests

# default number of measurable branches is zero
measurable_switch_encoding = MeasurableSwitchEncoding(measurable_branches=0)

Check warning on line 291 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L291

Added line #L291 was not covered by tests

# Maximum one branch allowed to be measurable
if len(measurable_comp_list) > 1:
return None

Check warning on line 295 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L294-L295

Added lines #L294 - L295 were not covered by tests
Expand Down Expand Up @@ -339,7 +343,7 @@ def switch_encoding_logprob(op, values, *inputs, **kwargs):
),
)
else:
base_var = components[1] # there needs to be a better way to obtain the base variable.
base_var = components[1]

Check warning on line 346 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L346

Added line #L346 was not covered by tests

logp_first_branch = _logprob_helper(switch_condn, pt.as_tensor(np.array(True)), **kwargs)

Check warning on line 348 in pymc/logprob/censoring.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/censoring.py#L348

Added line #L348 was not covered by tests

Expand Down

0 comments on commit 9053937

Please sign in to comment.