Skip to content
3 changes: 2 additions & 1 deletion src/lineagetree/_core/_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def add_root(lT: LineageTree, t: int, pos: list | None = None) -> int:
lT._successor[C_next] = ()
lT._predecessor[C_next] = ()
lT._time[C_next] = t
lT.pos[C_next] = pos if isinstance(pos, list) else []
if isinstance(pos, (list, tuple)):
lT.pos[C_next] = list(pos)
lT._changed_roots = True
return C_next

Expand Down
90 changes: 77 additions & 13 deletions src/lineagetree/lineage_tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,10 +617,41 @@ def plot_tree_distance_graphs(
matched_left = []
colors1 = {}
colors2 = {}
to_reverse_back = []

if style not in ("full", "downsampled"):
for m in btrc:
if m._left != -1 and m._right != -1:
cyc1 = tree1.lT.get_chain_of_node(corres1[m._left])
node_1 = corres1[m._left]
node_2 = corres2[m._right]
if (
node_1 not in tree1.lT.roots
and node_2 not in tree2.lT.roots
):
if tree2.lT.predecessor[
node_2
] not in to_reverse_back and tree1.lT.successor[
tree1.lT.predecessor[node_1][0]
].index(
node_1
) != tree2.lT.successor[
tree2.lT.predecessor[node_2][0]
].index(
node_2
):
tree2.lT._successor[
tree2.lT.predecessor[node_2][0]
] = list(
reversed(
tree2.lT.successor[
tree2.lT.predecessor[node_2][0]
]
)
)
to_reverse_back.append(
tree2.lT.predecessor[node_2][0]
)
cyc1 = tree1.lT.get_chain_of_node(node_1)
if len(cyc1) > 1:
node_1, *_, l_node_1 = cyc1
matched_left.append(node_1)
Expand All @@ -629,7 +660,7 @@ def plot_tree_distance_graphs(
node_1 = l_node_1 = cyc1.pop()
matched_left.append(node_1)

cyc2 = tree2.lT.get_chain_of_node(corres2[m._right])
cyc2 = tree2.lT.get_chain_of_node(node_2)
if len(cyc2) > 1:
node_2, *_, l_node_2 = cyc2
matched_right.append(node_2)
Expand Down Expand Up @@ -659,12 +690,41 @@ def plot_tree_distance_graphs(
else:
for m in btrc:
if m._left != -1 and m._right != -1:

node_1 = tree1.lT.get_chain_of_node(corres1[m._left])[0]
node_2 = tree2.lT.get_chain_of_node(corres2[m._right])[0]
if (
tree1.lT.get_chain_of_node(node_1)[0] == node_1
or tree2.lT.get_chain_of_node(node_2)[0] == node_2
and (node_1 not in colors1 or node_2 not in colors2)
node_1 not in tree1.lT.roots
and node_2 not in tree2.lT.roots
):
if tree2.lT.predecessor[
node_2
] not in to_reverse_back and tree1.lT.successor[
tree1.lT.predecessor[node_1][0]
].index(
node_1
) != tree2.lT.successor[
tree2.lT.predecessor[node_2][0]
].index(
node_2
):
tree2.lT._successor[
tree2.lT.predecessor[node_2][0]
] = list(
reversed(
tree2.lT.successor[
tree2.lT.predecessor[node_2][0]
]
)
)
to_reverse_back.append(
tree2.lT.predecessor[node_2][0]
)

if not colors1.get(
tree1.lT.get_chain_of_node(node_1)[0]
) or not colors2.get(
tree2.lT.get_chain_of_node(node_2)[0]
):
matched_left.append(node_1)
l_node_1 = tree1.lT.get_chain_of_node(node_1)[-1]
Expand All @@ -687,6 +747,14 @@ def plot_tree_distance_graphs(
tree2.get_norm(node_2),
)
)
for n1 in tree1.lT.get_chain_of_node(node_1):
colors1[n1] = colors1[
tree1.lT.get_chain_of_node(node_1)[0]
]
for n2 in tree2.lT.get_chain_of_node(node_2):
colors2[n2] = colors1[
tree1.lT.get_chain_of_node(node_1)[0]
]
colors2[node_2] = colors1[node_1]
colors1[tree1.lT.get_chain_of_node(node_1)[-1]] = (
colors1[node_1]
Expand All @@ -695,14 +763,6 @@ def plot_tree_distance_graphs(
colors2[node_2]
)

if tree1.lT.get_chain_of_node(node_1)[-1] != node_1:
matched_left.append(
tree1.lT.get_chain_of_node(node_1)[-1]
)
if tree2.lT.get_chain_of_node(node_2)[-1] != node_2:
matched_right.append(
tree2.lT.get_chain_of_node(node_2)[-1]
)
if ax is None:
fig, ax = plt.subplots(nrows=1, ncols=2)
cmap = colormaps[colormap]
Expand Down Expand Up @@ -733,6 +793,10 @@ def plot_tree_distance_graphs(
lw=lw,
ax=ax[1],
)
for node in to_reverse_back:
tree2.lT._successor[node] = list(
reversed(tree2.lT._successor[node])
)
return ax[0].get_figure(), ax

def labelled_mappings(
Expand Down
97 changes: 73 additions & 24 deletions src/lineagetree/measure/uted.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import copy
import warnings
from collections.abc import Callable
from functools import partial
from itertools import combinations
from typing import TYPE_CHECKING, Literal

import matplotlib as mpl
import numpy as np
import matplotlib.colors as mcolors
from edist import uted
Expand Down Expand Up @@ -445,10 +447,29 @@ def plot_tree_distance_graphs(
matched_right = []
matched_left = []
colors = {}
to_reverse_back = []
if style not in ("full", "downsampled"):
for m in btrc:
if m._left != -1 and m._right != -1:
cyc1 = lT.get_chain_of_node(corres1[m._left])
node_1 = corres1[m._left]
node_2 = corres2[m._right]
if node_1 not in lT.roots and node_2 not in lT.roots:
if lT.predecessor[node_2][
0
] not in to_reverse_back and lT.successor[
lT.predecessor[node_1][0]
].index(
node_1
) != lT.successor[
lT.predecessor[node_2][0]
].index(
node_2
):
lT._successor[lT.predecessor[node_2][0]] = list(
reversed(lT.successor[lT.predecessor[node_2][0]])
)
to_reverse_back.append(lT.predecessor[node_2][0])
cyc1 = lT.get_chain_of_node(node_1)
if len(cyc1) > 1:
node_1, *_, l_node_1 = cyc1
matched_left.append(node_1)
Expand All @@ -457,7 +478,7 @@ def plot_tree_distance_graphs(
node_1 = l_node_1 = cyc1.pop()
matched_left.append(node_1)

cyc2 = lT.get_chain_of_node(corres2[m._right])
cyc2 = lT.get_chain_of_node(node_2)
if len(cyc2) > 1:
node_2, *_, l_node_2 = cyc2
matched_right.append(node_2)
Expand Down Expand Up @@ -485,35 +506,52 @@ def plot_tree_distance_graphs(
else:
for m in btrc:
if m._left != -1 and m._right != -1:
node_1 = corres1[m._left]
node_2 = corres2[m._right]

if (
lT.get_chain_of_node(node_1)[0] == node_1
or lT.get_chain_of_node(node_2)[0] == node_2
and (node_1 not in colors or node_2 not in colors)
):
node_1 = lT.get_chain_of_node(corres1[m._left])[0]
node_2 = lT.get_chain_of_node(corres2[m._right])[0]
if node_1 not in lT.roots and node_2 not in lT.roots:
if lT.predecessor[node_2][
0
] not in to_reverse_back and lT._successor[
lT.predecessor[node_1][0]
].index(
node_1
) != lT._successor[
lT.predecessor[node_2][0]
].index(
node_2
):
lT._successor[lT.predecessor[node_2][0]] = list(
reversed(lT.successor[lT.predecessor[node_2][0]])
)
to_reverse_back.append(lT.predecessor[node_2][0])

if not colors.get(
lT.get_chain_of_node(node_1)[0]
) or not colors.get(lT.get_chain_of_node(node_2)[0]):
matched_left.append(node_1)
l_node_1 = lT.get_chain_of_node(node_1)[-1]
matched_left.append(l_node_1)
matched_right.append(node_2)
l_node_2 = lT.get_chain_of_node(node_2)[-1]
matched_right.append(l_node_2)
colors[node_1] = __calculate_distance_of_sub_tree(
lT,
node_1,
node_2,
btrc,
corres1,
corres2,
delta_tmp,
lT.norm_dict[norm],
tree1.get_norm(node_1),
tree2.get_norm(node_2),
colors[lT.get_chain_of_node(node_1)[0]] = (
__calculate_distance_of_sub_tree(
lT,
node_1,
node_2,
btrc,
corres1,
corres2,
delta_tmp,
lT.norm_dict[norm],
tree1.get_norm(node_1),
tree2.get_norm(node_2),
)
)
colors[l_node_1] = colors[node_1]
colors[node_2] = colors[node_1]
colors[l_node_2] = colors[node_1]
for n1 in lT.get_chain_of_node(node_1):
colors[n1] = colors[lT.get_chain_of_node(node_1)[0]]
for n2 in lT.get_chain_of_node(node_2):
colors[n2] = colors[lT.get_chain_of_node(node_1)[0]]
if ax is None:
fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True)
cmap = colormaps[colormap]
Expand Down Expand Up @@ -547,6 +585,17 @@ def plot_tree_distance_graphs(
lw=lw,
ax=ax[1],
)
for node in to_reverse_back:
lT._successor[node] = list(reversed(lT._successor[node]))

cax = fig.add_axes([0.25, 0.10, 0.50, 0.01])
for x in ax:
for s in x.spines.values():
s.set_visible(False)
sm = mpl.cm.ScalarMappable(cmap=cmap, norm=c_norm)
sm.set_array([])
fig.colorbar(sm, cax=cax, orientation="horizontal")

return ax[0].get_figure(), ax


Expand Down
Loading