Skip to content

Commit

Permalink
Add a temporary debugging mode
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed Dec 6, 2023
1 parent b49b960 commit e7a1fdc
Showing 1 changed file with 37 additions and 4 deletions.
41 changes: 37 additions & 4 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,9 +1011,19 @@ def factorize(edge_list, fixed_nodes):
return internal, external

@staticmethod
@numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)")
@numba.njit(
"f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1, b1)"
)
def propagate(
edges, likelihoods, posterior, messages, scale, log_partition, max_shape, min_kl
edges,
likelihoods,
posterior,
messages,
scale,
log_partition,
max_shape,
min_kl,
debug,
):
"""
Update approximating factors for each edge, returning average relative
Expand Down Expand Up @@ -1056,41 +1066,61 @@ def posterior_damping(x):
return d

for i, p, c in edges:
if debug:
print("---\nedge:", i, "parent:", p, "child:", c)
# Damped downdate to ensure proper cavity distributions
parent_message = messages[i, 0] * scale[p]
child_message = messages[i, 1] * scale[c]
if debug:
print("p-mess:", parent_message, "c-mess:", child_message)
parent_delta = cavity_damping(posterior[p], parent_message)
child_delta = cavity_damping(posterior[c], child_message)
delta = min(parent_delta, child_delta)
if debug:
print("delta:", delta)
# The cavity posteriors: the approximation omitting the variational
# factors for this edge.
parent_cavity = posterior[p] - delta * parent_message
child_cavity = posterior[c] - delta * child_message
if debug:
print("p-cavi:", parent_cavity, "c-cavi:", child_cavity)
# The edge likelihood, scaled by the damping factor
edge_likelihood = delta * likelihoods[i]
if debug:
print("e-like:", edge_likelihood)
# The target posterior: the cavity multiplied by the edge
# likelihood then projected onto a gamma via moment matching.
logconst, parent_post, child_post = approx.gamma_projection(
parent_cavity, child_cavity, edge_likelihood, min_kl
)
if debug:
print("logconst:", logconst)
if debug:
print("p-post:", parent_post, "c-post:", child_post)
# The messages: the difference in natural parameters between the
# target and cavity posteriors.
messages[i, 0] += (parent_post - posterior[p]) / scale[p]
messages[i, 1] += (child_post - posterior[c]) / scale[c]
if debug:
print("p-updt:", messages[i, 0], "c-updt:", messages[i, 1])
# Contribution to the marginal likelihood from the edge
log_partition[i] = logconst # TODO: incomplete
# Constrain the messages so that the gamma shape parameter for each
# posterior is bounded (e.g. set a maximum precision for log(age)).
parent_eta = posterior_damping(parent_post)
child_eta = posterior_damping(child_post)
if debug:
print("p-scal:", parent_eta, "c-scal:", child_eta)
posterior[p] = parent_eta * parent_post
posterior[c] = child_eta * child_post
scale[p] *= parent_eta
scale[c] *= child_eta
if debug:
print("p-end:", posterior[p], "c-end:", posterior[c])

return 0.0 # TODO, placeholder

def iterate(self, max_shape=1000, min_kl=True):
def iterate(self, max_shape=1000, min_kl=True, debug=False):
"""
Update edge factors from leaves to root then from root to leaves,
and return approximate log marginal likelihood (TODO)
Expand All @@ -1105,6 +1135,7 @@ def iterate(self, max_shape=1000, min_kl=True):
self.log_partition,
max_shape,
min_kl,
debug,
)
self.propagate(
self.edges[::-1],
Expand All @@ -1115,6 +1146,7 @@ def iterate(self, max_shape=1000, min_kl=True):
self.log_partition,
max_shape,
min_kl,
debug,
)

# TODO
Expand Down Expand Up @@ -1567,6 +1599,7 @@ def variational_dates(
num_threads=None, # Unused, matches get_dates()
probability_space=None, # Can only be None, simply to match get_dates()
ignore_oldest_root=False, # Can only be False, simply to match get_dates()
debug=False, # Print a ton of extra information
):
"""
Infer dates for the nodes in a tree sequence using expectation propagation,
Expand Down Expand Up @@ -1662,7 +1695,7 @@ def variational_dates(
desc="Expectation Propagation",
disable=not progress,
):
dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl)
dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl, debug=debug)

posterior = priors.clone_with_new_data(
grid_data=dynamic_prog.posterior[priors.nonfixed_nodes, :]
Expand Down

0 comments on commit e7a1fdc

Please sign in to comment.