diff --git a/src/LineageTree/lineageTree.py b/src/LineageTree/lineageTree.py index 409cc6a..e8629ca 100644 --- a/src/LineageTree/lineageTree.py +++ b/src/LineageTree/lineageTree.py @@ -2668,7 +2668,7 @@ def nodes_at_t( self, t: int, r: int | Iterable[int] | None = None, - ) -> list: + ) -> list[int]: """ Returns the list of nodes at time `t` that are spawn by the node(s) `r`. @@ -2681,8 +2681,8 @@ def nodes_at_t( Returns ------- - list - list of nodes at time `t` spawned by `r` + list of int + list of ids of the nodes at time `t` spawned by `r` """ if not r and r != 0: r = {root for root in self.roots if self.time[root] <= t} diff --git a/src/LineageTree/lineageTreeManager.py b/src/LineageTree/lineageTreeManager.py index 7847059..41c3677 100644 --- a/src/LineageTree/lineageTreeManager.py +++ b/src/LineageTree/lineageTreeManager.py @@ -3,7 +3,7 @@ import os import pickle as pkl import warnings -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Generator from functools import partial from typing import TYPE_CHECKING, Literal @@ -44,18 +44,17 @@ def __init__(self, lineagetree_list: Iterable[lineageTree] = ()): lineagetree_list: Iterable of lineageTree List of lineage trees to be in the lineageTreeManager """ - self.lineagetrees = {} - self.lineageTree_counter = 0 - self.registered = {} - self._comparisons = {} + self.lineagetrees: dict[str, lineageTree] = {} + self.lineageTree_counter: int = 0 + self._comparisons: dict = {} for lT in lineagetree_list: self.add(lT) - def __next__(self): + def __next__(self) -> int: self.lineageTree_counter += 1 return self.lineageTree_counter - 1 - def __len__(self): + def __len__(self) -> int: """Returns how many lineagetrees are in the manager. Returns @@ -65,10 +64,10 @@ def __len__(self): """ return len(self.lineagetrees) - def __iter__(self): + def __iter__(self) -> Generator[tuple[str, lineageTree]]: yield from self.lineagetrees.items() - def __getitem__(self, key): + def __getitem__(self, key: str) -> lineageTree: if key in self.lineagetrees: return self.lineagetrees[key] else: @@ -130,7 +129,7 @@ def add(self, other_tree: lineageTree, name: str = ""): "Please add a LineageTree object or add time resolution to the LineageTree added." ) - def __add__(self, other): + def __add__(self, other: lineageTree): self.add(other) def write(self, fname: str): @@ -154,7 +153,7 @@ def write(self, fname: str): pkl.dump(self, f) f.close() - def remove_embryo(self, key): + def remove_embryo(self, key: str): """Removes the embryo from the manager. Parameters @@ -429,13 +428,7 @@ def cross_lineage_edit_distance( ------- Alignment The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.` - -- - ΟΡ - -- - - Alignment - The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.` - tuple(tree,tree) + tuple(tree,tree), optional The two trees that have been mapped to each other. """