diff --git a/tsdate/core.py b/tsdate/core.py index fbe5a736..ee381c37 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -997,10 +997,6 @@ def __init__(self, *args, global_prior, **kwargs): self.posterior[p] += self.likelihoods[i] # self.log_partition[i] += ... # TODO - # scaling factor for posterior: posterior is the sum of messages and - # prior, multiplied by a scaling term in (0, 1] - self.scale = np.ones(self.ts.num_nodes) - @staticmethod def factorize(edge_list, fixed_nodes): """Split edges into internal and external""" @@ -1019,13 +1015,12 @@ def factorize(edge_list, fixed_nodes): return internal, external @staticmethod - @numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)") + @numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8, b1)") def propagate_likelihood( edges, likelihoods, posterior, messages, - scale, log_partition, max_shape, min_kl, @@ -1044,8 +1039,6 @@ def propagate_likelihood( :param ndarray messages: array of dimension `[num_edges, 2, 2]` containing parent/child messages (natural parameters) for each edge, updated in-place. - :param ndarray scale: array of dimension `[num_nodes]` - containing the scaling factor for the posterior, updated in place :param ndarray log_partition: array of dimension `[num_edges]` containing the approximate normalizing constants per edge, updated in-place. @@ -1055,7 +1048,7 @@ def propagate_likelihood( # Bound the shape parameter for the posterior and cavity distributions # so that lower_cavi < lower_post < upper_post < upper_cavi. - upper_post = max_shape - 1.0 + upper_post = 1.0 * max_shape - 1.0 lower_post = 1.0 / max_shape - 1.0 upper_cavi = 2.0 * max_shape - 1.0 lower_cavi = 0.5 / max_shape - 1.0 @@ -1082,6 +1075,7 @@ def posterior_damping(x): assert 0.0 < d <= 1.0 return d + scale = np.ones(posterior.shape[0]) for i, p, c in edges: # Damped downdate to ensure proper cavity distributions parent_message = messages[i, 0] * scale[p] @@ -1115,12 +1109,17 @@ def posterior_damping(x): scale[p] *= parent_eta scale[c] *= child_eta + # move the scaling term into the messages + for i, p, c in edges: + messages[i, 0] *= scale[p] + messages[i, 1] *= scale[c] + return 0.0 # TODO, placeholder @staticmethod - @numba.njit("f8(i4[:], f8[:, :], f8[:, :], f8[:, :], f8[:], f8, i4, f8)") + @numba.njit("f8(i4[:], f8[:, :], f8[:, :], f8[:, :], f8, i4, f8)") def propagate_prior( - nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol + nodes, global_prior, posterior, messages, max_shape, em_maxitt, em_reltol ): """ Update approximating factors for global prior at each node. @@ -1160,16 +1159,16 @@ def posterior_damping(x): return d cavity = np.zeros(posterior.shape) - cavity[nodes] = posterior[nodes] - messages[nodes] * scale[nodes, np.newaxis] + cavity[nodes] = posterior[nodes] - messages[nodes] global_prior, posterior[nodes] = mixture.fit_gamma_mixture( global_prior, cavity[nodes], em_maxitt, em_reltol, False ) - messages[nodes] = (posterior[nodes] - cavity[nodes]) / scale[nodes, np.newaxis] + messages[nodes] = posterior[nodes] - cavity[nodes] for n in nodes: eta = posterior_damping(posterior[n]) posterior[n] *= eta - scale[n] *= eta + messages[n] *= eta return 0.0 @@ -1185,7 +1184,6 @@ def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True): self.global_prior, self.posterior, self.prior_messages, - self.scale, max_shape, em_maxitt, em_reltol, @@ -1197,7 +1195,6 @@ def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True): self.likelihoods, self.posterior, self.messages, - self.scale, self.log_partition, max_shape, min_kl, @@ -1209,7 +1206,6 @@ def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True): self.likelihoods, self.posterior, self.messages, - self.scale, self.log_partition, max_shape, min_kl,