From b75cb2e0fa00c13ec5ee77bde88687961332b90b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Guignard?= Date: Tue, 20 May 2025 12:55:27 +0200 Subject: [PATCH] helping a bit hopefully --- notebooks/Easy_clustermaps.ipynb | 5 +- pyproject.toml | 4 +- src/LineageTree/lineageTree.py | 30 +++++------ src/LineageTree/lineageTreeManager.py | 64 +++++++++++------------- src/LineageTree/test/test_lineageTree.py | 4 +- src/LineageTree/tree_styles.py | 18 +++++-- 6 files changed, 62 insertions(+), 63 deletions(-) diff --git a/notebooks/Easy_clustermaps.ipynb b/notebooks/Easy_clustermaps.ipynb index 64dfead..ec2bf60 100644 --- a/notebooks/Easy_clustermaps.ipynb +++ b/notebooks/Easy_clustermaps.ipynb @@ -203,7 +203,6 @@ " names[n2][1],\n", " names[n2][0],\n", " 100,\n", - " node_lengths=(1, 3, 7), # optional for fragmented\n", " style=\"simple\", # Best style option for speed and accuracy is \"fragmented\"\n", " )" ] @@ -290,7 +289,7 @@ ], "metadata": { "kernelspec": { - "display_name": "nap", + "display_name": "sandbox", "language": "python", "name": "python3" }, @@ -304,7 +303,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 1f8f0d8..7f7afb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,7 @@ push = false [tool.ruff] line-length = 79 -select = [ +lint.select = [ "E", "F", "W", #flake8 "UP", # pyupgrade "I", # isort @@ -109,7 +109,7 @@ select = [ "T201", # print statements "ERA001", #commented out code ] -ignore = [ +lint.ignore = [ "E501", # line too long. let black handle this "SIM300", # yoda conditions ] diff --git a/src/LineageTree/lineageTree.py b/src/LineageTree/lineageTree.py index e3be6ae..7947560 100644 --- a/src/LineageTree/lineageTree.py +++ b/src/LineageTree/lineageTree.py @@ -1698,7 +1698,7 @@ def unordered_tree_edit_distances_at_time_t( end_time: int | None = None, style: Literal["simple", "full", "downsampled"] = "simple", downsample: int = 2, - norm: Literal["max", "sum"] | None = "max", + norm: Literal["max", "sum", None] = "max", recompute: bool = False, ) -> dict[int, float]: """ @@ -1780,7 +1780,7 @@ def __unordereded_backtrace( n1: int, n2: int, end_time: int | None = None, - norm: Literal["max", "sum"] | None = "max", + norm: Literal["max", "sum", None] = "max", style: Literal[ "simple", "normalized_simple", "full", "downsampled", "mini" ] = "simple", @@ -1879,8 +1879,8 @@ def plot_tree_distance_graphs( self, n1: int, n2: int, - end_time: int = None, - norm: Literal["max", "sum"] | None = "max", + end_time: int | None = None, + norm: Literal["max", "sum", None] = "max", style: Literal[ "simple", "normalized_simple", "full", "downsampled", "mini" ] = "simple", @@ -1889,7 +1889,7 @@ def plot_tree_distance_graphs( default_color: str = "black", size: float = 10, lw: float = 0.3, - ax: list[plt.Axes, plt.Axes] = None, + ax: list[plt.Axes] | None = None, ) -> tuple[plt.figure, plt.Axes]: """ Plots the distance graphs of 2 nodes compared. @@ -1948,9 +1948,7 @@ def plot_tree_distance_graphs( times1=times1, times2=times2, ) - norm_dict = {"max": max, "sum": sum, "None": lambda x: 1} - if norm is None: - norm = "None" + norm_dict = {"max": max, "sum": sum, None: lambda x: 1} if norm not in norm_dict: raise Warning( "Select a viable normalization method (max, sum, None)" @@ -2060,8 +2058,8 @@ def labelled_mappings( self, n1: int, n2: int, - end_time: int = None, - norm: Literal["max", "sum"] | None = "max", + end_time: int | None = None, + norm: Literal["max", "sum", None] = "max", style: Literal[ "simple", "normalized_simple", "full", "downsampled", "mini" ] = "simple", @@ -2116,9 +2114,7 @@ def labelled_mappings( *_, corres2, ) = tree2.edist - norm_dict = {"max": max, "sum": sum, "None": lambda x: 1} - if norm is None: - norm = "None" + norm_dict = {"max": max, "sum": sum, None: lambda x: 1} if norm not in norm_dict: raise Warning( "Select a viable normalization method (max, sum, None)" @@ -2178,8 +2174,8 @@ def unordered_tree_edit_distance( self, n1: int, n2: int, - end_time: int = None, - norm: Literal["max", "sum"] | None = "max", + end_time: int | None = None, + norm: Literal["max", "sum", None] = "max", style: Literal[ "simple", "normalized_simple", "full", "downsampled", "mini" ] = "simple", @@ -2287,7 +2283,7 @@ def __plot_nodes( def __plot_edges( hier: dict, lnks_tms: dict, - selected_edges: set, + selected_edges: Iterable, color: str | dict | list, lw: float, ax: plt.Axes, @@ -3137,7 +3133,7 @@ def plot_dtw_trajectory( fast: bool = False, w: int = 0, centered_band: bool = True, - projection: Literal["3d", "xy", "xz", "yz", "pca"] | None = None, + projection: Literal["3d", "xy", "xz", "yz", "pca", None] = None, alig: bool = False, ) -> tuple[float, plt.Figure]: """ diff --git a/src/LineageTree/lineageTreeManager.py b/src/LineageTree/lineageTreeManager.py index 02a7ffa..d7e01e5 100644 --- a/src/LineageTree/lineageTreeManager.py +++ b/src/LineageTree/lineageTreeManager.py @@ -160,7 +160,7 @@ def __cross_lineage_edit_backtrace( embryo_2: str, end_time2: int, style="simple", - norm: Literal["max", "sum"] | None = "max", + norm: Literal["max", "sum", None] = "max", downsample: int = 2, registration=None, # will be added as a later feature ): @@ -301,7 +301,7 @@ def cross_lineage_edit_distance( n2: int, embryo_2: str, end_time2: int, - norm: tuple["max", "sum", "None"] | None = "max", + norm: Literal["max", "sum", None] = "max", style: Literal[ "simple", "normalized_simple", "full", "downsampled", "mini" ] = "simple", @@ -389,11 +389,11 @@ def cross_lineage_edit_distance( corres2, ) = tree2.edist if len(nodes1) == len(nodes2) == 0: - self._comparisons[hash(frozenset(parameters.values()))] = { + self._comparisons[hash(frozenset(parameters))] = { "alignment": (), "trees": (), } - return self._comparisons[hash(frozenset(parameters.values()))] + return self._comparisons[hash(frozenset(parameters))] delta_tmp = partial( tree1.delta, corres1=corres1, @@ -420,7 +420,7 @@ def plot_tree_distance_graphs( n2: int, embryo_2, end_time2, - norm: Literal["max", "sum"] | None = "max", + norm: Literal["max", "sum", None] = "max", style: Literal[ "simple", "normalized_simple", "full", "downsampled", "mini" ] = "simple", @@ -429,7 +429,7 @@ def plot_tree_distance_graphs( default_color: str = "black", size: float = 10, lw: float = 0.3, - ax: list[plt.Axes, plt.Axes] = None, + ax: list[plt.Axes] | None = None, ) -> tuple[plt.figure, plt.Axes]: """ Plots the distance graphs of 2 nodes compared. @@ -501,9 +501,7 @@ def plot_tree_distance_graphs( times1=times1, times2=times2, ) - norm_dict = {"max": max, "sum": sum, "None": lambda x: 1} - if norm is None: - norm = "None" + norm_dict = {"max": max, "sum": sum, None: lambda x: 1} if norm not in norm_dict: raise Warning( "Select a viable normalization method (max, sum, None)" @@ -567,28 +565,28 @@ def plot_tree_distance_graphs( matched_right.append(node_2) l_node_2 = tree2.lT.get_chain_of_node(node_2)[-1] matched_right.append(l_node_2) - colors1[ - node_1 - ] = self.__calculate_distance_of_sub_tree( - node_1, - tree1.lT, - node_2, - tree2.lT, - btrc, - corres1, - corres2, - delta_tmp, - norm_dict[norm], - tree1.get_norm(node_1), - tree2.get_norm(node_2), + colors1[node_1] = ( + self.__calculate_distance_of_sub_tree( + node_1, + tree1.lT, + node_2, + tree2.lT, + btrc, + corres1, + corres2, + delta_tmp, + norm_dict[norm], + tree1.get_norm(node_1), + tree2.get_norm(node_2), + ) ) colors2[node_2] = colors1[node_1] - colors1[ - tree1.lT.get_chain_of_node(node_1)[-1] - ] = colors1[node_1] - colors2[ - tree2.lT.get_chain_of_node(node_2)[-1] - ] = colors2[node_2] + colors1[tree1.lT.get_chain_of_node(node_1)[-1]] = ( + colors1[node_1] + ) + colors2[tree2.lT.get_chain_of_node(node_2)[-1]] = ( + colors2[node_2] + ) if tree1.lT.get_chain_of_node(node_1)[-1] != node_1: matched_left.append( @@ -634,7 +632,7 @@ def labelled_mappings( n2: int, embryo_2, end_time2, - norm: Literal["max", "sum"] | None = "max", + norm: Literal["max", "sum", None] = "max", style: Literal[ "simple", "normalized_simple", "full", "downsampled", "mini" ] = "simple", @@ -642,7 +640,7 @@ def labelled_mappings( colormap: str = "cool", default_color: str = "black", size: float = 10, - ax: list[plt.Axes, plt.Axes] = None, + ax: list[plt.Axes] | None = None, ) -> dict[str, list]: """ Plots the distance graphs of 2 nodes compared. @@ -708,9 +706,7 @@ def labelled_mappings( *_, corres2, ) = tree2.edist - norm_dict = {"max": max, "sum": sum, "None": lambda x: 1} - if norm is None: - norm = "None" + norm_dict = {"max": max, "sum": sum, None: lambda x: 1} if norm not in norm_dict: raise Warning( "Select a viable normalization method (max, sum, None)" diff --git a/src/LineageTree/test/test_lineageTree.py b/src/LineageTree/test/test_lineageTree.py index 8ff3e78..bc16f9e 100644 --- a/src/LineageTree/test/test_lineageTree.py +++ b/src/LineageTree/test/test_lineageTree.py @@ -2,6 +2,7 @@ import numpy as np import pytest + from LineageTree import ( lineageTree, lineageTreeManager, @@ -514,8 +515,7 @@ def test_compute_spatial_edges(): assert lT1.compute_spatial_edges()[129294] == {139162, 148358} -def test_main_axes(): - ... +def test_main_axes(): ... def test_get_ancestor_at_t(): diff --git a/src/LineageTree/tree_styles.py b/src/LineageTree/tree_styles.py index fd56b49..6738cb9 100644 --- a/src/LineageTree/tree_styles.py +++ b/src/LineageTree/tree_styles.py @@ -23,7 +23,7 @@ def __init__( lT: lineageTree, root: int, downsample: int | None = None, - end_time: int = None, + end_time: int | None = None, time_scale: int = 1, ): self.lT: lineageTree = lT @@ -37,7 +37,7 @@ def __init__( self.tree: tuple = self.get_tree() self.edist = self._edist_format(self.tree[0]) - def get_next_id(self): + def get_next_id(self) -> int: self.internal_ids += 1 return self.internal_ids @@ -66,7 +66,15 @@ def get_tree(self) -> tuple[dict, dict]: """ @abstractmethod - def delta(self, x, y, corres1, corres2, times1, times2): + def delta( + self, + x: int, + y: int, + corres1: dict[int, int], + corres2: dict[int, int], + times1: dict[int, float], + times2: dict[int, float], + ) -> int | float: """The distance of two nodes inside a tree. Behaves like a staticmethod. The corres1/2 and time1/2 should always be provided and will be handled accordingly by the specific delta of each tree style. @@ -78,9 +86,9 @@ def delta(self, x, y, corres1, corres2, times1, times2): y : int The second node to compare, takes the names provided by the edist corres1 : dict - Correspondance between node1 and its name in the real tree. + Dictionary mapping node1 ids to the corresponding id in the original tree. corres2 : dict - Correspondance between node2 and its name in the real tree. + Dictionary mapping node2 ids to the corresponding id in the original tree. times1 : dict The dictionary of the branch lengths of the tree that n1 is spawned from. times2 : dict