Skip to content

Commit 518aba5

Browse files
committed
update_type_hints_and_drop_all_branch_mentions
1 parent ad0f25d commit 518aba5

File tree

5 files changed

+140
-150
lines changed

5 files changed

+140
-150
lines changed

src/LineageTree/lineageTree.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from .tree_styles import abstract_trees, tree_style
3333
from .utils import (
3434
convert_style_to_number,
35-
create_links_and_cycles,
35+
create_links_and_chains,
3636
hierarchical_pos,
3737
)
3838

@@ -67,6 +67,9 @@ def __get__(self, instance, owner):
6767

6868

6969
class lineageTree:
70+
71+
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
72+
7073
def modifier(wrapped_func):
7174
@wraps(wrapped_func)
7275
def raising_flag(self, *args, **kwargs):
@@ -1469,7 +1472,7 @@ def compute_spatial_density(
14691472
14701473
Returns
14711474
-------
1472-
dict of int to float
1475+
dict mapping int to float
14731476
dictionary that maps a node id to its spatial density
14741477
"""
14751478
if not hasattr(self, "spatial_density"):
@@ -1501,7 +1504,7 @@ def compute_k_nearest_neighbours(self, k: int = 10) -> dict[int, set[int]]:
15011504
15021505
Returns
15031506
-------
1504-
dict of int to set of int
1507+
dict mapping int to set of int
15051508
dictionary that maps
15061509
a node id to its `k` nearest neighbors
15071510
"""
@@ -1538,7 +1541,7 @@ def compute_spatial_edges(self, th: int = 50) -> dict[int, set[int]]:
15381541
15391542
Returns
15401543
-------
1541-
dict of int to set of int
1544+
dict mapping int to set of int
15421545
dictionary that maps a node id to its neighbors at a distance `th`
15431546
"""
15441547
self.th_edges = {}
@@ -1834,8 +1837,7 @@ def __unordereded_backtrace(
18341837
18351838
Returns
18361839
-------
1837-
dict
1838-
Dictionary containing:
1840+
dict mapping str to Alignment or tuple of [abstract_trees, abstract_trees]
18391841
- 'alignment'
18401842
The alignment between the nodes by the subtrees spawned by the nodes n1,n2 and the normalization function.`
18411843
- 'trees'
@@ -1988,8 +1990,8 @@ def plot_tree_distance_graphs(
19881990
times1=times1,
19891991
times2=times2,
19901992
)
1991-
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
1992-
if norm not in norm_dict:
1993+
1994+
if norm not in self.norm_dict:
19931995
raise Warning(
19941996
"Select a viable normalization method (max, sum, None)"
19951997
)
@@ -2025,7 +2027,7 @@ def plot_tree_distance_graphs(
20252027
corres1,
20262028
corres2,
20272029
delta_tmp,
2028-
norm_dict[norm],
2030+
self.norm_dict[norm],
20292031
tree1.get_norm(node_1),
20302032
tree2.get_norm(node_2),
20312033
)
@@ -2056,7 +2058,7 @@ def plot_tree_distance_graphs(
20562058
corres1,
20572059
corres2,
20582060
delta_tmp,
2059-
norm_dict[norm],
2061+
self.norm_dict[norm],
20602062
tree1.get_norm(node_1),
20612063
tree2.get_norm(node_2),
20622064
)
@@ -2104,7 +2106,7 @@ def labelled_mappings(
21042106
"simple", "normalized_simple", "full", "downsampled", "mini"
21052107
] = "simple",
21062108
downsample: int = 2,
2107-
) -> dict[str, list]:
2109+
) -> dict[str, list[str]]:
21082110
"""
21092111
Returns the labels or IDs of all the nodes in the subtrees compared.
21102112
@@ -2128,8 +2130,9 @@ def labelled_mappings(
21282130
21292131
Returns
21302132
-------
2131-
Alignment
2132-
The alignment between the nodes of of the subtrees spawned by the nodes n1,n2
2133+
dict mapping str to list[str]
2134+
- 'matched' The labels of the matched nodes of the alignment.
2135+
- 'unmatched' The labels of the unmatched nodes of the alginment.
21332136
"""
21342137
parameters = (
21352138
end_time,
@@ -2154,8 +2157,8 @@ def labelled_mappings(
21542157
*_,
21552158
corres2,
21562159
) = tree2.edist
2157-
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
2158-
if norm not in norm_dict:
2160+
2161+
if norm not in self.norm_dict:
21592162
raise Warning(
21602163
"Select a viable normalization method (max, sum, None)"
21612164
)
@@ -2249,7 +2252,7 @@ def unordered_tree_edit_distance(
22492252
Returns
22502253
-------
22512254
float
2252-
The normed unordered tree edit distance between `n1` and `n2`
2255+
The normalized unordered tree edit distance between `n1` and `n2`
22532256
"""
22542257
parameters = (
22552258
end_time,
@@ -2284,16 +2287,16 @@ def unordered_tree_edit_distance(
22842287
times1=times1,
22852288
times2=times2,
22862289
)
2287-
norm_dict = {"max": max, "sum": sum, None: lambda x: 1}
2288-
if norm not in norm_dict:
2290+
2291+
if norm not in self.norm_dict:
22892292
raise ValueError(
22902293
"Select a viable normalization method (max, sum, None)"
22912294
)
22922295
cost = btrc.cost(nodes1, nodes2, delta_tmp)
22932296
norm_values = (tree1.get_norm(n1), tree2.get_norm(n2))
22942297
if return_norms:
22952298
return cost, norm_values
2296-
return cost / norm_dict[norm](norm_values)
2299+
return cost / self.norm_dict[norm](norm_values)
22972300

22982301
@staticmethod
22992302
def __plot_nodes(
@@ -2358,7 +2361,7 @@ def __plot_edges(
23582361
def draw_tree_graph(
23592362
self,
23602363
hier: dict[int, tuple[int, int]],
2361-
lnks_tms: dict,
2364+
lnks_tms: dict[str, dict[int, list | int]],
23622365
selected_nodes: list | set | None = None,
23632366
selected_edges: list | set | None = None,
23642367
color_of_nodes: str | dict = "magenta",
@@ -2375,10 +2378,9 @@ def draw_tree_graph(
23752378
----------
23762379
hier : dict mapping int to tuple of int
23772380
Dictionary that contains the positions of all nodes.
2378-
lnks_tms : dict, dict
2379-
2 dictionaries: 1 contains all links from start of life cycle to end of life cycle and
2380-
the succesors of each node.
2381-
1 contains the length of each life cycle.
2381+
lnks_tms : dict mapping string to dictionaries mapping int to list or int
2382+
- 'links' : conatains the hierarchy of the nodes (only start and end of each chain)
2383+
- 'times' : contains the distance between the start and the end of each chain.
23822384
selected_nodes : list or set, optional
23832385
Which nodes are to be selected (Painted with a different color)
23842386
selected_edges : list or set, optional
@@ -2389,7 +2391,7 @@ def draw_tree_graph(
23892391
Color of selected edges
23902392
size : int, default=10
23912393
Size of the nodes
2392-
lw : float, defaults to 0.3
2394+
lw : float, default=0.3
23932395
The width of the edges of the tree graph, by default 0.1
23942396
ax : plt.Axes, optional
23952397
Plot the graph on existing ax. Defaults to None.
@@ -2450,11 +2452,11 @@ def _create_dict_of_plots(
24502452
end_time: int | None = None,
24512453
) -> dict[int, dict]:
24522454
"""Generates a dictionary of graphs where the keys are the index of the graph and
2453-
the values are the graphs themselves which are produced by `create_links_and_cycles`
2455+
the values are the graphs themselves which are produced by `create_links_and_chains`
24542456
24552457
Parameters
24562458
----------
2457-
node : int|Iterable[int], optional
2459+
node : int or Iterable of int, optional
24582460
The id of the node/nodes to produce the simple graphs
24592461
start_time : int, optional
24602462
Important only if there are no nodes it will produce the graph of every
@@ -2465,7 +2467,7 @@ def _create_dict_of_plots(
24652467
24662468
Returns
24672469
-------
2468-
dict of int to dict
2470+
dict mapping int to dict
24692471
The keys are just index values 0-n and the values are the graphs produced.
24702472
"""
24712473
if start_time is None:
@@ -2481,7 +2483,7 @@ def _create_dict_of_plots(
24812483
else:
24822484
mothers = [node]
24832485
return {
2484-
i: create_links_and_cycles(self, mother, end_time=end_time)
2486+
i: create_links_and_chains(self, mother, end_time=end_time)
24852487
for i, mother in enumerate(mothers)
24862488
}
24872489

@@ -2521,7 +2523,7 @@ def plot_all_lineages(
25212523
vert_gap : int, default=1
25222524
space between the nodes.
25232525
**kwargs:
2524-
args accepted by matplotlib
2526+
kwargs accepted by matplotlib
25252527
25262528
Returns
25272529
-------
@@ -2656,10 +2658,12 @@ def plot_subtree(
26562658
26572659
Returns
26582660
-------
2659-
plt.Figure
2660-
The figure
2661-
plt.Axes
2662-
The axes
2661+
tuple (plt.Figure, plt.Axes)
2662+
tuple with:
2663+
plt.Figure
2664+
The matplotlib figure
2665+
plt.Axes
2666+
The matplotlib axes
26632667
Raises
26642668
------
26652669
Warning
@@ -3213,7 +3217,7 @@ def plot_dtw_trajectory(
32133217
-------
32143218
float
32153219
DTW distance
3216-
figue
3220+
figure
32173221
Trajectories Plot
32183222
"""
32193223
(

0 commit comments

Comments
 (0)