Skip to content

Commit 02605ec

Browse files
committed
draft for plotter class
1 parent 66eb54e commit 02605ec

File tree

1 file changed

+89
-1
lines changed

1 file changed

+89
-1
lines changed

navis/plotting/pygfx/objects.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""Experimental plotting module using pygfx."""
1818

1919
import uuid
20+
import warnings
2021

2122
import pandas as pd
2223
import numpy as np
@@ -37,6 +38,91 @@
3738
logger = config.get_logger(__name__)
3839

3940

41+
class Plotter:
42+
allowed_kwargs = []
43+
44+
def __init__(self, **kwargs):
45+
self.parse_kwargs(**kwargs)
46+
47+
def __call__(self):
48+
"""Plot objects."""
49+
return self.plot()
50+
51+
def parse_kwargs(self, **kwargs):
52+
"""Parse kwargs."""
53+
# Check for invalid kwargs
54+
invalid_kwargs = list(kwargs)
55+
for k in self.allowed_kwargs:
56+
# If this is a tuple of possible kwargs (e.g. "lw" or "linewidth")
57+
if isinstance(k, (tuple, list, set)):
58+
# Check if we have multiple kwargs for the same thing
59+
if len([kk for kk in k if kk in invalid_kwargs]) > 1:
60+
raise ValueError(f'Please use only one of "{k}"')
61+
62+
for kk in k:
63+
if kk in invalid_kwargs:
64+
# Make sure we always use the first kwarg
65+
kwargs[k[0]] = kwargs.pop(kk)
66+
invalid_kwargs.pop(kk)
67+
else:
68+
if k in invalid_kwargs:
69+
invalid_kwargs.pop(k)
70+
71+
if len(invalid_kwargs):
72+
warnings.warn(
73+
f"Unknown kwargs for {self.backend} backend: {', '.join([f'{k}' for k in invalid_kwargs])}"
74+
)
75+
76+
self.kwargs = kwargs
77+
78+
def add_objects(self, x):
79+
"""Add objects to the plot."""
80+
(neurons, volumes, points, visual) = utils.parse_objects(x)
81+
82+
def plot(self):
83+
"""Plot objects."""
84+
# Generate
85+
colors = self.kwargs.get('color', None)
86+
palette = self.kwargs.get("palette", None)
87+
88+
neuron_cmap, volumes_cmap = prepare_colormap(
89+
colors,
90+
neurons=self.neurons,
91+
volumes=self.volumes,
92+
palette=palette,
93+
clusters=self.kwargs.get("clusters", None),
94+
alpha=self.kwargs.get("alpha", None),
95+
color_range=255,
96+
)
97+
98+
99+
class GfxPlotter(Plotter):
100+
allowed_kwargs = {
101+
("color", "c", "colors"),
102+
"cn_colors",
103+
("linewidth", "lw"),
104+
"scatter_kws",
105+
"synapse_layout",
106+
"dps_scale_vec",
107+
"width",
108+
"height",
109+
"alpha",
110+
"radius",
111+
"soma",
112+
"connectors",
113+
"connectors_only",
114+
"palette",
115+
"color_by",
116+
"shade_by",
117+
"vmin",
118+
"vmax",
119+
"smin",
120+
"smax",
121+
"volume_legend",
122+
}
123+
backend = "pygfx"
124+
125+
40126
def volume2gfx(x, **kwargs):
41127
"""Convert Volume(s) to pygfx visuals."""
42128
# Must not use make_iterable here as this will turn into list of keys!
@@ -283,7 +369,9 @@ def connectors2gfx(neuron, neuron_color, object_id, **kwargs):
283369

284370
# Zip coordinates and add a row of NaNs to indicate breaks in the
285371
# segments
286-
coords = np.hstack((pos, tn_coords, np.full(pos.shape, fill_value=np.nan))).reshape((len(pos) * 3, 3))
372+
coords = np.hstack(
373+
(pos, tn_coords, np.full(pos.shape, fill_value=np.nan))
374+
).reshape((len(pos) * 3, 3))
287375
coords = coords.astype(np.float32, copy=False)
288376

289377
# Create line plot from segments

0 commit comments

Comments
 (0)