Skip to content

Commit 499b26f

Browse files
authored
Merge pull request #15 from funkelab/add_motile_plots
Add plotly plotting from motile
2 parents f3c353b + d11b564 commit 499b26f

File tree

6 files changed

+439
-0
lines changed

6 files changed

+439
-0
lines changed

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ ignore_missing_imports = True
1919

2020
[mypy-scipy.*]
2121
ignore_missing_imports = True
22+
23+
[mypy-plotly.*]
24+
ignore_missing_imports = True

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dev = [
3535
'pre-commit',
3636
'types-tqdm',
3737
'pytest-unordered',
38+
'plotly',
3839
'ruff',
3940
]
4041

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .napari_utils import to_napari_tracks_layer
2+
from .plot_motile_graphs import draw_solution, draw_track_graph
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Callable, Mapping, overload
4+
5+
import numpy as np
6+
7+
try:
8+
import plotly.graph_objects as go
9+
except ImportError as e: # pragma: no cover
10+
raise ImportError(
11+
"This functionality requires the plotly package. Please install plotly."
12+
) from e
13+
14+
from motile.variables import EdgeSelected, NodeSelected
15+
16+
if TYPE_CHECKING:
17+
from motile import Solver, TrackGraph
18+
from motile._types import EdgeId, NodeId
19+
20+
Color = tuple[int, int, int]
21+
ReturnsFloat = Callable[[Any], float]
22+
ReturnsStr = Callable[[Any], str]
23+
24+
PURPLE = (127, 30, 121)
25+
26+
27+
def _attr_hover_text(attrs: Mapping) -> str:
28+
return "<br>".join([f"{name}: {value}" for name, value in attrs.items()])
29+
30+
31+
def draw_track_graph(
32+
graph: TrackGraph,
33+
position_attribute: str | None = None,
34+
position_func: ReturnsFloat | None = None,
35+
alpha_attribute: str | None = None,
36+
alpha_func: ReturnsFloat | tuple[ReturnsFloat, ReturnsFloat] | None = None,
37+
label_attribute: str | None = None,
38+
label_func: ReturnsStr | tuple[ReturnsStr, ReturnsStr] | None = None,
39+
node_size: float = 30,
40+
node_color: Color = PURPLE,
41+
edge_color: Color = PURPLE,
42+
width: int = 660,
43+
height: int = 400,
44+
) -> go.Figure:
45+
"""Create a plotly figure showing the given graph.
46+
47+
Time is shown on the x-axis and node positions on the y-axis.
48+
49+
Args:
50+
graph:
51+
The :class:`~motile.TrackGraph` to plot.
52+
53+
position_attribute (str):
54+
The name of the node attribute to use to place nodes on the y-axis.
55+
56+
position_func (callable):
57+
A function returning the position of a given node on the y-axis.
58+
59+
alpha_attribute (str):
60+
The name of a node or edge attribute to use for the transparency.
61+
62+
alpha_func (callable):
63+
A function returning the alpha value to use for each node or edge.
64+
Can be a tuple for node and edge functions, respectively.
65+
66+
label_attribute (str):
67+
The name of a node or edge attribute to use for a text label.
68+
69+
label_func (callable):
70+
A function returning the label to use for each node or edge. Can be
71+
a tuple for node and edge functions, respectively.
72+
73+
node_size (float):
74+
The size of nodes.
75+
76+
node_color (tuple[int, ...]):
77+
The RGB color to use for nodes.
78+
79+
edge_color (tuple[int, ...]):
80+
The RGB color to use for edges.
81+
82+
width (int):
83+
The width of the plot, in pixels. Default: 660.
84+
85+
height (int):
86+
The height of the plot, in pixels. Default: 400.
87+
88+
Returns:
89+
:class:`plotly.graph_objects.Figure` showing the graph.
90+
"""
91+
if position_attribute is not None and position_func is not None:
92+
raise RuntimeError(
93+
"Only one of position_attribute and position_func can be given"
94+
)
95+
if alpha_attribute is not None and alpha_func is not None:
96+
raise RuntimeError("Only one of alpha_attribute and alpha_func can be given")
97+
if label_attribute is not None and label_func is not None:
98+
raise RuntimeError("Only one of label_attribute and label_func can be given")
99+
100+
if position_attribute is None:
101+
position_attribute = "x"
102+
103+
if position_func is None:
104+
105+
def position_func(node: NodeId) -> float:
106+
return float(graph.nodes[node][position_attribute])
107+
108+
alpha_node_func: ReturnsFloat
109+
alpha_edge_func: ReturnsFloat
110+
label_node_func: ReturnsStr
111+
label_edge_func: ReturnsStr
112+
113+
if alpha_attribute is not None:
114+
115+
def alpha_node_func(node):
116+
return graph.nodes[node].get(alpha_attribute, 1.0)
117+
118+
def alpha_edge_func(edge):
119+
return graph.edges[edge].get(alpha_attribute, 1.0)
120+
121+
elif alpha_func is None:
122+
123+
def alpha_node_func(_):
124+
return 1.0
125+
126+
def alpha_edge_func(_):
127+
return 1.0
128+
129+
elif isinstance(alpha_func, tuple):
130+
alpha_node_func, alpha_edge_func = alpha_func
131+
else:
132+
alpha_node_func = alpha_func
133+
alpha_edge_func = alpha_func
134+
135+
if label_attribute is not None:
136+
137+
def label_node_func(node):
138+
return graph.nodes[node].get(label_attribute, "")
139+
140+
def label_edge_func(edge):
141+
return graph.edges[edge].get(label_attribute, "")
142+
143+
elif label_func is None:
144+
145+
def label_node_func(node):
146+
return str(node)
147+
148+
def label_edge_func(edge):
149+
return str(edge)
150+
151+
elif isinstance(label_func, tuple):
152+
label_node_func, label_edge_func = label_func
153+
else:
154+
label_node_func = label_func
155+
label_edge_func = label_func
156+
157+
frame_attribute = graph.frame_attribute
158+
# (get_frames() will return a tuple including None if the graph has no nodes)
159+
frames = list(range(*graph.get_frames())) # type: ignore
160+
161+
node_positions = np.asarray(
162+
[
163+
(attrs[frame_attribute], position_func(node))
164+
for node, attrs in sorted(graph.nodes.items())
165+
]
166+
)
167+
node_alphas: list[float] = [alpha_node_func(node) for node in graph.nodes]
168+
edge_alphas: list[float] = [alpha_edge_func(edge) for edge in graph.edges]
169+
# can be a list for different colors per node/edge
170+
node_colors = _to_rgba(node_color, node_alphas)
171+
edge_colors = _to_rgba(edge_color, edge_alphas)
172+
173+
node_labels = [str(label_node_func(node)) for node in graph.nodes]
174+
edge_labels = [str(label_edge_func(edge)) for edge in graph.edges]
175+
176+
fig = go.Figure()
177+
178+
node_trace = go.Scatter(
179+
x=node_positions[:, 0],
180+
y=node_positions[:, 1],
181+
mode="markers+text",
182+
marker={"color": node_colors, "size": node_size},
183+
text=node_labels,
184+
textfont={"color": "white"},
185+
hoverinfo="text",
186+
hovertext=[_attr_hover_text(attrs) for attrs in graph.nodes.values()],
187+
)
188+
189+
fig.add_trace(node_trace)
190+
191+
fig.update_layout(
192+
xaxis={
193+
"tickmode": "linear",
194+
"tick0": min(frames),
195+
"dtick": 1,
196+
"title": "time",
197+
},
198+
yaxis={
199+
"title": "space",
200+
},
201+
showlegend=False,
202+
margin={
203+
"t": 0,
204+
"b": 0,
205+
"l": 0,
206+
"r": 0,
207+
},
208+
modebar={
209+
"remove": [
210+
"lasso",
211+
"pan",
212+
"select",
213+
"autoscale",
214+
"zoomin",
215+
"zoomout",
216+
"resetscale",
217+
]
218+
},
219+
width=width,
220+
height=height,
221+
)
222+
223+
arrows = []
224+
for ((u, v), attrs), label, color in zip(
225+
graph.edges.items(), edge_labels, edge_colors
226+
):
227+
start = node_positions[sorted(graph.nodes).index(u), (0, 1)]
228+
end = node_positions[sorted(graph.nodes).index(v), (0, 1)]
229+
mid = 0.6 * start + 0.4 * end
230+
first_half = go.layout.Annotation(
231+
dict(
232+
ax=start[0],
233+
ay=start[1],
234+
x=mid[0],
235+
y=mid[1],
236+
xref="x",
237+
yref="y",
238+
showarrow=True,
239+
startstandoff=node_size * 0.5,
240+
axref="x",
241+
ayref="y",
242+
arrowhead=0,
243+
arrowwidth=4,
244+
arrowcolor=color,
245+
)
246+
)
247+
second_half = go.layout.Annotation(
248+
dict(
249+
ax=mid[0],
250+
ay=mid[1],
251+
x=end[0],
252+
y=end[1],
253+
xref="x",
254+
yref="y",
255+
text=label,
256+
font={"color": "white"},
257+
hovertext=_attr_hover_text(attrs),
258+
bgcolor=color,
259+
showarrow=True,
260+
standoff=node_size * 0.6,
261+
axref="x",
262+
ayref="y",
263+
arrowhead=2,
264+
arrowwidth=4,
265+
arrowsize=0.6,
266+
arrowcolor=color,
267+
)
268+
)
269+
270+
arrows.append(first_half)
271+
arrows.append(second_half)
272+
273+
fig.update_layout(annotations=arrows)
274+
275+
return fig
276+
277+
278+
def draw_solution(
279+
graph: TrackGraph, solver: Solver, *args: Any, **kwargs: Any
280+
) -> go.Figure:
281+
"""Draw ``graph`` with the current ``solver.solution`` highlighted.
282+
283+
This is a wrapper around :func:`draw_track_graph` highlighting the solution found
284+
by the given solver.
285+
286+
Args:
287+
graph (:class:`TrackGraph`):
288+
The graph to plot.
289+
290+
solver :class:`Solver`):
291+
The solver that was used to find the solution.
292+
293+
*args:
294+
Pass-through arguments to :func:`draw_track_graph`.
295+
296+
**kwargs:
297+
Pass-through keyword arguments to :func:`draw_track_graph`.
298+
299+
Returns:
300+
``plotly`` figure showing the graph.
301+
"""
302+
solution = solver.solution
303+
if solution is None:
304+
raise RuntimeError("Solver has no solution. Call solve() first.")
305+
306+
node_indicators = solver.get_variables(NodeSelected)
307+
edge_indicators = solver.get_variables(EdgeSelected)
308+
309+
def node_alpha_func(node: NodeId) -> float:
310+
return solution[node_indicators[node]] # type: ignore
311+
312+
def edge_alpha_func(edge: EdgeId) -> float:
313+
return solution[edge_indicators[edge]] # type: ignore
314+
315+
kwargs["alpha_func"] = (node_alpha_func, edge_alpha_func)
316+
return draw_track_graph(graph, *args, **kwargs)
317+
318+
319+
@overload
320+
def _to_rgba(color: list[Color], alpha: float | list[float] = 1.0) -> list[str]: ...
321+
322+
323+
@overload
324+
def _to_rgba(color: Color, alpha: float | list[float] = 1.0) -> str: ...
325+
326+
327+
def _to_rgba(
328+
color: Color | list[Color], alpha: float | list[float] = 1.0
329+
) -> str | list[str]:
330+
"""Convert a color to a rgba string."""
331+
if isinstance(color, list):
332+
if isinstance(alpha, list):
333+
return [_to_rgba(c, a) for c, a in zip(color, alpha)]
334+
else: # only color is list
335+
return [_to_rgba(c, alpha) for c in color]
336+
elif isinstance(alpha, list): # only alpha is list
337+
return [_to_rgba(color, a) for a in alpha]
338+
339+
# we fake alpha by mixing with white(ish)
340+
# transparency is tricky...
341+
r, g, b = tuple(int(c * alpha + 220 * (1.0 - alpha)) for c in color)
342+
return f"rgb({r},{g},{b})"

0 commit comments

Comments
 (0)