Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lineagetree/_core/_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def add_root(lT: LineageTree, t: int, pos: list | None = None) -> int:
lT._successor[C_next] = ()
lT._predecessor[C_next] = ()
lT._time[C_next] = t
lT.pos[C_next] = pos if isinstance(pos, list) else []
if isinstance(pos, (list, tuple)):
lT.pos[C_next] = list(pos)
lT._changed_roots = True
return C_next

Expand Down
59 changes: 57 additions & 2 deletions src/lineagetree/lineage_tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,10 +617,37 @@ def plot_tree_distance_graphs(
matched_left = []
colors1 = {}
colors2 = {}
to_reverse_back = []

if style not in ("full", "downsampled"):
for m in btrc:
if m._left != -1 and m._right != -1:
cyc1 = tree1.lT.get_chain_of_node(corres1[m._left])
node_1 = corres1[m._left]
node_2 = corres2[m._right]
if (
node_1 not in tree1.lT.roots
and node_2 not in tree2.lT.roots
):
if tree1.lT.successor[
tree1.lT.predecessor[node_1][0]
].index(node_1) != tree2.lT.successor[
tree2.lT.predecessor[node_2][0]
].index(
node_2
):
tree2.lT._successor[
tree2.lT.predecessor[node_2][0]
] = list(
reversed(
tree2.lT.successor[
tree2.lT.predecessor[node_2][0]
]
)
)
to_reverse_back.append(
tree2.lT.predecessor[node_2][0]
)
cyc1 = tree1.lT.get_chain_of_node(node_1)
if len(cyc1) > 1:
node_1, *_, l_node_1 = cyc1
matched_left.append(node_1)
Expand All @@ -629,7 +656,7 @@ def plot_tree_distance_graphs(
node_1 = l_node_1 = cyc1.pop()
matched_left.append(node_1)

cyc2 = tree2.lT.get_chain_of_node(corres2[m._right])
cyc2 = tree2.lT.get_chain_of_node(node_2)
if len(cyc2) > 1:
node_2, *_, l_node_2 = cyc2
matched_right.append(node_2)
Expand Down Expand Up @@ -659,8 +686,32 @@ def plot_tree_distance_graphs(
else:
for m in btrc:
if m._left != -1 and m._right != -1:

node_1 = tree1.lT.get_chain_of_node(corres1[m._left])[0]
node_2 = tree2.lT.get_chain_of_node(corres2[m._right])[0]
if (
node_1 not in tree1.lT.roots
and node_2 not in tree2.lT.roots
):
if tree1.lT.successor[
tree1.lT.predecessor[node_1][0]
].index(node_1) != tree2.lT.successor[
tree2.lT.predecessor[node_2][0]
].index(
node_2
):
tree2.lT._successor[
tree2.lT.predecessor[node_2][0]
] = list(
reversed(
tree2.lT.successor[
tree2.lT.predecessor[node_2][0]
]
)
)
to_reverse_back.append(
tree2.lT.predecessor[node_2][0]
)
if (
tree1.lT.get_chain_of_node(node_1)[0] == node_1
or tree2.lT.get_chain_of_node(node_2)[0] == node_2
Expand Down Expand Up @@ -733,6 +784,10 @@ def plot_tree_distance_graphs(
lw=lw,
ax=ax[1],
)
for node in to_reverse_back:
tree2.lT._successor[node] = list(
reversed(tree2.lT._successor[node])
)
return ax[0].get_figure(), ax

def labelled_mappings(
Expand Down
26 changes: 24 additions & 2 deletions src/lineagetree/measure/uted.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import warnings
from collections.abc import Callable
from functools import partial
Expand Down Expand Up @@ -445,10 +446,21 @@ def plot_tree_distance_graphs(
matched_right = []
matched_left = []
colors = {}
to_reverse_back = []
if style not in ("full", "downsampled"):
for m in btrc:
if m._left != -1 and m._right != -1:
cyc1 = lT.get_chain_of_node(corres1[m._left])
node_1 = corres1[m._left]
node_2 = corres2[m._right]
if node_1 not in lT.roots and node_2 not in lT.roots:
if lT.successor[lT.predecessor[node_1][0]].index(
node_1
) != lT.successor[lT.predecessor[node_2][0]].index(node_2):
lT._successor[lT.predecessor[node_2][0]] = list(
reversed(lT.successor[lT.predecessor[node_2][0]])
)
to_reverse_back.append(lT.predecessor[node_2][0])
cyc1 = lT.get_chain_of_node(node_1)
if len(cyc1) > 1:
node_1, *_, l_node_1 = cyc1
matched_left.append(node_1)
Expand All @@ -457,7 +469,7 @@ def plot_tree_distance_graphs(
node_1 = l_node_1 = cyc1.pop()
matched_left.append(node_1)

cyc2 = lT.get_chain_of_node(corres2[m._right])
cyc2 = lT.get_chain_of_node(node_2)
if len(cyc2) > 1:
node_2, *_, l_node_2 = cyc2
matched_right.append(node_2)
Expand Down Expand Up @@ -487,6 +499,14 @@ def plot_tree_distance_graphs(
if m._left != -1 and m._right != -1:
node_1 = corres1[m._left]
node_2 = corres2[m._right]
if node_1 not in lT.roots and node_2 not in lT.roots:
if lT.successor[lT.predecessor[node_1][0]].index(
node_1
) != lT.successor[lT.predecessor[node_2][0]].index(node_2):
lT._successor[lT.predecessor[node_2][0]] = list(
reversed(lT.successor[lT.predecessor[node_2][0]])
)
to_reverse_back.append(lT.predecessor[node_2][0])

if (
lT.get_chain_of_node(node_1)[0] == node_1
Expand Down Expand Up @@ -547,6 +567,8 @@ def plot_tree_distance_graphs(
lw=lw,
ax=ax[1],
)
for node in to_reverse_back:
lT._successor[node] = list(reversed(lT._successor[node]))
return ax[0].get_figure(), ax


Expand Down
14 changes: 10 additions & 4 deletions src/lineagetree/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import warnings
from collections.abc import Iterable
from typing import Literal, TYPE_CHECKING
Expand All @@ -20,20 +21,24 @@ def __plot_nodes(
color: str | dict | list,
size: int | float,
ax: plt.Axes,
leaves: set,
default_color: str = "black",
**kwargs,
) -> None:
"""
Private method that plots the nodes of the tree.
"""

hier_no_leaves = copy.copy(hier)
for leaf in leaves:
hier_no_leaves.pop(leaf, None)
if isinstance(color, dict):
color = [color.get(k, default_color) for k in hier]
color = [color.get(k, default_color) for k in hier_no_leaves]
elif isinstance(color, str | list):
color = [
color if node in selected_nodes else default_color for node in hier
color if node in selected_nodes else default_color
for node in hier_no_leaves
]
hier_pos = np.array(list(hier.values()))
hier_pos = np.array(list(hier_no_leaves.values()))
ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs)


Expand Down Expand Up @@ -142,6 +147,7 @@ def draw_tree_graph(
size=size,
ax=ax,
default_color=default_color,
leaves={k for k, v in lnks_tms["links"].items() if not v},
**kwargs,
)
if not color_of_edges:
Expand Down
Loading