diff --git a/src/lineagetree/__init__.py b/src/lineagetree/__init__.py index 99948d6..8c53c0e 100644 --- a/src/lineagetree/__init__.py +++ b/src/lineagetree/__init__.py @@ -31,4 +31,4 @@ "read_from_txt_for_celegans", "read_from_txt_for_celegans_CAO", "LOADERS", -) \ No newline at end of file +) diff --git a/src/lineagetree/_core/_navigation.py b/src/lineagetree/_core/_navigation.py index 3918e84..305a8c2 100644 --- a/src/lineagetree/_core/_navigation.py +++ b/src/lineagetree/_core/_navigation.py @@ -380,7 +380,9 @@ def nodes_at_t( list of int list of ids of the nodes at time `t` spawned by `r` """ - if isinstance(r, int): + if isinstance(r, Iterable): + r = list(r) + else: r = [r] if t is None: t = lT.t_e diff --git a/src/lineagetree/_core/utils.py b/src/lineagetree/_core/utils.py index 26c9db0..c908b6d 100644 --- a/src/lineagetree/_core/utils.py +++ b/src/lineagetree/_core/utils.py @@ -64,9 +64,11 @@ def create_links_and_chains( return {"links": links, "times": times, "root": roots} -def _find_leaves_and_depths_iterative(lnks_tms: dict, root: int) -> tuple[list[int], dict[int, int]]: +def _find_leaves_and_depths_iterative( + lnks_tms: dict, root: int +) -> tuple[list[int], dict[int, int]]: """Find all leaves and calculate depths for all nodes using iterative approach. - + Parameters ---------- lnks_tms : dict @@ -86,91 +88,92 @@ def _find_leaves_and_depths_iterative(lnks_tms: dict, root: int) -> tuple[list[i times = lnks_tms["times"] links = lnks_tms["links"] - + # Stack for DFS: (node, parent_depth) stack = [(root, 0)] - + while stack: parent_node, parent_depth = stack.pop() depths[parent_node] = parent_depth succ = links.get(parent_node, []) - + if not succ: # This is a leaf leaves.append(parent_node) else: - if len(succ) == 1: # in this case, times[parent_node] is equal to the length of the chain + if ( + len(succ) == 1 + ): # in this case, times[parent_node] is equal to the length of the chain child_depth = parent_depth + times[parent_node] - 1 - else: # in this case, times[parent_node] is 0 + else: # in this case, times[parent_node] is 0 child_depth = parent_depth + 1 # Add children to stack (reverse order to maintain left-to-right traversal) for child in reversed(succ): stack.append((child, child_depth)) - + return leaves, depths -def _calculate_leaf_positions(leaves: list[int], width: int, xcenter: int) -> dict[int, float]: +def _calculate_leaf_positions( + leaves: list[int], width: int, xcenter: int +) -> dict[int, float]: """Calculate uniform x-positions for leaves.""" num_leaves = len(leaves) if num_leaves == 1: return {leaves[0]: xcenter} - + leaf_spacing = width / (num_leaves - 1) return { - leaf: xcenter - width/2 + i * leaf_spacing + leaf: xcenter - width / 2 + i * leaf_spacing for i, leaf in enumerate(leaves) } def _assign_positions_iterative( - lnks_tms: dict, - root: int, - depths: dict[int, int], + lnks_tms: dict, + root: int, + depths: dict[int, int], leaf_x_positions: dict[int, float], vert_gap: int, - ycenter: int + ycenter: int, ) -> dict[int, list[float]]: """Assign positions to nodes using iterative post-order traversal.""" pos_node = {} - + # First pass: build parent-child relationships and find processing order - children_map = lnks_tms["links"] + children_map = lnks_tms["links"] # Reverse-order traversal using two stacks stack1 = [root] stack2 = [] - + # This while loop stores nodes in stack2 so that children are processed before parents while stack1: node = stack1.pop() stack2.append(node) stack1.extend(children_map.get(node, [])) - + # Process nodes in reverse-order (children before parents) while stack2: node = stack2.pop() succ = children_map.get(node, []) - - if not succ: # This is a leaf + + if not succ: # This is a leaf pos_node[node] = [ - leaf_x_positions[node], - ycenter - depths[node] * vert_gap + leaf_x_positions[node], + ycenter - depths[node] * vert_gap, ] elif len(succ) == 1: # Single child: place directly above pos_node[node] = [ pos_node[succ[0]][0], - ycenter - depths[node] * vert_gap + ycenter - depths[node] * vert_gap, ] else: # Multiple children: place at center of children child_x_positions = [pos_node[child][0] for child in succ] center_x = sum(child_x_positions) / len(child_x_positions) - pos_node[node] = [ - center_x, - ycenter - depths[node] * vert_gap - ] - + pos_node[node] = [center_x, ycenter - depths[node] * vert_gap] + return pos_node @@ -203,18 +206,18 @@ def hierarchical_pos( """ if root not in lnks_tms["times"]: return None - + # Find all leaves and calculate depths leaves, depths = _find_leaves_and_depths_iterative(lnks_tms, root) - + # Calculate uniform x-positions for leaves leaf_x_positions = _calculate_leaf_positions(leaves, width, xcenter) - + # Assign positions using iterative approach pos_node = _assign_positions_iterative( lnks_tms, root, depths, leaf_x_positions, vert_gap, ycenter ) - + return pos_node diff --git a/src/lineagetree/_io/_loaders.py b/src/lineagetree/_io/_loaders.py index 3e8ff65..bd9546e 100644 --- a/src/lineagetree/_io/_loaders.py +++ b/src/lineagetree/_io/_loaders.py @@ -106,10 +106,10 @@ def _load_meshdict_from_bmfmesh(bmfmesh, pos_multipliers, translation): translation = np.array(translation, dtype=float) vertices = vertices * pos_multipliers + translation - return { # could be a class - 'vertices': vertices, - 'faces': faces, - 'center_mass': np.mean(vertices, axis=0), + return { # could be a class + "vertices": vertices, + "faces": faces, + "center_mass": np.mean(vertices, axis=0), } @@ -150,9 +150,11 @@ def read_from_bmf( for track in tracks: pred = None for t, mesh in track.meshes.items(): - mesh = _load_meshdict_from_bmfmesh(mesh, pos_multipliers, translation) + mesh = _load_meshdict_from_bmfmesh( + mesh, pos_multipliers, translation + ) pos[cell_id] = mesh["center_mass"] - + if store_meshes: lT_mesh[cell_id] = mesh