Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lineagetree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@
"read_from_txt_for_celegans",
"read_from_txt_for_celegans_CAO",
"LOADERS",
)
)
4 changes: 3 additions & 1 deletion src/lineagetree/_core/_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def nodes_at_t(
list of int
list of ids of the nodes at time `t` spawned by `r`
"""
if isinstance(r, int):
if isinstance(r, Iterable):
r = list(r)
else:
r = [r]
if t is None:
t = lT.t_e
Expand Down
69 changes: 36 additions & 33 deletions src/lineagetree/_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def create_links_and_chains(
return {"links": links, "times": times, "root": roots}


def _find_leaves_and_depths_iterative(lnks_tms: dict, root: int) -> tuple[list[int], dict[int, int]]:
def _find_leaves_and_depths_iterative(
lnks_tms: dict, root: int
) -> tuple[list[int], dict[int, int]]:
"""Find all leaves and calculate depths for all nodes using iterative approach.

Parameters
----------
lnks_tms : dict
Expand All @@ -86,91 +88,92 @@ def _find_leaves_and_depths_iterative(lnks_tms: dict, root: int) -> tuple[list[i

times = lnks_tms["times"]
links = lnks_tms["links"]

# Stack for DFS: (node, parent_depth)
stack = [(root, 0)]

while stack:
parent_node, parent_depth = stack.pop()
depths[parent_node] = parent_depth
succ = links.get(parent_node, [])

if not succ: # This is a leaf
leaves.append(parent_node)
else:
if len(succ) == 1: # in this case, times[parent_node] is equal to the length of the chain
if (
len(succ) == 1
): # in this case, times[parent_node] is equal to the length of the chain
child_depth = parent_depth + times[parent_node] - 1
else: # in this case, times[parent_node] is 0
else: # in this case, times[parent_node] is 0
child_depth = parent_depth + 1
# Add children to stack (reverse order to maintain left-to-right traversal)
for child in reversed(succ):
stack.append((child, child_depth))

return leaves, depths


def _calculate_leaf_positions(leaves: list[int], width: int, xcenter: int) -> dict[int, float]:
def _calculate_leaf_positions(
leaves: list[int], width: int, xcenter: int
) -> dict[int, float]:
"""Calculate uniform x-positions for leaves."""
num_leaves = len(leaves)
if num_leaves == 1:
return {leaves[0]: xcenter}

leaf_spacing = width / (num_leaves - 1)
return {
leaf: xcenter - width/2 + i * leaf_spacing
leaf: xcenter - width / 2 + i * leaf_spacing
for i, leaf in enumerate(leaves)
}


def _assign_positions_iterative(
lnks_tms: dict,
root: int,
depths: dict[int, int],
lnks_tms: dict,
root: int,
depths: dict[int, int],
leaf_x_positions: dict[int, float],
vert_gap: int,
ycenter: int
ycenter: int,
) -> dict[int, list[float]]:
"""Assign positions to nodes using iterative post-order traversal."""
pos_node = {}

# First pass: build parent-child relationships and find processing order
children_map = lnks_tms["links"]
children_map = lnks_tms["links"]

# Reverse-order traversal using two stacks
stack1 = [root]
stack2 = []

# This while loop stores nodes in stack2 so that children are processed before parents
while stack1:
node = stack1.pop()
stack2.append(node)
stack1.extend(children_map.get(node, []))

# Process nodes in reverse-order (children before parents)
while stack2:
node = stack2.pop()
succ = children_map.get(node, [])
if not succ: # This is a leaf

if not succ: # This is a leaf
pos_node[node] = [
leaf_x_positions[node],
ycenter - depths[node] * vert_gap
leaf_x_positions[node],
ycenter - depths[node] * vert_gap,
]
elif len(succ) == 1:
# Single child: place directly above
pos_node[node] = [
pos_node[succ[0]][0],
ycenter - depths[node] * vert_gap
ycenter - depths[node] * vert_gap,
]
else:
# Multiple children: place at center of children
child_x_positions = [pos_node[child][0] for child in succ]
center_x = sum(child_x_positions) / len(child_x_positions)
pos_node[node] = [
center_x,
ycenter - depths[node] * vert_gap
]

pos_node[node] = [center_x, ycenter - depths[node] * vert_gap]

return pos_node


Expand Down Expand Up @@ -203,18 +206,18 @@ def hierarchical_pos(
"""
if root not in lnks_tms["times"]:
return None

# Find all leaves and calculate depths
leaves, depths = _find_leaves_and_depths_iterative(lnks_tms, root)

# Calculate uniform x-positions for leaves
leaf_x_positions = _calculate_leaf_positions(leaves, width, xcenter)

# Assign positions using iterative approach
pos_node = _assign_positions_iterative(
lnks_tms, root, depths, leaf_x_positions, vert_gap, ycenter
)

return pos_node


Expand Down
14 changes: 8 additions & 6 deletions src/lineagetree/_io/_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def _load_meshdict_from_bmfmesh(bmfmesh, pos_multipliers, translation):
translation = np.array(translation, dtype=float)
vertices = vertices * pos_multipliers + translation

return { # could be a class
'vertices': vertices,
'faces': faces,
'center_mass': np.mean(vertices, axis=0),
return { # could be a class
"vertices": vertices,
"faces": faces,
"center_mass": np.mean(vertices, axis=0),
}


Expand Down Expand Up @@ -150,9 +150,11 @@ def read_from_bmf(
for track in tracks:
pred = None
for t, mesh in track.meshes.items():
mesh = _load_meshdict_from_bmfmesh(mesh, pos_multipliers, translation)
mesh = _load_meshdict_from_bmfmesh(
mesh, pos_multipliers, translation
)
pos[cell_id] = mesh["center_mass"]

if store_meshes:
lT_mesh[cell_id] = mesh

Expand Down
Loading