diff --git a/src/lineagetree/_core/utils.py b/src/lineagetree/_core/utils.py index cb7510c..26c9db0 100644 --- a/src/lineagetree/_core/utils.py +++ b/src/lineagetree/_core/utils.py @@ -79,28 +79,32 @@ def _find_leaves_and_depths_iterative(lnks_tms: dict, root: int) -> tuple[list[i leaves : list of int List of leaf node ids. depths : dict mapping int to int - Dictionary mapping node ids to their depth in the tree. + Dictionary mapping all node ids to their depth in the tree. """ leaves = [] depths = {} + + times = lnks_tms["times"] + links = lnks_tms["links"] - # Stack for DFS: (node, current_depth, parent_depth) + # Stack for DFS: (node, parent_depth) stack = [(root, 0)] while stack: - node, parent_depth = stack.pop() + parent_node, parent_depth = stack.pop() + depths[parent_node] = parent_depth + succ = links.get(parent_node, []) - node_depth = parent_depth + lnks_tms["times"].get(node, 0) - depths[node] = parent_depth - - succ = lnks_tms["links"].get(node, []) - if not succ: # This is a leaf - leaves.append(node) + leaves.append(parent_node) else: + if len(succ) == 1: # in this case, times[parent_node] is equal to the length of the chain + child_depth = parent_depth + times[parent_node] - 1 + else: # in this case, times[parent_node] is 0 + child_depth = parent_depth + 1 # Add children to stack (reverse order to maintain left-to-right traversal) for child in reversed(succ): - stack.append((child, node_depth)) + stack.append((child, child_depth)) return leaves, depths diff --git a/src/lineagetree/plot.py b/src/lineagetree/plot.py index fc2084b..f644b6f 100644 --- a/src/lineagetree/plot.py +++ b/src/lineagetree/plot.py @@ -211,7 +211,7 @@ def plot_all_lineages( lT: LineageTree, nodes: list | None = None, last_time_point_to_consider: int | None = None, - nrows: int = 2, + nrows: int = 1, figsize: tuple[int, int] = (10, 15), dpi: int = 100, fontsize: int = 15, @@ -232,7 +232,7 @@ def plot_all_lineages( For example if start_time is 10, then all trees that begin on tp 10 or before are calculated. Defaults to None, where it will plot all the roots that exist on `lT.t_b`. - nrows : int, default=2 + nrows : int, default=1 How many rows of plots should be printed. figsize : tuple, default=(10, 15) The size of the figure. @@ -290,7 +290,7 @@ def plot_all_lineages( raise Exception( f"Not enough axes, they should be at least {len(graphs)}." ) - flat_axes = axes.flatten() + flat_axes = axes.flatten() if hasattr(axes, "flatten") else [axes] ax2root = {} min_width, min_height = float("inf"), float("inf") for ax in flat_axes: @@ -326,8 +326,8 @@ def plot_all_lineages( "edgecolor": "green", }, ) - [figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()] - return axes.flatten()[0].get_figure(), axes, ax2root + [figure.delaxes(ax) for ax in flat_axes if not ax.has_data()] + return flat_axes[0].get_figure(), axes, ax2root def plot_subtree(