diff --git a/src/lineagetree/_core/_modifier.py b/src/lineagetree/_core/_modifier.py index 25303d5..4f2b1cb 100644 --- a/src/lineagetree/_core/_modifier.py +++ b/src/lineagetree/_core/_modifier.py @@ -96,7 +96,8 @@ def add_root(lT: LineageTree, t: int, pos: list | None = None) -> int: lT._successor[C_next] = () lT._predecessor[C_next] = () lT._time[C_next] = t - lT.pos[C_next] = pos if isinstance(pos, list) else [] + if isinstance(pos, (list, tuple)): + lT.pos[C_next] = list(pos) lT._changed_roots = True return C_next diff --git a/src/lineagetree/lineage_tree_manager.py b/src/lineagetree/lineage_tree_manager.py index 7b713df..9e8312e 100644 --- a/src/lineagetree/lineage_tree_manager.py +++ b/src/lineagetree/lineage_tree_manager.py @@ -617,10 +617,41 @@ def plot_tree_distance_graphs( matched_left = [] colors1 = {} colors2 = {} + to_reverse_back = [] + if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: - cyc1 = tree1.lT.get_chain_of_node(corres1[m._left]) + node_1 = corres1[m._left] + node_2 = corres2[m._right] + if ( + node_1 not in tree1.lT.roots + and node_2 not in tree2.lT.roots + ): + if tree2.lT.predecessor[ + node_2 + ] not in to_reverse_back and tree1.lT.successor[ + tree1.lT.predecessor[node_1][0] + ].index( + node_1 + ) != tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ].index( + node_2 + ): + tree2.lT._successor[ + tree2.lT.predecessor[node_2][0] + ] = list( + reversed( + tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ] + ) + ) + to_reverse_back.append( + tree2.lT.predecessor[node_2][0] + ) + cyc1 = tree1.lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 matched_left.append(node_1) @@ -629,7 +660,7 @@ def plot_tree_distance_graphs( node_1 = l_node_1 = cyc1.pop() matched_left.append(node_1) - cyc2 = tree2.lT.get_chain_of_node(corres2[m._right]) + cyc2 = tree2.lT.get_chain_of_node(node_2) if len(cyc2) > 1: node_2, *_, l_node_2 = cyc2 matched_right.append(node_2) @@ -659,12 +690,41 @@ def plot_tree_distance_graphs( else: for m in btrc: if m._left != -1 and m._right != -1: + node_1 = tree1.lT.get_chain_of_node(corres1[m._left])[0] node_2 = tree2.lT.get_chain_of_node(corres2[m._right])[0] if ( - tree1.lT.get_chain_of_node(node_1)[0] == node_1 - or tree2.lT.get_chain_of_node(node_2)[0] == node_2 - and (node_1 not in colors1 or node_2 not in colors2) + node_1 not in tree1.lT.roots + and node_2 not in tree2.lT.roots + ): + if tree2.lT.predecessor[ + node_2 + ] not in to_reverse_back and tree1.lT.successor[ + tree1.lT.predecessor[node_1][0] + ].index( + node_1 + ) != tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ].index( + node_2 + ): + tree2.lT._successor[ + tree2.lT.predecessor[node_2][0] + ] = list( + reversed( + tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ] + ) + ) + to_reverse_back.append( + tree2.lT.predecessor[node_2][0] + ) + + if not colors1.get( + tree1.lT.get_chain_of_node(node_1)[0] + ) or not colors2.get( + tree2.lT.get_chain_of_node(node_2)[0] ): matched_left.append(node_1) l_node_1 = tree1.lT.get_chain_of_node(node_1)[-1] @@ -687,6 +747,14 @@ def plot_tree_distance_graphs( tree2.get_norm(node_2), ) ) + for n1 in tree1.lT.get_chain_of_node(node_1): + colors1[n1] = colors1[ + tree1.lT.get_chain_of_node(node_1)[0] + ] + for n2 in tree2.lT.get_chain_of_node(node_2): + colors2[n2] = colors1[ + tree1.lT.get_chain_of_node(node_1)[0] + ] colors2[node_2] = colors1[node_1] colors1[tree1.lT.get_chain_of_node(node_1)[-1]] = ( colors1[node_1] @@ -695,14 +763,6 @@ def plot_tree_distance_graphs( colors2[node_2] ) - if tree1.lT.get_chain_of_node(node_1)[-1] != node_1: - matched_left.append( - tree1.lT.get_chain_of_node(node_1)[-1] - ) - if tree2.lT.get_chain_of_node(node_2)[-1] != node_2: - matched_right.append( - tree2.lT.get_chain_of_node(node_2)[-1] - ) if ax is None: fig, ax = plt.subplots(nrows=1, ncols=2) cmap = colormaps[colormap] @@ -733,6 +793,10 @@ def plot_tree_distance_graphs( lw=lw, ax=ax[1], ) + for node in to_reverse_back: + tree2.lT._successor[node] = list( + reversed(tree2.lT._successor[node]) + ) return ax[0].get_figure(), ax def labelled_mappings( diff --git a/src/lineagetree/measure/uted.py b/src/lineagetree/measure/uted.py index 719db07..9979129 100644 --- a/src/lineagetree/measure/uted.py +++ b/src/lineagetree/measure/uted.py @@ -1,11 +1,13 @@ from __future__ import annotations +import copy import warnings from collections.abc import Callable from functools import partial from itertools import combinations from typing import TYPE_CHECKING, Literal +import matplotlib as mpl import numpy as np import matplotlib.colors as mcolors from edist import uted @@ -445,10 +447,29 @@ def plot_tree_distance_graphs( matched_right = [] matched_left = [] colors = {} + to_reverse_back = [] if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: - cyc1 = lT.get_chain_of_node(corres1[m._left]) + node_1 = corres1[m._left] + node_2 = corres2[m._right] + if node_1 not in lT.roots and node_2 not in lT.roots: + if lT.predecessor[node_2][ + 0 + ] not in to_reverse_back and lT.successor[ + lT.predecessor[node_1][0] + ].index( + node_1 + ) != lT.successor[ + lT.predecessor[node_2][0] + ].index( + node_2 + ): + lT._successor[lT.predecessor[node_2][0]] = list( + reversed(lT.successor[lT.predecessor[node_2][0]]) + ) + to_reverse_back.append(lT.predecessor[node_2][0]) + cyc1 = lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 matched_left.append(node_1) @@ -457,7 +478,7 @@ def plot_tree_distance_graphs( node_1 = l_node_1 = cyc1.pop() matched_left.append(node_1) - cyc2 = lT.get_chain_of_node(corres2[m._right]) + cyc2 = lT.get_chain_of_node(node_2) if len(cyc2) > 1: node_2, *_, l_node_2 = cyc2 matched_right.append(node_2) @@ -485,35 +506,52 @@ def plot_tree_distance_graphs( else: for m in btrc: if m._left != -1 and m._right != -1: - node_1 = corres1[m._left] - node_2 = corres2[m._right] - - if ( - lT.get_chain_of_node(node_1)[0] == node_1 - or lT.get_chain_of_node(node_2)[0] == node_2 - and (node_1 not in colors or node_2 not in colors) - ): + node_1 = lT.get_chain_of_node(corres1[m._left])[0] + node_2 = lT.get_chain_of_node(corres2[m._right])[0] + if node_1 not in lT.roots and node_2 not in lT.roots: + if lT.predecessor[node_2][ + 0 + ] not in to_reverse_back and lT._successor[ + lT.predecessor[node_1][0] + ].index( + node_1 + ) != lT._successor[ + lT.predecessor[node_2][0] + ].index( + node_2 + ): + lT._successor[lT.predecessor[node_2][0]] = list( + reversed(lT.successor[lT.predecessor[node_2][0]]) + ) + to_reverse_back.append(lT.predecessor[node_2][0]) + + if not colors.get( + lT.get_chain_of_node(node_1)[0] + ) or not colors.get(lT.get_chain_of_node(node_2)[0]): matched_left.append(node_1) l_node_1 = lT.get_chain_of_node(node_1)[-1] matched_left.append(l_node_1) matched_right.append(node_2) l_node_2 = lT.get_chain_of_node(node_2)[-1] matched_right.append(l_node_2) - colors[node_1] = __calculate_distance_of_sub_tree( - lT, - node_1, - node_2, - btrc, - corres1, - corres2, - delta_tmp, - lT.norm_dict[norm], - tree1.get_norm(node_1), - tree2.get_norm(node_2), + colors[lT.get_chain_of_node(node_1)[0]] = ( + __calculate_distance_of_sub_tree( + lT, + node_1, + node_2, + btrc, + corres1, + corres2, + delta_tmp, + lT.norm_dict[norm], + tree1.get_norm(node_1), + tree2.get_norm(node_2), + ) ) - colors[l_node_1] = colors[node_1] - colors[node_2] = colors[node_1] - colors[l_node_2] = colors[node_1] + for n1 in lT.get_chain_of_node(node_1): + colors[n1] = colors[lT.get_chain_of_node(node_1)[0]] + for n2 in lT.get_chain_of_node(node_2): + colors[n2] = colors[lT.get_chain_of_node(node_1)[0]] if ax is None: fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True) cmap = colormaps[colormap] @@ -547,6 +585,17 @@ def plot_tree_distance_graphs( lw=lw, ax=ax[1], ) + for node in to_reverse_back: + lT._successor[node] = list(reversed(lT._successor[node])) + + cax = fig.add_axes([0.25, 0.10, 0.50, 0.01]) + for x in ax: + for s in x.spines.values(): + s.set_visible(False) + sm = mpl.cm.ScalarMappable(cmap=cmap, norm=c_norm) + sm.set_array([]) + fig.colorbar(sm, cax=cax, orientation="horizontal") + return ax[0].get_figure(), ax