diff --git a/src/lineagetree/plot.py b/src/lineagetree/plot.py index f644b6f..3693336 100644 --- a/src/lineagetree/plot.py +++ b/src/lineagetree/plot.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import warnings from collections.abc import Iterable from typing import Literal, TYPE_CHECKING @@ -20,20 +21,24 @@ def __plot_nodes( color: str | dict | list, size: int | float, ax: plt.Axes, + leaves: set, default_color: str = "black", **kwargs, ) -> None: """ Private method that plots the nodes of the tree. """ - + hier_no_leaves = copy.copy(hier) + for leaf in leaves: + hier_no_leaves.pop(leaf, None) if isinstance(color, dict): - color = [color.get(k, default_color) for k in hier] + color = [color.get(k, default_color) for k in hier_no_leaves] elif isinstance(color, str | list): color = [ - color if node in selected_nodes else default_color for node in hier + color if node in selected_nodes else default_color + for node in hier_no_leaves ] - hier_pos = np.array(list(hier.values())) + hier_pos = np.array(list(hier_no_leaves.values())) ax.scatter(*hier_pos.T, s=size, zorder=10, color=color, **kwargs) @@ -142,6 +147,7 @@ def draw_tree_graph( size=size, ax=ax, default_color=default_color, + leaves={k for k, v in lnks_tms["links"].items() if not v}, **kwargs, ) if not color_of_edges: