diff --git a/src/lineagetree/_core/_modifier.py b/src/lineagetree/_core/_modifier.py index 25303d5..4f2b1cb 100644 --- a/src/lineagetree/_core/_modifier.py +++ b/src/lineagetree/_core/_modifier.py @@ -96,7 +96,8 @@ def add_root(lT: LineageTree, t: int, pos: list | None = None) -> int: lT._successor[C_next] = () lT._predecessor[C_next] = () lT._time[C_next] = t - lT.pos[C_next] = pos if isinstance(pos, list) else [] + if isinstance(pos, (list, tuple)): + lT.pos[C_next] = list(pos) lT._changed_roots = True return C_next diff --git a/src/lineagetree/lineage_tree_manager.py b/src/lineagetree/lineage_tree_manager.py index 7b713df..6889cd7 100644 --- a/src/lineagetree/lineage_tree_manager.py +++ b/src/lineagetree/lineage_tree_manager.py @@ -617,10 +617,37 @@ def plot_tree_distance_graphs( matched_left = [] colors1 = {} colors2 = {} + to_reverse_back = [] + if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: - cyc1 = tree1.lT.get_chain_of_node(corres1[m._left]) + node_1 = corres1[m._left] + node_2 = corres2[m._right] + if ( + node_1 not in tree1.lT.roots + and node_2 not in tree2.lT.roots + ): + if tree1.lT.successor[ + tree1.lT.predecessor[node_1][0] + ].index(node_1) != tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ].index( + node_2 + ): + tree2.lT._successor[ + tree2.lT.predecessor[node_2][0] + ] = list( + reversed( + tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ] + ) + ) + to_reverse_back.append( + tree2.lT.predecessor[node_2][0] + ) + cyc1 = tree1.lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 matched_left.append(node_1) @@ -629,7 +656,7 @@ def plot_tree_distance_graphs( node_1 = l_node_1 = cyc1.pop() matched_left.append(node_1) - cyc2 = tree2.lT.get_chain_of_node(corres2[m._right]) + cyc2 = tree2.lT.get_chain_of_node(node_2) if len(cyc2) > 1: node_2, *_, l_node_2 = cyc2 matched_right.append(node_2) @@ -659,8 +686,32 @@ def plot_tree_distance_graphs( else: for m in btrc: if m._left != -1 and m._right != -1: + node_1 = tree1.lT.get_chain_of_node(corres1[m._left])[0] node_2 = tree2.lT.get_chain_of_node(corres2[m._right])[0] + if ( + node_1 not in tree1.lT.roots + and node_2 not in tree2.lT.roots + ): + if tree1.lT.successor[ + tree1.lT.predecessor[node_1][0] + ].index(node_1) != tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ].index( + node_2 + ): + tree2.lT._successor[ + tree2.lT.predecessor[node_2][0] + ] = list( + reversed( + tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ] + ) + ) + to_reverse_back.append( + tree2.lT.predecessor[node_2][0] + ) if ( tree1.lT.get_chain_of_node(node_1)[0] == node_1 or tree2.lT.get_chain_of_node(node_2)[0] == node_2 @@ -733,6 +784,10 @@ def plot_tree_distance_graphs( lw=lw, ax=ax[1], ) + for node in to_reverse_back: + tree2.lT._successor[node] = list( + reversed(tree2.lT._successor[node]) + ) return ax[0].get_figure(), ax def labelled_mappings( diff --git a/src/lineagetree/measure/uted.py b/src/lineagetree/measure/uted.py index 719db07..dc2dd96 100644 --- a/src/lineagetree/measure/uted.py +++ b/src/lineagetree/measure/uted.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import warnings from collections.abc import Callable from functools import partial @@ -445,10 +446,21 @@ def plot_tree_distance_graphs( matched_right = [] matched_left = [] colors = {} + to_reverse_back = [] if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: - cyc1 = lT.get_chain_of_node(corres1[m._left]) + node_1 = corres1[m._left] + node_2 = corres2[m._right] + if node_1 not in lT.roots and node_2 not in lT.roots: + if lT.successor[lT.predecessor[node_1][0]].index( + node_1 + ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): + lT._successor[lT.predecessor[node_2][0]] = list( + reversed(lT.successor[lT.predecessor[node_2][0]]) + ) + to_reverse_back.append(lT.predecessor[node_2][0]) + cyc1 = lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 matched_left.append(node_1) @@ -457,7 +469,7 @@ def plot_tree_distance_graphs( node_1 = l_node_1 = cyc1.pop() matched_left.append(node_1) - cyc2 = lT.get_chain_of_node(corres2[m._right]) + cyc2 = lT.get_chain_of_node(node_2) if len(cyc2) > 1: node_2, *_, l_node_2 = cyc2 matched_right.append(node_2) @@ -487,6 +499,14 @@ def plot_tree_distance_graphs( if m._left != -1 and m._right != -1: node_1 = corres1[m._left] node_2 = corres2[m._right] + if node_1 not in lT.roots and node_2 not in lT.roots: + if lT.successor[lT.predecessor[node_1][0]].index( + node_1 + ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): + lT._successor[lT.predecessor[node_2][0]] = list( + reversed(lT.successor[lT.predecessor[node_2][0]]) + ) + to_reverse_back.append(lT.predecessor[node_2][0]) if ( lT.get_chain_of_node(node_1)[0] == node_1 @@ -547,6 +567,8 @@ def plot_tree_distance_graphs( lw=lw, ax=ax[1], ) + for node in to_reverse_back: + lT._successor[node] = list(reversed(lT._successor[node])) return ax[0].get_figure(), ax diff --git a/src/lineagetree/plot.py b/src/lineagetree/plot.py index f644b6f..3693336 100644 --- a/src/lineagetree/plot.py +++ b/src/lineagetree/plot.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import warnings from collections.abc import Iterable from typing import Literal, TYPE_CHECKING @@ -20,20 +21,24 @@ def __plot_nodes( color: str | dict | list, size: int | float, ax: plt.Axes, + leaves: set, default_color: str = "black", **kwargs, ) -> None: """ Private method that plots the nodes of the tree. """ - + hier_no_leaves = copy.copy(hier) + for leaf in leaves: + hier_no_leaves.pop(leaf, None) if isinstance(color, dict): - color = [color.get(k, default_color) for k in hier] + color = [color.get(k, default_color) for k in hier_no_leaves] elif isinstance(color, str | list): color = [ - color if node in selected_nodes else default_color for node in hier + color if node in selected_nodes else default_color + for node in hier_no_leaves ] - hier_pos = np.array(list(hier.values())) + hier_pos = np.array(list(hier_no_leaves.values())) ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs) @@ -142,6 +147,7 @@ def draw_tree_graph( size=size, ax=ax, default_color=default_color, + leaves={k for k, v in lnks_tms["links"].items() if not v}, **kwargs, ) if not color_of_edges: