Skip to content

Commit

Permalink
Faster shape parameter scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
nspope committed Dec 3, 2023
1 parent d186ff9 commit a8924a5
Showing 1 changed file with 41 additions and 15 deletions.
56 changes: 41 additions & 15 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,8 @@ def __init__(self, *args, **kwargs):
# store factorization into messages: the edge ids pointing
# towards roots/leaves for each node
self.parent_factors, self.child_factors = self.factorize()
self.edge_delta = np.ones(self.ts.num_edges)
self.scale = np.ones(self.ts.num_nodes)

def factorize(self):
"""
Expand Down Expand Up @@ -1017,13 +1019,38 @@ def damp(posterior, message, eps=1e-3):
assert 0.0 < delta <= 1.0
return delta

def scale_posterior(self, n, max_shape):
assert self.posterior[n][0] > -1.0 and self.posterior[n][1] > 0.0
eta = min(1.0, (max_shape - 1.0) / abs(self.posterior[n, 0]))
# DEBUG
# foo = self.priors[n].copy()
# damped = False
# for j in self.parent_factors[n]:
# foo += self.parent_message[j] * self.scale[n]
# if self.edge_delta[j] != 1.0:
# damped = True
# for j in self.child_factors[n]:
# foo += self.child_message[j] * self.scale[n]
# if self.edge_delta[j] != 1.0:
# damped = True
# print("SCALING", n, damped, self.posterior[n], foo)
# END DEBUG
self.posterior[n] *= eta
self.scale[n] *= eta
# if eta < 1.0:
# for j in self.parent_factors[n]:
# self.parent_message[j] *= eta
# for j in self.child_factors[n]:
# self.child_message[j] *= eta

def propagate(
self, *, edges, desc=None, progress=None, max_shape=1000, use_laplace=True
):
"""
Update approximating factor for each edge
"""
assert max_shape >= 1.0
eps = 1.0 / max_shape
if progress is None:
progress = self.progress
for edge in tqdm(
Expand All @@ -1035,17 +1062,20 @@ def propagate(
if p in self.fixednodes:
raise ValueError("Internal nodes can not be fixed in EP algorithm")
# Damped downdate to ensure proper cavity distributions
parent_delta = self.damp(
self.posterior[p], self.parent_message[i], 1 / max_shape
)
child_delta = self.damp(
self.posterior[c], self.child_message[i], 1 / max_shape
)
parent_message = self.parent_message[i] * self.scale[p]
child_message = self.child_message[i] * self.scale[c]
parent_delta = self.damp(self.posterior[p], parent_message, eps)
child_delta = self.damp(self.posterior[c], child_message, eps)
delta = min(parent_delta, child_delta)
self.edge_delta[i] = delta
# The cavity posteriors: the approximation omitting the variational
# factor for this edge.
parent_cavity = self.posterior[p] - delta * self.parent_message[i]
child_cavity = self.posterior[c] - delta * self.child_message[i]
parent_cavity = (
self.posterior[p] - delta * parent_message
) # self.parent_message[i]
child_cavity = (
self.posterior[c] - delta * child_message
) # self.child_message[i]
# The edge likelihood, scaled by the damping factor
edge_likelihood = delta * self.likelihoods[i]
# The target posterior: the cavity multiplied by the edge
Expand All @@ -1055,8 +1085,8 @@ def propagate(
)
# The messages: the difference in natural parameters between the
# target and cavity posteriors.
self.parent_message[i] += parent_post - self.posterior[p]
self.child_message[i] += child_post - self.posterior[c]
self.parent_message[i] += (parent_post - self.posterior[p]) / self.scale[p]
self.child_message[i] += (child_post - self.posterior[c]) / self.scale[c]
# Contribution to the marginal likelihood from the edge
self.log_partition[i] = logconst # TODO: incomplete
# Constrain the messages so that the gamma shape parameter for each
Expand All @@ -1067,11 +1097,7 @@ def propagate(
assert self.posterior[n][0] > -1.0 and self.posterior[n][1] > 0.0
eta = min(1.0, (max_shape - 1.0) / abs(self.posterior[n, 0]))
self.posterior[n] *= eta
if eta < 1.0:
for j in self.parent_factors[n]:
self.parent_message[j] *= eta
for j in self.child_factors[n]:
self.child_message[j] *= eta
self.scale[n] *= eta

def iterate(
self, *, iter_num=None, progress=None, max_shape=1000, use_laplace=True
Expand Down

0 comments on commit a8924a5

Please sign in to comment.