diff --git a/src/lineagetree/_core/utils.py b/src/lineagetree/_core/utils.py index 1e2c57e..cb7510c 100644 --- a/src/lineagetree/_core/utils.py +++ b/src/lineagetree/_core/utils.py @@ -64,10 +64,116 @@ def create_links_and_chains( return {"links": links, "times": times, "root": roots} +def _find_leaves_and_depths_iterative(lnks_tms: dict, root: int) -> tuple[list[int], dict[int, int]]: + """Find all leaves and calculate depths for all nodes using iterative approach. + + Parameters + ---------- + lnks_tms : dict + A dictionary created by create_links_and_chains. + root : int + The id of the root node. + + Returns + ------- + leaves : list of int + List of leaf node ids. + depths : dict mapping int to int + Dictionary mapping node ids to their depth in the tree. + """ + leaves = [] + depths = {} + + # Stack for DFS: (node, current_depth, parent_depth) + stack = [(root, 0)] + + while stack: + node, parent_depth = stack.pop() + + node_depth = parent_depth + lnks_tms["times"].get(node, 0) + depths[node] = parent_depth + + succ = lnks_tms["links"].get(node, []) + + if not succ: # This is a leaf + leaves.append(node) + else: + # Add children to stack (reverse order to maintain left-to-right traversal) + for child in reversed(succ): + stack.append((child, node_depth)) + + return leaves, depths + + +def _calculate_leaf_positions(leaves: list[int], width: int, xcenter: int) -> dict[int, float]: + """Calculate uniform x-positions for leaves.""" + num_leaves = len(leaves) + if num_leaves == 1: + return {leaves[0]: xcenter} + + leaf_spacing = width / (num_leaves - 1) + return { + leaf: xcenter - width/2 + i * leaf_spacing + for i, leaf in enumerate(leaves) + } + + +def _assign_positions_iterative( + lnks_tms: dict, + root: int, + depths: dict[int, int], + leaf_x_positions: dict[int, float], + vert_gap: int, + ycenter: int +) -> dict[int, list[float]]: + """Assign positions to nodes using iterative post-order traversal.""" + pos_node = {} + + # First pass: build parent-child relationships and find processing order + children_map = lnks_tms["links"] + + # Reverse-order traversal using two stacks + stack1 = [root] + stack2 = [] + + # This while loop stores nodes in stack2 so that children are processed before parents + while stack1: + node = stack1.pop() + stack2.append(node) + stack1.extend(children_map.get(node, [])) + + # Process nodes in reverse-order (children before parents) + while stack2: + node = stack2.pop() + succ = children_map.get(node, []) + + if not succ: # This is a leaf + pos_node[node] = [ + leaf_x_positions[node], + ycenter - depths[node] * vert_gap + ] + elif len(succ) == 1: + # Single child: place directly above + pos_node[node] = [ + pos_node[succ[0]][0], + ycenter - depths[node] * vert_gap + ] + else: + # Multiple children: place at center of children + child_x_positions = [pos_node[child][0] for child in succ] + center_x = sum(child_x_positions) / len(child_x_positions) + pos_node[node] = [ + center_x, + ycenter - depths[node] * vert_gap + ] + + return pos_node + + def hierarchical_pos( lnks_tms: dict, root, width=1000, vert_gap=2, xcenter=0, ycenter=0 ) -> dict[int, list[float]] | None: - """Calculates the position of each node on the tree graph. + """Calculates the position of each node on the tree graph with uniform leaf spacing. Parameters ---------- @@ -88,42 +194,23 @@ def hierarchical_pos( ------- dict mapping int to list of float Provides a dictionary that contains the id of each node as keys and its 2-d position on the - tree graph as values. + tree graph as values. Leaves are uniformly spaced on the x-axis. If the root requested does not exists, None is then returned """ - to_do = [root] if root not in lnks_tms["times"]: return None - pos_node = {root: [xcenter, ycenter]} - prev_width = {root: width / 2} - while to_do: - curr = to_do.pop() - succ = lnks_tms["links"].get(curr, []) - if len(succ) == 0: - continue - elif len(succ) == 1: - pos_node[succ[0]] = [ - pos_node[curr][0], - pos_node[curr][1] - - lnks_tms["times"].get(curr, 0) - + min(vert_gap, lnks_tms["times"].get(curr, 0)), - ] - to_do.extend(succ) - prev_width[succ[0]] = prev_width[curr] - elif len(succ) == 2: - pos_node[succ[0]] = [ - pos_node[curr][0] - prev_width[curr] / 2, - pos_node[curr][1] - vert_gap, - ] - pos_node[succ[1]] = [ - pos_node[curr][0] + prev_width[curr] / 2, - pos_node[curr][1] - vert_gap, - ] - to_do.extend(succ) - prev_width[succ[0]], prev_width[succ[1]] = ( - prev_width[curr] / 2, - prev_width[curr] / 2, - ) + + # Find all leaves and calculate depths + leaves, depths = _find_leaves_and_depths_iterative(lnks_tms, root) + + # Calculate uniform x-positions for leaves + leaf_x_positions = _calculate_leaf_positions(leaves, width, xcenter) + + # Assign positions using iterative approach + pos_node = _assign_positions_iterative( + lnks_tms, root, depths, leaf_x_positions, vert_gap, ycenter + ) + return pos_node