Skip to content

Commit

Permalink
Reset scale at end of iteration
Browse files Browse the repository at this point in the history
Edit docstring
  • Loading branch information
nspope committed Jan 10, 2024
1 parent a388028 commit 8c2e529
Showing 1 changed file with 13 additions and 17 deletions.
30 changes: 13 additions & 17 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,10 +1001,6 @@ def __init__(self, *args, global_prior, **kwargs):
self.posterior[p] += self.likelihoods[i]
# self.log_partition[i] += ... # TODO

# scaling factor for posterior: posterior is the sum of messages and
# prior, multiplied by a scaling term in (0, 1]
self.scale = np.ones(self.ts.num_nodes)

@staticmethod
def factorize(edge_list, fixed_nodes):
"""Split edges into internal and external"""
Expand All @@ -1023,13 +1019,12 @@ def factorize(edge_list, fixed_nodes):
return internal, external

@staticmethod
@numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)")
@numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8, b1)")
def propagate_likelihood(
edges,
likelihoods,
posterior,
messages,
scale,
log_partition,
max_shape,
min_kl,
Expand All @@ -1048,8 +1043,6 @@ def propagate_likelihood(
:param ndarray messages: array of dimension `[num_edges, 2, 2]`
containing parent/child messages (natural parameters) for each edge,
updated in-place.
:param ndarray scale: array of dimension `[num_nodes]`
containing the scaling factor for the posterior, updated in place
:param ndarray log_partition: array of dimension `[num_edges]`
containing the approximate normalizing constants per edge,
updated in-place.
Expand All @@ -1059,7 +1052,7 @@ def propagate_likelihood(

# Bound the shape parameter for the posterior and cavity distributions
# so that lower_cavi < lower_post < upper_post < upper_cavi.
upper_post = max_shape - 1.0
upper_post = 1.0 * max_shape - 1.0
lower_post = 1.0 / max_shape - 1.0
upper_cavi = 2.0 * max_shape - 1.0
lower_cavi = 0.5 / max_shape - 1.0
Expand All @@ -1086,6 +1079,7 @@ def posterior_damping(x):
assert 0.0 < d <= 1.0
return d

scale = np.ones(posterior.shape[0])
for i, p, c in edges:
# Damped downdate to ensure proper cavity distributions
parent_message = messages[i, 0] * scale[p]
Expand Down Expand Up @@ -1119,12 +1113,17 @@ def posterior_damping(x):
scale[p] *= parent_eta
scale[c] *= child_eta

# move the scaling term into the messages
for i, p, c in edges:
messages[i, 0] *= scale[p]
messages[i, 1] *= scale[c]

return 0.0 # TODO, placeholder

@staticmethod
@numba.njit("f8(i4[:], f8[:, :], f8[:, :], f8[:, :], f8[:], f8, i4, f8)")
@numba.njit("f8(i4[:], f8[:, :], f8[:, :], f8[:, :], f8, i4, f8)")
def propagate_prior(
nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol
nodes, global_prior, posterior, messages, max_shape, em_maxitt, em_reltol
):
"""
Update approximating factors for global prior at each node.
Expand Down Expand Up @@ -1164,16 +1163,16 @@ def posterior_damping(x):
return d

cavity = np.zeros(posterior.shape)
cavity[nodes] = posterior[nodes] - messages[nodes] * scale[nodes, np.newaxis]
cavity[nodes] = posterior[nodes] - messages[nodes]
global_prior, posterior[nodes] = mixture.fit_gamma_mixture(
global_prior, cavity[nodes], em_maxitt, em_reltol, False
)
messages[nodes] = (posterior[nodes] - cavity[nodes]) / scale[nodes, np.newaxis]
messages[nodes] = posterior[nodes] - cavity[nodes]

for n in nodes:
eta = posterior_damping(posterior[n])
posterior[n] *= eta
scale[n] *= eta
messages[n] *= eta

return 0.0

Expand All @@ -1189,7 +1188,6 @@ def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True):
self.global_prior,
self.posterior,
self.prior_messages,
self.scale,
max_shape,
em_maxitt,
em_reltol,
Expand All @@ -1201,7 +1199,6 @@ def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True):
self.likelihoods,
self.posterior,
self.messages,
self.scale,
self.log_partition,
max_shape,
min_kl,
Expand All @@ -1213,7 +1210,6 @@ def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True):
self.likelihoods,
self.posterior,
self.messages,
self.scale,
self.log_partition,
max_shape,
min_kl,
Expand Down

0 comments on commit 8c2e529

Please sign in to comment.