From fa1f027868c86658b1304d9811855025741ce9ae Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Wed, 10 Dec 2025 12:09:38 +0100 Subject: [PATCH 1/7] fix_tree_distance_graphs --- src/lineagetree/measure/uted.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/lineagetree/measure/uted.py b/src/lineagetree/measure/uted.py index 719db07..6801033 100644 --- a/src/lineagetree/measure/uted.py +++ b/src/lineagetree/measure/uted.py @@ -445,10 +445,21 @@ def plot_tree_distance_graphs( matched_right = [] matched_left = [] colors = {} + second_lT = lT.get_subtree(lT.get_subtree_nodes(n2)) + if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: - cyc1 = lT.get_chain_of_node(corres1[m._left]) + node_1 = corres1[m._left] + node_2 = corres1[m._right] + if node_1 not in lT.roots and node_2 not in lT.roots: + if lT.successor[lT.predecessor[node_1][0]].index( + node_1 + ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): + second_lT._successor[lT.predecessor[node_2][0]] = list( + reversed(lT.successor[lT.predecessor[node_2][0]]) + ) + cyc1 = lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 matched_left.append(node_1) @@ -457,7 +468,7 @@ def plot_tree_distance_graphs( node_1 = l_node_1 = cyc1.pop() matched_left.append(node_1) - cyc2 = lT.get_chain_of_node(corres2[m._right]) + cyc2 = lT.get_chain_of_node(node_2) if len(cyc2) > 1: node_2, *_, l_node_2 = cyc2 matched_right.append(node_2) @@ -487,6 +498,15 @@ def plot_tree_distance_graphs( if m._left != -1 and m._right != -1: node_1 = corres1[m._left] node_2 = corres2[m._right] + # node_first_tree = corres1[m._left] + # node_second_tree = corres1[m._right] + if node_1 not in lT.roots and node_2 not in lT.roots: + if lT.successor[lT.predecessor[node_1][0]].index( + node_1 + ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): + second_lT._successor[lT.predecessor[node_2][0]] = list( + reversed(lT.successor[lT.predecessor[node_2][0]]) + ) if ( lT.get_chain_of_node(node_1)[0] == node_1 From 2b9831efb6d527ae0dc9e7946aeaa956b5b72d12 Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Wed, 10 Dec 2025 12:11:23 +0100 Subject: [PATCH 2/7] do not add pos if none specified for add_root --- src/lineagetree/_core/_modifier.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lineagetree/_core/_modifier.py b/src/lineagetree/_core/_modifier.py index 25303d5..4f2b1cb 100644 --- a/src/lineagetree/_core/_modifier.py +++ b/src/lineagetree/_core/_modifier.py @@ -96,7 +96,8 @@ def add_root(lT: LineageTree, t: int, pos: list | None = None) -> int: lT._successor[C_next] = () lT._predecessor[C_next] = () lT._time[C_next] = t - lT.pos[C_next] = pos if isinstance(pos, list) else [] + if isinstance(pos, (list, tuple)): + lT.pos[C_next] = list(pos) lT._changed_roots = True return C_next From 739cdc17f85fadd3628714fe8e6a74cd467b412c Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Wed, 10 Dec 2025 14:24:04 +0100 Subject: [PATCH 3/7] reorder sibliings for minimal time cost, for one dataset --- src/lineagetree/measure/uted.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/lineagetree/measure/uted.py b/src/lineagetree/measure/uted.py index 6801033..dc2dd96 100644 --- a/src/lineagetree/measure/uted.py +++ b/src/lineagetree/measure/uted.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import warnings from collections.abc import Callable from functools import partial @@ -445,20 +446,20 @@ def plot_tree_distance_graphs( matched_right = [] matched_left = [] colors = {} - second_lT = lT.get_subtree(lT.get_subtree_nodes(n2)) - + to_reverse_back = [] if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: node_1 = corres1[m._left] - node_2 = corres1[m._right] + node_2 = corres2[m._right] if node_1 not in lT.roots and node_2 not in lT.roots: if lT.successor[lT.predecessor[node_1][0]].index( node_1 ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): - second_lT._successor[lT.predecessor[node_2][0]] = list( + lT._successor[lT.predecessor[node_2][0]] = list( reversed(lT.successor[lT.predecessor[node_2][0]]) ) + to_reverse_back.append(lT.predecessor[node_2][0]) cyc1 = lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 @@ -498,15 +499,14 @@ def plot_tree_distance_graphs( if m._left != -1 and m._right != -1: node_1 = corres1[m._left] node_2 = corres2[m._right] - # node_first_tree = corres1[m._left] - # node_second_tree = corres1[m._right] if node_1 not in lT.roots and node_2 not in lT.roots: if lT.successor[lT.predecessor[node_1][0]].index( node_1 ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): - second_lT._successor[lT.predecessor[node_2][0]] = list( + lT._successor[lT.predecessor[node_2][0]] = list( reversed(lT.successor[lT.predecessor[node_2][0]]) ) + to_reverse_back.append(lT.predecessor[node_2][0]) if ( lT.get_chain_of_node(node_1)[0] == node_1 @@ -567,6 +567,8 @@ def plot_tree_distance_graphs( lw=lw, ax=ax[1], ) + for node in to_reverse_back: + lT._successor[node] = list(reversed(lT._successor[node])) return ax[0].get_figure(), ax From 0a0cc9d2551e17b7d97f5cc3d597267ded7a25b2 Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Wed, 10 Dec 2025 15:57:28 +0100 Subject: [PATCH 4/7] reorder branches end --- src/lineagetree/lineage_tree_manager.py | 59 ++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/src/lineagetree/lineage_tree_manager.py b/src/lineagetree/lineage_tree_manager.py index 7b713df..6889cd7 100644 --- a/src/lineagetree/lineage_tree_manager.py +++ b/src/lineagetree/lineage_tree_manager.py @@ -617,10 +617,37 @@ def plot_tree_distance_graphs( matched_left = [] colors1 = {} colors2 = {} + to_reverse_back = [] + if style not in ("full", "downsampled"): for m in btrc: if m._left != -1 and m._right != -1: - cyc1 = tree1.lT.get_chain_of_node(corres1[m._left]) + node_1 = corres1[m._left] + node_2 = corres2[m._right] + if ( + node_1 not in tree1.lT.roots + and node_2 not in tree2.lT.roots + ): + if tree1.lT.successor[ + tree1.lT.predecessor[node_1][0] + ].index(node_1) != tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ].index( + node_2 + ): + tree2.lT._successor[ + tree2.lT.predecessor[node_2][0] + ] = list( + reversed( + tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ] + ) + ) + to_reverse_back.append( + tree2.lT.predecessor[node_2][0] + ) + cyc1 = tree1.lT.get_chain_of_node(node_1) if len(cyc1) > 1: node_1, *_, l_node_1 = cyc1 matched_left.append(node_1) @@ -629,7 +656,7 @@ def plot_tree_distance_graphs( node_1 = l_node_1 = cyc1.pop() matched_left.append(node_1) - cyc2 = tree2.lT.get_chain_of_node(corres2[m._right]) + cyc2 = tree2.lT.get_chain_of_node(node_2) if len(cyc2) > 1: node_2, *_, l_node_2 = cyc2 matched_right.append(node_2) @@ -659,8 +686,32 @@ def plot_tree_distance_graphs( else: for m in btrc: if m._left != -1 and m._right != -1: + node_1 = tree1.lT.get_chain_of_node(corres1[m._left])[0] node_2 = tree2.lT.get_chain_of_node(corres2[m._right])[0] + if ( + node_1 not in tree1.lT.roots + and node_2 not in tree2.lT.roots + ): + if tree1.lT.successor[ + tree1.lT.predecessor[node_1][0] + ].index(node_1) != tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ].index( + node_2 + ): + tree2.lT._successor[ + tree2.lT.predecessor[node_2][0] + ] = list( + reversed( + tree2.lT.successor[ + tree2.lT.predecessor[node_2][0] + ] + ) + ) + to_reverse_back.append( + tree2.lT.predecessor[node_2][0] + ) if ( tree1.lT.get_chain_of_node(node_1)[0] == node_1 or tree2.lT.get_chain_of_node(node_2)[0] == node_2 @@ -733,6 +784,10 @@ def plot_tree_distance_graphs( lw=lw, ax=ax[1], ) + for node in to_reverse_back: + tree2.lT._successor[node] = list( + reversed(tree2.lT._successor[node]) + ) return ax[0].get_figure(), ax def labelled_mappings( From e20c7eebd3dc334ae0e4cc260127f20b067a52c2 Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Thu, 11 Dec 2025 10:50:23 +0100 Subject: [PATCH 5/7] . --- src/lineagetree/measure/uted.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/lineagetree/measure/uted.py b/src/lineagetree/measure/uted.py index dc2dd96..79db8da 100644 --- a/src/lineagetree/measure/uted.py +++ b/src/lineagetree/measure/uted.py @@ -7,6 +7,7 @@ from itertools import combinations from typing import TYPE_CHECKING, Literal +import matplotlib as mpl import numpy as np import matplotlib.colors as mcolors from edist import uted @@ -569,6 +570,14 @@ def plot_tree_distance_graphs( ) for node in to_reverse_back: lT._successor[node] = list(reversed(lT._successor[node])) + cax = fig.add_axes([0.25, 0.10, 0.50, 0.01]) + for x in ax: + for s in x.spines.values(): + s.set_visible(False) + sm = mpl.cm.ScalarMappable(cmap=cmap, norm=c_norm) + sm.set_array([]) + fig.colorbar(sm, cax=cax, orientation="horizontal") + return ax[0].get_figure(), ax From 3323c3315a3d61ea218678efa2f609e20f9fbc52 Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Thu, 11 Dec 2025 14:57:01 +0100 Subject: [PATCH 6/7] fixed order --- src/lineagetree/measure/uted.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/lineagetree/measure/uted.py b/src/lineagetree/measure/uted.py index 79db8da..074aa19 100644 --- a/src/lineagetree/measure/uted.py +++ b/src/lineagetree/measure/uted.py @@ -454,9 +454,17 @@ def plot_tree_distance_graphs( node_1 = corres1[m._left] node_2 = corres2[m._right] if node_1 not in lT.roots and node_2 not in lT.roots: - if lT.successor[lT.predecessor[node_1][0]].index( + if lT.predecessor[node_2][ + 0 + ] not in to_reverse_back and lT.successor[ + lT.predecessor[node_1][0] + ].index( node_1 - ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): + ) != lT.successor[ + lT.predecessor[node_2][0] + ].index( + node_2 + ): lT._successor[lT.predecessor[node_2][0]] = list( reversed(lT.successor[lT.predecessor[node_2][0]]) ) @@ -501,9 +509,17 @@ def plot_tree_distance_graphs( node_1 = corres1[m._left] node_2 = corres2[m._right] if node_1 not in lT.roots and node_2 not in lT.roots: - if lT.successor[lT.predecessor[node_1][0]].index( + if lT.predecessor[node_2][ + 0 + ] not in to_reverse_back and lT._successor[ + lT.predecessor[node_1][0] + ].index( node_1 - ) != lT.successor[lT.predecessor[node_2][0]].index(node_2): + ) != lT._successor[ + lT.predecessor[node_2][0] + ].index( + node_2 + ): lT._successor[lT.predecessor[node_2][0]] = list( reversed(lT.successor[lT.predecessor[node_2][0]]) ) @@ -570,6 +586,7 @@ def plot_tree_distance_graphs( ) for node in to_reverse_back: lT._successor[node] = list(reversed(lT._successor[node])) + cax = fig.add_axes([0.25, 0.10, 0.50, 0.01]) for x in ax: for s in x.spines.values(): From 17067ad60257b53c211662040826cc27ee8a47d1 Mon Sep 17 00:00:00 2001 From: BadPrograms Date: Thu, 11 Dec 2025 15:40:18 +0100 Subject: [PATCH 7/7] completely fix tree distance graphs --- src/lineagetree/lineage_tree_manager.py | 41 ++++++++++++++--------- src/lineagetree/measure/uted.py | 43 +++++++++++++------------ 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/src/lineagetree/lineage_tree_manager.py b/src/lineagetree/lineage_tree_manager.py index 6889cd7..9e8312e 100644 --- a/src/lineagetree/lineage_tree_manager.py +++ b/src/lineagetree/lineage_tree_manager.py @@ -628,9 +628,13 @@ def plot_tree_distance_graphs( node_1 not in tree1.lT.roots and node_2 not in tree2.lT.roots ): - if tree1.lT.successor[ + if tree2.lT.predecessor[ + node_2 + ] not in to_reverse_back and tree1.lT.successor[ tree1.lT.predecessor[node_1][0] - ].index(node_1) != tree2.lT.successor[ + ].index( + node_1 + ) != tree2.lT.successor[ tree2.lT.predecessor[node_2][0] ].index( node_2 @@ -693,9 +697,13 @@ def plot_tree_distance_graphs( node_1 not in tree1.lT.roots and node_2 not in tree2.lT.roots ): - if tree1.lT.successor[ + if tree2.lT.predecessor[ + node_2 + ] not in to_reverse_back and tree1.lT.successor[ tree1.lT.predecessor[node_1][0] - ].index(node_1) != tree2.lT.successor[ + ].index( + node_1 + ) != tree2.lT.successor[ tree2.lT.predecessor[node_2][0] ].index( node_2 @@ -712,10 +720,11 @@ def plot_tree_distance_graphs( to_reverse_back.append( tree2.lT.predecessor[node_2][0] ) - if ( - tree1.lT.get_chain_of_node(node_1)[0] == node_1 - or tree2.lT.get_chain_of_node(node_2)[0] == node_2 - and (node_1 not in colors1 or node_2 not in colors2) + + if not colors1.get( + tree1.lT.get_chain_of_node(node_1)[0] + ) or not colors2.get( + tree2.lT.get_chain_of_node(node_2)[0] ): matched_left.append(node_1) l_node_1 = tree1.lT.get_chain_of_node(node_1)[-1] @@ -738,6 +747,14 @@ def plot_tree_distance_graphs( tree2.get_norm(node_2), ) ) + for n1 in tree1.lT.get_chain_of_node(node_1): + colors1[n1] = colors1[ + tree1.lT.get_chain_of_node(node_1)[0] + ] + for n2 in tree2.lT.get_chain_of_node(node_2): + colors2[n2] = colors1[ + tree1.lT.get_chain_of_node(node_1)[0] + ] colors2[node_2] = colors1[node_1] colors1[tree1.lT.get_chain_of_node(node_1)[-1]] = ( colors1[node_1] @@ -746,14 +763,6 @@ def plot_tree_distance_graphs( colors2[node_2] ) - if tree1.lT.get_chain_of_node(node_1)[-1] != node_1: - matched_left.append( - tree1.lT.get_chain_of_node(node_1)[-1] - ) - if tree2.lT.get_chain_of_node(node_2)[-1] != node_2: - matched_right.append( - tree2.lT.get_chain_of_node(node_2)[-1] - ) if ax is None: fig, ax = plt.subplots(nrows=1, ncols=2) cmap = colormaps[colormap] diff --git a/src/lineagetree/measure/uted.py b/src/lineagetree/measure/uted.py index 074aa19..9979129 100644 --- a/src/lineagetree/measure/uted.py +++ b/src/lineagetree/measure/uted.py @@ -506,8 +506,8 @@ def plot_tree_distance_graphs( else: for m in btrc: if m._left != -1 and m._right != -1: - node_1 = corres1[m._left] - node_2 = corres2[m._right] + node_1 = lT.get_chain_of_node(corres1[m._left])[0] + node_2 = lT.get_chain_of_node(corres2[m._right])[0] if node_1 not in lT.roots and node_2 not in lT.roots: if lT.predecessor[node_2][ 0 @@ -525,32 +525,33 @@ def plot_tree_distance_graphs( ) to_reverse_back.append(lT.predecessor[node_2][0]) - if ( - lT.get_chain_of_node(node_1)[0] == node_1 - or lT.get_chain_of_node(node_2)[0] == node_2 - and (node_1 not in colors or node_2 not in colors) - ): + if not colors.get( + lT.get_chain_of_node(node_1)[0] + ) or not colors.get(lT.get_chain_of_node(node_2)[0]): matched_left.append(node_1) l_node_1 = lT.get_chain_of_node(node_1)[-1] matched_left.append(l_node_1) matched_right.append(node_2) l_node_2 = lT.get_chain_of_node(node_2)[-1] matched_right.append(l_node_2) - colors[node_1] = __calculate_distance_of_sub_tree( - lT, - node_1, - node_2, - btrc, - corres1, - corres2, - delta_tmp, - lT.norm_dict[norm], - tree1.get_norm(node_1), - tree2.get_norm(node_2), + colors[lT.get_chain_of_node(node_1)[0]] = ( + __calculate_distance_of_sub_tree( + lT, + node_1, + node_2, + btrc, + corres1, + corres2, + delta_tmp, + lT.norm_dict[norm], + tree1.get_norm(node_1), + tree2.get_norm(node_2), + ) ) - colors[l_node_1] = colors[node_1] - colors[node_2] = colors[node_1] - colors[l_node_2] = colors[node_1] + for n1 in lT.get_chain_of_node(node_1): + colors[n1] = colors[lT.get_chain_of_node(node_1)[0]] + for n2 in lT.get_chain_of_node(node_2): + colors[n2] = colors[lT.get_chain_of_node(node_1)[0]] if ax is None: fig, ax = plt.subplots(nrows=1, ncols=2, sharey=True) cmap = colormaps[colormap]