Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 120 additions & 33 deletions src/lineagetree/_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,116 @@ def create_links_and_chains(
return {"links": links, "times": times, "root": roots}


def _find_leaves_and_depths_iterative(lnks_tms: dict, root: int) -> tuple[list[int], dict[int, int]]:
"""Find all leaves and calculate depths for all nodes using iterative approach.

Parameters
----------
lnks_tms : dict
A dictionary created by create_links_and_chains.
root : int
The id of the root node.

Returns
-------
leaves : list of int
List of leaf node ids.
depths : dict mapping int to int
Dictionary mapping node ids to their depth in the tree.
"""
leaves = []
depths = {}

# Stack for DFS: (node, current_depth, parent_depth)
stack = [(root, 0)]

while stack:
node, parent_depth = stack.pop()

node_depth = parent_depth + lnks_tms["times"].get(node, 0)
depths[node] = parent_depth

succ = lnks_tms["links"].get(node, [])

if not succ: # This is a leaf
leaves.append(node)
else:
# Add children to stack (reverse order to maintain left-to-right traversal)
for child in reversed(succ):
stack.append((child, node_depth))

return leaves, depths


def _calculate_leaf_positions(leaves: list[int], width: int, xcenter: int) -> dict[int, float]:
"""Calculate uniform x-positions for leaves."""
num_leaves = len(leaves)
if num_leaves == 1:
return {leaves[0]: xcenter}

leaf_spacing = width / (num_leaves - 1)
return {
leaf: xcenter - width/2 + i * leaf_spacing
for i, leaf in enumerate(leaves)
}


def _assign_positions_iterative(
lnks_tms: dict,
root: int,
depths: dict[int, int],
leaf_x_positions: dict[int, float],
vert_gap: int,
ycenter: int
) -> dict[int, list[float]]:
"""Assign positions to nodes using iterative post-order traversal."""
pos_node = {}

# First pass: build parent-child relationships and find processing order
children_map = lnks_tms["links"]

# Reverse-order traversal using two stacks
stack1 = [root]
stack2 = []

# This while loop stores nodes in stack2 so that children are processed before parents
while stack1:
node = stack1.pop()
stack2.append(node)
stack1.extend(children_map.get(node, []))

# Process nodes in reverse-order (children before parents)
while stack2:
node = stack2.pop()
succ = children_map.get(node, [])

if not succ: # This is a leaf
pos_node[node] = [
leaf_x_positions[node],
ycenter - depths[node] * vert_gap
]
elif len(succ) == 1:
# Single child: place directly above
pos_node[node] = [
pos_node[succ[0]][0],
ycenter - depths[node] * vert_gap
]
else:
# Multiple children: place at center of children
child_x_positions = [pos_node[child][0] for child in succ]
center_x = sum(child_x_positions) / len(child_x_positions)
pos_node[node] = [
center_x,
ycenter - depths[node] * vert_gap
]

return pos_node


def hierarchical_pos(
lnks_tms: dict, root, width=1000, vert_gap=2, xcenter=0, ycenter=0
) -> dict[int, list[float]] | None:
"""Calculates the position of each node on the tree graph.
"""Calculates the position of each node on the tree graph with uniform leaf spacing.

Parameters
----------
Expand All @@ -88,42 +194,23 @@ def hierarchical_pos(
-------
dict mapping int to list of float
Provides a dictionary that contains the id of each node as keys and its 2-d position on the
tree graph as values.
tree graph as values. Leaves are uniformly spaced on the x-axis.
If the root requested does not exists, None is then returned
"""
to_do = [root]
if root not in lnks_tms["times"]:
return None
pos_node = {root: [xcenter, ycenter]}
prev_width = {root: width / 2}
while to_do:
curr = to_do.pop()
succ = lnks_tms["links"].get(curr, [])
if len(succ) == 0:
continue
elif len(succ) == 1:
pos_node[succ[0]] = [
pos_node[curr][0],
pos_node[curr][1]
- lnks_tms["times"].get(curr, 0)
+ min(vert_gap, lnks_tms["times"].get(curr, 0)),
]
to_do.extend(succ)
prev_width[succ[0]] = prev_width[curr]
elif len(succ) == 2:
pos_node[succ[0]] = [
pos_node[curr][0] - prev_width[curr] / 2,
pos_node[curr][1] - vert_gap,
]
pos_node[succ[1]] = [
pos_node[curr][0] + prev_width[curr] / 2,
pos_node[curr][1] - vert_gap,
]
to_do.extend(succ)
prev_width[succ[0]], prev_width[succ[1]] = (
prev_width[curr] / 2,
prev_width[curr] / 2,
)

# Find all leaves and calculate depths
leaves, depths = _find_leaves_and_depths_iterative(lnks_tms, root)

# Calculate uniform x-positions for leaves
leaf_x_positions = _calculate_leaf_positions(leaves, width, xcenter)

# Assign positions using iterative approach
pos_node = _assign_positions_iterative(
lnks_tms, root, depths, leaf_x_positions, vert_gap, ycenter
)

return pos_node


Expand Down
Loading