diff --git a/src/LineageTree/lineageTree.py b/src/LineageTree/lineageTree.py index 9b71970..409cc6a 100644 --- a/src/LineageTree/lineageTree.py +++ b/src/LineageTree/lineageTree.py @@ -3349,6 +3349,25 @@ def plot_dtw_trajectory( return distance, fig + def get_subtree(self, node_list: set[int]) -> lineageTree: + new_successors = { + n: tuple(vi for vi in self.successor[n] if vi in node_list) + for n in node_list + } + return lineageTree( + successor=new_successors, + time=self._time, + pos=self.pos, + name=self.name, + root_leaf_value=[ + (), + ], + **{ + name: self.__dict__[name] + for name in self._custom_property_list + }, + ) + def __init__( self, *, @@ -3455,7 +3474,7 @@ def __init__( "Cycles were found in the tree, there should not be any." ) - if pos is None: + if pos is None or len(pos) == 0: self.pos = {} else: if self.nodes.difference(pos) != set(): @@ -3511,6 +3530,7 @@ def __init__( "Provided times are not strictly increasing. Setting times to default." ) # custom properties + self._custom_property_list = [] for name, d in kwargs.items(): if name in self.__dict__: warnings.warn( @@ -3518,5 +3538,6 @@ def __init__( ) continue setattr(self, name, d) + self._custom_property_list.append(name) if not hasattr(self, "_comparisons"): self._comparisons = {} diff --git a/src/LineageTree/lineageTreeManager.py b/src/LineageTree/lineageTreeManager.py index 0f1564d..df3014d 100644 --- a/src/LineageTree/lineageTreeManager.py +++ b/src/LineageTree/lineageTreeManager.py @@ -3,7 +3,7 @@ import os import pickle as pkl import warnings -from collections.abc import Callable +from collections.abc import Callable, Iterable from functools import partial from typing import TYPE_CHECKING, Literal @@ -35,11 +35,21 @@ class lineageTreeManager: norm_dict = {"max": max, "sum": sum, None: lambda x: 1} - def __init__(self): + def __init__(self, lineagetree_list: Iterable[lineageTree] = ()): + """Creates a lineageTreeManager + :TODO: write the docstring + + Parameters + ---------- + lineagetree_list: Iterable of lineageTree + List of lineage trees to be in the lineageTreeManager + """ self.lineagetrees = {} self.lineageTree_counter = 0 self.registered = {} self._comparisons = {} + for lT in lineagetree_list: + self.add(lT) def __next__(self): self.lineageTree_counter += 1 @@ -55,9 +65,7 @@ def __len__(self): """ return len(self.lineagetrees) - def __iter__( - self, - ): + def __iter__(self): yield from self.lineagetrees.items() def __getitem__(self, key): @@ -657,28 +665,28 @@ def plot_tree_distance_graphs( matched_right.append(node_2) l_node_2 = tree2.lT.get_chain_of_node(node_2)[-1] matched_right.append(l_node_2) - colors1[ - node_1 - ] = self.__calculate_distance_of_sub_tree( - node_1, - tree1.lT, - node_2, - tree2.lT, - btrc, - corres1, - corres2, - delta_tmp, - self.norm_dict[norm], - tree1.get_norm(node_1), - tree2.get_norm(node_2), + colors1[node_1] = ( + self.__calculate_distance_of_sub_tree( + node_1, + tree1.lT, + node_2, + tree2.lT, + btrc, + corres1, + corres2, + delta_tmp, + self.norm_dict[norm], + tree1.get_norm(node_1), + tree2.get_norm(node_2), + ) ) colors2[node_2] = colors1[node_1] - colors1[ - tree1.lT.get_chain_of_node(node_1)[-1] - ] = colors1[node_1] - colors2[ - tree2.lT.get_chain_of_node(node_2)[-1] - ] = colors2[node_2] + colors1[tree1.lT.get_chain_of_node(node_1)[-1]] = ( + colors1[node_1] + ) + colors2[tree2.lT.get_chain_of_node(node_2)[-1]] = ( + colors2[node_2] + ) if tree1.lT.get_chain_of_node(node_1)[-1] != node_1: matched_left.append(