From fa1f027868c86658b1304d9811855025741ce9ae Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Wed, 10 Dec 2025 12:09:38 +0100 Subject: [PATCH 1/7] fix_tree_distance_graphs --- src/lineagetree/measure/uted.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/lineagetree/measure/uted.py b/src/lineagetree/measure/uted.py index 719db07..6801033 100644 --- a/src/lineagetree/measure/uted.py +++ b/src/lineagetree/measure/uted.py @@ -445,10 +445,21 @@ def plot_tree_distance_graphs( matched_right = [] matched_left = [] colors = {} + second_lT = lT.get_subtree(lT.get_subtree_nodes(n2)) + if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: - cyc1 = lT.get_chain_of_node(corres1[m._left]) + node_1 = corres1[m._left] + node_2 = corres1[m._right] + if node_1 not in lT.roots and node_2 not in lT.roots: + if lT.successor[lT.predecessor[node_1][0]].index( + node_1 + ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): + second_lT._successor[lT.predecessor[node_2][0]] = list( + reversed(lT.successor[lT.predecessor[node_2][0]]) + ) + cyc1 = lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 matched_left.append(node_1) @@ -457,7 +468,7 @@ def plot_tree_distance_graphs( node_1 = l_node_1 = cyc1.pop() matched_left.append(node_1) - cyc2 = lT.get_chain_of_node(corres2[m._right]) + cyc2 = lT.get_chain_of_node(node_2) if len(cyc2) > 1: node_2, *_, l_node_2 = cyc2 matched_right.append(node_2) @@ -487,6 +498,15 @@ def plot_tree_distance_graphs( if m._left != -1 and m._right != -1: node_1 = corres1[m._left] node_2 = corres2[m._right] + # node_first_tree = corres1[m._left] + # node_second_tree = corres1[m._right] + if node_1 not in lT.roots and node_2 not in lT.roots: + if lT.successor[lT.predecessor[node_1][0]].index( + node_1 + ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): + second_lT._successor[lT.predecessor[node_2][0]] = list( + reversed(lT.successor[lT.predecessor[node_2][0]]) + ) if ( lT.get_chain_of_node(node_1)[0] == node_1 From 2b9831efb6d527ae0dc9e7946aeaa956b5b72d12 Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Wed, 10 Dec 2025 12:11:23 +0100 Subject: [PATCH 2/7] do not add pos if none specified for add_root --- src/lineagetree/_core/_modifier.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lineagetree/_core/_modifier.py b/src/lineagetree/_core/_modifier.py index 25303d5..4f2b1cb 100644 --- a/src/lineagetree/_core/_modifier.py +++ b/src/lineagetree/_core/_modifier.py @@ -96,7 +96,8 @@ def add_root(lT: LineageTree, t: int, pos: list | None = None) -> int: lT._successor[C_next] = () lT._predecessor[C_next] = () lT._time[C_next] = t - lT.pos[C_next] = pos if isinstance(pos, list) else [] + if isinstance(pos, (list, tuple)): + lT.pos[C_next] = list(pos) lT._changed_roots = True return C_next From 739cdc17f85fadd3628714fe8e6a74cd467b412c Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Wed, 10 Dec 2025 14:24:04 +0100 Subject: [PATCH 3/7] reorder sibliings for minimal time cost, for one dataset --- src/lineagetree/measure/uted.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/lineagetree/measure/uted.py b/src/lineagetree/measure/uted.py index 6801033..dc2dd96 100644 --- a/src/lineagetree/measure/uted.py +++ b/src/lineagetree/measure/uted.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import warnings from collections.abc import Callable from functools import partial @@ -445,20 +446,20 @@ def plot_tree_distance_graphs( matched_right = [] matched_left = [] colors = {} - second_lT = lT.get_subtree(lT.get_subtree_nodes(n2)) - + to_reverse_back = [] if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: node_1 = corres1[m._left] - node_2 = corres1[m._right] + node_2 = corres2[m._right] if node_1 not in lT.roots and node_2 not in lT.roots: if lT.successor[lT.predecessor[node_1][0]].index( node_1 ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): - second_lT._successor[lT.predecessor[node_2][0]] = list( + lT._successor[lT.predecessor[node_2][0]] = list( reversed(lT.successor[lT.predecessor[node_2][0]]) ) + to_reverse_back.append(lT.predecessor[node_2][0]) cyc1 = lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 @@ -498,15 +499,14 @@ def plot_tree_distance_graphs( if m._left != -1 and m._right != -1: node_1 = corres1[m._left] node_2 = corres2[m._right] - # node_first_tree = corres1[m._left] - # node_second_tree = corres1[m._right] if node_1 not in lT.roots and node_2 not in lT.roots: if lT.successor[lT.predecessor[node_1][0]].index( node_1 ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): - second_lT._successor[lT.predecessor[node_2][0]] = list( + lT._successor[lT.predecessor[node_2][0]] = list( reversed(lT.successor[lT.predecessor[node_2][0]]) ) + to_reverse_back.append(lT.predecessor[node_2][0]) if ( lT.get_chain_of_node(node_1)[0] == node_1 @@ -567,6 +567,8 @@ def plot_tree_distance_graphs( lw=lw, ax=ax[1], ) + for node in to_reverse_back: + lT._successor[node] = list(reversed(lT._successor[node])) return ax[0].get_figure(), ax From 0a0cc9d2551e17b7d97f5cc3d597267ded7a25b2 Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Wed, 10 Dec 2025 15:57:28 +0100 Subject: [PATCH 4/7] reorder branches end --- src/lineagetree/lineage_tree_manager.py | 59 ++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/src/lineagetree/lineage_tree_manager.py b/src/lineagetree/lineage_tree_manager.py index 7b713df..6889cd7 100644 --- a/src/lineagetree/lineage_tree_manager.py +++ b/src/lineagetree/lineage_tree_manager.py @@ -617,10 +617,37 @@ def plot_tree_distance_graphs( matched_left = [] colors1 = {} colors2 = {} + to_reverse_back = [] + if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: - cyc1 = tree1.lT.get_chain_of_node(corres1[m._left]) + node_1 = corres1[m._left] + node_2 = corres2[m._right] + if ( + node_1 not in tree1.lT.roots + and node_2 not in tree2.lT.roots + ): + if tree1.lT.successor[ + tree1.lT.predecessor[node_1][0] + ].index(node_1) != tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ].index( + node_2 + ): + tree2.lT._successor[ + tree2.lT.predecessor[node_2][0] + ] = list( + reversed( + tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ] + ) + ) + to_reverse_back.append( + tree2.lT.predecessor[node_2][0] + ) + cyc1 = tree1.lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 matched_left.append(node_1) @@ -629,7 +656,7 @@ def plot_tree_distance_graphs( node_1 = l_node_1 = cyc1.pop() matched_left.append(node_1) - cyc2 = tree2.lT.get_chain_of_node(corres2[m._right]) + cyc2 = tree2.lT.get_chain_of_node(node_2) if len(cyc2) > 1: node_2, *_, l_node_2 = cyc2 matched_right.append(node_2) @@ -659,8 +686,32 @@ def plot_tree_distance_graphs( else: for m in btrc: if m._left != -1 and m._right != -1: + node_1 = tree1.lT.get_chain_of_node(corres1[m._left])[0] node_2 = tree2.lT.get_chain_of_node(corres2[m._right])[0] + if ( + node_1 not in tree1.lT.roots + and node_2 not in tree2.lT.roots + ): + if tree1.lT.successor[ + tree1.lT.predecessor[node_1][0] + ].index(node_1) != tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ].index( + node_2 + ): + tree2.lT._successor[ + tree2.lT.predecessor[node_2][0] + ] = list( + reversed( + tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ] + ) + ) + to_reverse_back.append( + tree2.lT.predecessor[node_2][0] + ) if ( tree1.lT.get_chain_of_node(node_1)[0] == node_1 or tree2.lT.get_chain_of_node(node_2)[0] == node_2 @@ -733,6 +784,10 @@ def plot_tree_distance_graphs( lw=lw, ax=ax[1], ) + for node in to_reverse_back: + tree2.lT._successor[node] = list( + reversed(tree2.lT._successor[node]) + ) return ax[0].get_figure(), ax def labelled_mappings( From f1b58c16171b06926c335b9e2d458a06887fd353 Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Wed, 10 Dec 2025 16:51:03 +0100 Subject: [PATCH 5/7] no_leaves --- src/lineagetree/plot.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lineagetree/plot.py b/src/lineagetree/plot.py index f644b6f..1616a7f 100644 --- a/src/lineagetree/plot.py +++ b/src/lineagetree/plot.py @@ -20,13 +20,14 @@ def __plot_nodes( color: str | dict | list, size: int | float, ax: plt.Axes, + leaves: set, default_color: str = "black", **kwargs, ) -> None: """ Private method that plots the nodes of the tree. """ - + hier = {k: v for k, v in hier.items() if k not in leaves} if isinstance(color, dict): color = [color.get(k, default_color) for k in hier] elif isinstance(color, str | list): @@ -142,6 +143,7 @@ def draw_tree_graph( size=size, ax=ax, default_color=default_color, + leaves={k for k, v in lnks_tms["links"].items() if not v}, **kwargs, ) if not color_of_edges: From 77941ddd46ad5809c3052b8abde1f6f965449f17 Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Thu, 11 Dec 2025 10:16:18 +0100 Subject: [PATCH 6/7] slightly faster --- src/lineagetree/plot.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/lineagetree/plot.py b/src/lineagetree/plot.py index 1616a7f..63b4c89 100644 --- a/src/lineagetree/plot.py +++ b/src/lineagetree/plot.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import warnings from collections.abc import Iterable from typing import Literal, TYPE_CHECKING @@ -27,14 +28,17 @@ def __plot_nodes( """ Private method that plots the nodes of the tree. """ - hier = {k: v for k, v in hier.items() if k not in leaves} + hier_no_leaves = copy.copy(hier) + for leaf in leaves: + hier_no_leaves.pop(key, None) if isinstance(color, dict): - color = [color.get(k, default_color) for k in hier] + color = [color.get(k, default_color) for k in hier_no_leaves] elif isinstance(color, str | list): color = [ - color if node in selected_nodes else default_color for node in hier + color if node in selected_nodes else default_color + for node in hier_no_leaves ] - hier_pos = np.array(list(hier.values())) + hier_pos = np.array(list(hier_no_leaves.values())) ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs) From d579744e3677e55f81fd2d055a9da3bba68cb81a Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Thu, 11 Dec 2025 10:25:15 +0100 Subject: [PATCH 7/7] fix --- src/lineagetree/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lineagetree/plot.py b/src/lineagetree/plot.py index 63b4c89..3693336 100644 --- a/src/lineagetree/plot.py +++ b/src/lineagetree/plot.py @@ -30,7 +30,7 @@ def __plot_nodes( """ hier_no_leaves = copy.copy(hier) for leaf in leaves: - hier_no_leaves.pop(key, None) + hier_no_leaves.pop(leaf, None) if isinstance(color, dict): color = [color.get(k, default_color) for k in hier_no_leaves] elif isinstance(color, str | list):