Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions notebooks/Easy_clustermaps.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@
" names[n2][1],\n",
" names[n2][0],\n",
" 100,\n",
" node_lengths=(1, 3, 7), # optional for fragmented\n",
" style=\"simple\", # Best style option for speed and accuracy is \"fragmented\"\n",
" )"
]
Expand Down Expand Up @@ -290,7 +289,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "nap",
"display_name": "sandbox",
"language": "python",
"name": "python3"
},
Expand All @@ -304,7 +303,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ push = false

[tool.ruff]
line-length = 79
select = [
lint.select = [
"E", "F", "W", #flake8
"UP", # pyupgrade
"I", # isort
Expand All @@ -109,7 +109,7 @@ select = [
"T201", # print statements
"ERA001", #commented out code
]
ignore = [
lint.ignore = [
"E501", # line too long. let black handle this
"SIM300", # yoda conditions
]
Expand Down
32 changes: 14 additions & 18 deletions src/LineageTree/lineageTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ def unordered_tree_edit_distances_at_time_t(
end_time: int | None = None,
style: Literal["simple", "full", "downsampled"] = "simple",
downsample: int = 2,
norm: Literal["max", "sum"] | None = "max",
norm: Literal["max", "sum", None] = "max",
recompute: bool = False,
) -> dict[int, float]:
"""
Expand Down Expand Up @@ -1770,7 +1770,7 @@ def __unordereded_backtrace(
n1: int,
n2: int,
end_time: int | None = None,
norm: Literal["max", "sum"] | None = "max",
norm: Literal["max", "sum", None] = "max",
style: Literal[
"simple", "normalized_simple", "full", "downsampled", "mini"
] = "simple",
Expand Down Expand Up @@ -1871,17 +1871,17 @@ def plot_tree_distance_graphs(
self,
n1: int,
n2: int,
end_time: int = None,
norm: Literal["max", "sum"] | None = "max",
end_time: int | None = None,
norm: Literal["max", "sum", None] = "max",
style: Literal[
"simple", "normalized_simple", "full", "downsampled", "mini"
] = "simple",
downsample: int = 2,
colormap: str = "cool",
default_color: str = "black",
size: float = 10,
lw: float = 0.1,
ax: list[plt.Axes, plt.Axes] = None,
lw: float = 0.3,
ax: list[plt.Axes] | None = None,
) -> tuple[plt.figure, plt.Axes]:
"""
Plots the distance graphs of 2 nodes compared.
Expand Down Expand Up @@ -1940,9 +1940,7 @@ def plot_tree_distance_graphs(
times1=times1,
times2=times2,
)
norm_dict = {"max": max, "sum": sum, "None": lambda x: 1}
if norm is None:
norm = "None"
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
if norm not in norm_dict:
raise Warning(
"Select a viable normalization method (max, sum, None)"
Expand Down Expand Up @@ -2052,8 +2050,8 @@ def labelled_mappings(
self,
n1: int,
n2: int,
end_time: int = None,
norm: Literal["max", "sum"] | None = "max",
end_time: int | None = None,
norm: Literal["max", "sum", None] = "max",
style: Literal[
"simple", "normalized_simple", "full", "downsampled", "mini"
] = "simple",
Expand Down Expand Up @@ -2108,9 +2106,7 @@ def labelled_mappings(
*_,
corres2,
) = tree2.edist
norm_dict = {"max": max, "sum": sum, "None": lambda x: 1}
if norm is None:
norm = "None"
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
if norm not in norm_dict:
raise Warning(
"Select a viable normalization method (max, sum, None)"
Expand Down Expand Up @@ -2170,8 +2166,8 @@ def unordered_tree_edit_distance(
self,
n1: int,
n2: int,
end_time: int = None,
norm: Literal["max", "sum"] | None = "max",
end_time: int | None = None,
norm: Literal["max", "sum", None] = "max",
style: Literal[
"simple", "normalized_simple", "full", "downsampled", "mini"
] = "simple",
Expand Down Expand Up @@ -2279,7 +2275,7 @@ def __plot_nodes(
def __plot_edges(
hier: dict,
lnks_tms: dict,
selected_edges: set,
selected_edges: Iterable,
color: str | dict | list,
lw: float,
ax: plt.Axes,
Expand Down Expand Up @@ -3129,7 +3125,7 @@ def plot_dtw_trajectory(
fast: bool = False,
w: int = 0,
centered_band: bool = True,
projection: Literal["3d", "xy", "xz", "yz", "pca"] | None = None,
projection: Literal["3d", "xy", "xz", "yz", "pca", None] = None,
alig: bool = False,
) -> tuple[float, plt.Figure]:
"""
Expand Down
64 changes: 30 additions & 34 deletions src/LineageTree/lineageTreeManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __cross_lineage_edit_backtrace(
embryo_2: str,
end_time2: int,
style="simple",
norm: Literal["max", "sum"] | None = "max",
norm: Literal["max", "sum", None] = "max",
downsample: int = 2,
registration=None, # will be added as a later feature
):
Expand Down Expand Up @@ -301,7 +301,7 @@ def cross_lineage_edit_distance(
n2: int,
embryo_2: str,
end_time2: int,
norm: tuple["max", "sum", "None"] | None = "max",
norm: Literal["max", "sum", None] = "max",
style: Literal[
"simple", "normalized_simple", "full", "downsampled", "mini"
] = "simple",
Expand Down Expand Up @@ -389,11 +389,11 @@ def cross_lineage_edit_distance(
corres2,
) = tree2.edist
if len(nodes1) == len(nodes2) == 0:
self._comparisons[hash(frozenset(parameters.values()))] = {
self._comparisons[hash(frozenset(parameters))] = {
"alignment": (),
"trees": (),
}
return self._comparisons[hash(frozenset(parameters.values()))]
return self._comparisons[hash(frozenset(parameters))]
delta_tmp = partial(
tree1.delta,
corres1=corres1,
Expand All @@ -420,7 +420,7 @@ def plot_tree_distance_graphs(
n2: int,
embryo_2,
end_time2,
norm: Literal["max", "sum"] | None = "max",
norm: Literal["max", "sum", None] = "max",
style: Literal[
"simple", "normalized_simple", "full", "downsampled", "mini"
] = "simple",
Expand All @@ -429,7 +429,7 @@ def plot_tree_distance_graphs(
default_color: str = "black",
size: float = 10,
lw: float = 0.3,
ax: list[plt.Axes, plt.Axes] = None,
ax: list[plt.Axes] | None = None,
) -> tuple[plt.figure, plt.Axes]:
"""
Plots the distance graphs of 2 nodes compared.
Expand Down Expand Up @@ -501,9 +501,7 @@ def plot_tree_distance_graphs(
times1=times1,
times2=times2,
)
norm_dict = {"max": max, "sum": sum, "None": lambda x: 1}
if norm is None:
norm = "None"
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
if norm not in norm_dict:
raise Warning(
"Select a viable normalization method (max, sum, None)"
Expand Down Expand Up @@ -567,28 +565,28 @@ def plot_tree_distance_graphs(
matched_right.append(node_2)
l_node_2 = tree2.lT.get_chain_of_node(node_2)[-1]
matched_right.append(l_node_2)
colors1[
node_1
] = self.__calculate_distance_of_sub_tree(
node_1,
tree1.lT,
node_2,
tree2.lT,
btrc,
corres1,
corres2,
delta_tmp,
norm_dict[norm],
tree1.get_norm(node_1),
tree2.get_norm(node_2),
colors1[node_1] = (
self.__calculate_distance_of_sub_tree(
node_1,
tree1.lT,
node_2,
tree2.lT,
btrc,
corres1,
corres2,
delta_tmp,
norm_dict[norm],
tree1.get_norm(node_1),
tree2.get_norm(node_2),
)
)
colors2[node_2] = colors1[node_1]
colors1[
tree1.lT.get_chain_of_node(node_1)[-1]
] = colors1[node_1]
colors2[
tree2.lT.get_chain_of_node(node_2)[-1]
] = colors2[node_2]
colors1[tree1.lT.get_chain_of_node(node_1)[-1]] = (
colors1[node_1]
)
colors2[tree2.lT.get_chain_of_node(node_2)[-1]] = (
colors2[node_2]
)

if tree1.lT.get_chain_of_node(node_1)[-1] != node_1:
matched_left.append(
Expand Down Expand Up @@ -634,15 +632,15 @@ def labelled_mappings(
n2: int,
embryo_2,
end_time2,
norm: Literal["max", "sum"] | None = "max",
norm: Literal["max", "sum", None] = "max",
style: Literal[
"simple", "normalized_simple", "full", "downsampled", "mini"
] = "simple",
downsample: int = 2,
colormap: str = "cool",
default_color: str = "black",
size: float = 10,
ax: list[plt.Axes, plt.Axes] = None,
ax: list[plt.Axes] | None = None,
) -> dict[str, list]:
"""
Plots the distance graphs of 2 nodes compared.
Expand Down Expand Up @@ -708,9 +706,7 @@ def labelled_mappings(
*_,
corres2,
) = tree2.edist
norm_dict = {"max": max, "sum": sum, "None": lambda x: 1}
if norm is None:
norm = "None"
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
if norm not in norm_dict:
raise Warning(
"Select a viable normalization method (max, sum, None)"
Expand Down
4 changes: 2 additions & 2 deletions src/LineageTree/test/test_lineageTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest

from LineageTree import (
lineageTree,
lineageTreeManager,
Expand Down Expand Up @@ -514,8 +515,7 @@ def test_compute_spatial_edges():
assert lT1.compute_spatial_edges()[129294] == {139162, 148358}


def test_main_axes():
...
def test_main_axes(): ...


def test_get_ancestor_at_t():
Expand Down
18 changes: 13 additions & 5 deletions src/LineageTree/tree_styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
lT: lineageTree,
root: int,
downsample: int | None = None,
end_time: int = None,
end_time: int | None = None,
time_scale: int = 1,
):
self.lT: lineageTree = lT
Expand All @@ -37,7 +37,7 @@ def __init__(
self.tree: tuple = self.get_tree()
self.edist = self._edist_format(self.tree[0])

def get_next_id(self):
def get_next_id(self) -> int:
self.internal_ids += 1
return self.internal_ids

Expand Down Expand Up @@ -66,7 +66,15 @@ def get_tree(self) -> tuple[dict, dict]:
"""

@abstractmethod
def delta(self, x, y, corres1, corres2, times1, times2):
def delta(
self,
x: int,
y: int,
corres1: dict[int, int],
corres2: dict[int, int],
times1: dict[int, float],
times2: dict[int, float],
) -> int | float:
"""The distance of two nodes inside a tree. Behaves like a staticmethod.
The corres1/2 and time1/2 should always be provided and will be handled accordingly by the specific
delta of each tree style.
Expand All @@ -78,9 +86,9 @@ def delta(self, x, y, corres1, corres2, times1, times2):
y : int
The second node to compare, takes the names provided by the edist
corres1 : dict
Correspondance between node1 and its name in the real tree.
Dictionary mapping node1 ids to the corresponding id in the original tree.
corres2 : dict
Correspondance between node2 and its name in the real tree.
Dictionary mapping node2 ids to the corresponding id in the original tree.
times1 : dict
The dictionary of the branch lengths of the tree that n1 is spawned from.
times2 : dict
Expand Down
Loading