Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/lineagetree/_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,28 +79,32 @@ def _find_leaves_and_depths_iterative(lnks_tms: dict, root: int) -> tuple[list[i
leaves : list of int
List of leaf node ids.
depths : dict mapping int to int
Dictionary mapping node ids to their depth in the tree.
Dictionary mapping all node ids to their depth in the tree.
"""
leaves = []
depths = {}

times = lnks_tms["times"]
links = lnks_tms["links"]

# Stack for DFS: (node, current_depth, parent_depth)
# Stack for DFS: (node, parent_depth)
stack = [(root, 0)]

while stack:
node, parent_depth = stack.pop()
parent_node, parent_depth = stack.pop()
depths[parent_node] = parent_depth
succ = links.get(parent_node, [])

node_depth = parent_depth + lnks_tms["times"].get(node, 0)
depths[node] = parent_depth

succ = lnks_tms["links"].get(node, [])

if not succ: # This is a leaf
leaves.append(node)
leaves.append(parent_node)
else:
if len(succ) == 1: # in this case, times[parent_node] is equal to the length of the chain
child_depth = parent_depth + times[parent_node] - 1
else: # in this case, times[parent_node] is 0
child_depth = parent_depth + 1
# Add children to stack (reverse order to maintain left-to-right traversal)
for child in reversed(succ):
stack.append((child, node_depth))
stack.append((child, child_depth))

return leaves, depths

Expand Down
10 changes: 5 additions & 5 deletions src/lineagetree/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def plot_all_lineages(
lT: LineageTree,
nodes: list | None = None,
last_time_point_to_consider: int | None = None,
nrows: int = 2,
nrows: int = 1,
figsize: tuple[int, int] = (10, 15),
dpi: int = 100,
fontsize: int = 15,
Expand All @@ -232,7 +232,7 @@ def plot_all_lineages(
For example if start_time is 10, then all trees that begin
on tp 10 or before are calculated. Defaults to None, where
it will plot all the roots that exist on `lT.t_b`.
nrows : int, default=2
nrows : int, default=1
How many rows of plots should be printed.
figsize : tuple, default=(10, 15)
The size of the figure.
Expand Down Expand Up @@ -290,7 +290,7 @@ def plot_all_lineages(
raise Exception(
f"Not enough axes, they should be at least {len(graphs)}."
)
flat_axes = axes.flatten()
flat_axes = axes.flatten() if hasattr(axes, "flatten") else [axes]
ax2root = {}
min_width, min_height = float("inf"), float("inf")
for ax in flat_axes:
Expand Down Expand Up @@ -326,8 +326,8 @@ def plot_all_lineages(
"edgecolor": "green",
},
)
[figure.delaxes(ax) for ax in axes.flatten() if not ax.has_data()]
return axes.flatten()[0].get_figure(), axes, ax2root
[figure.delaxes(ax) for ax in flat_axes if not ax.has_data()]
return flat_axes[0].get_figure(), axes, ax2root


def plot_subtree(
Expand Down
Loading