Skip to content

Commit

Permalink
Merge pull request #300 from hyanwong/tqdm-auto
Browse files Browse the repository at this point in the history
Make nicer progress bars for iteration
  • Loading branch information
hyanwong authored Jul 15, 2023
2 parents a83df44 + 56ed7a1 commit 7282a30
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
9 changes: 6 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,12 @@ def test_iterative_progress(self, tmp_path, capfd):
(out, err) = capfd.readouterr()
assert out == ""
# run_tsdate_cli print logging to stderr
assert err.count("Expectation Propagation: 100%") == 2
assert err.count("EP (iter 2, rootwards): 100%") == 1
assert err.count("rootwards): 100%") == err.count("leafwards): 100%")
assert err.count("Expectation Propagation: 100%") == 1
# The capfd fixture doesn't end up capturing progress bars with
# leave=False (they get deleted) so we can't see these in the output
# assert err.count("Iteration 1: 100%") == 1
# assert err.count("Rootwards: 100%") > 1
# assert err.count("Rootwards: 100%") == err.count("Leafwards: 100%")


class TestEndToEnd(RunCLI):
Expand Down
31 changes: 18 additions & 13 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,9 @@ def propagate(self, *, edges, desc=None, progress=None):
if progress is None:
progress = self.progress
# TODO: this will still converge if parallelized (potentially slower)
for edge in tqdm(edges, desc, total=self.ts.num_edges, disable=not progress):
for edge in tqdm(
edges, desc, total=self.ts.num_edges, disable=not progress, leave=False
):
if edge.child in self.fixednodes:
continue
if edge.parent in self.fixednodes:
Expand Down Expand Up @@ -1039,17 +1041,17 @@ def iterate(self, *, iter_num=None, progress=None):
Update edge factors from leaves to root then from root to leaves,
and return approximate log marginal likelihood
"""
desc = "Expectation Propagation"
if iter_num: # Show iteration number if not first iteration
desc = f"EP (iter {iter_num + 1:>2}, rootwards)"
self.propagate(
edges=self.edges_by_parent_asc(grouped=False), desc=desc, progress=progress
)
if iter_num:
desc = f"EP (iter {iter_num + 1:>2}, leafwards)"
self.propagate(
edges=self.edges_by_child_desc(grouped=False), desc=desc, progress=progress
)
if progress is None:
progress = self.progress
it = iter_num + 1 # For display purposes: show 1-based iteration
tasks = {
"Rootwards": self.edges_by_parent_asc,
"Leafwards": self.edges_by_child_desc,
}
for desc, func in tqdm(
tasks.items(), f"Iteration {it}", disable=not progress, leave=False
):
self.propagate(edges=func(grouped=False), desc=desc, progress=progress)
# TODO
# marginal_lik = np.sum(self.factor_norm)
# return marginal_lik
Expand Down Expand Up @@ -1538,7 +1540,10 @@ def variational_dates(
)

dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress)
for it in range(max_iterations):
for it in tqdm(
np.arange(max_iterations),
desc="Expectation Propagation",
):
dynamic_prog.iterate(iter_num=it)
posterior = dynamic_prog.posterior
tree_sequence, mn_post, _ = variational_mean_var(
Expand Down

0 comments on commit 7282a30

Please sign in to comment.