Skip to content

Commit 312cc01

Browse files
author
Julien Boussard
committed
multi features scatter
1 parent 4aec4bf commit 312cc01

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

src/spike_psvae/cluster_viz.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,139 @@ def cluster_scatter(
107107
ell.set_transform(transform + ax.transData)
108108
ax.add_patch(ell)
109109

110+
def array_scatter_5_features(
111+
labels,
112+
geom,
113+
x,
114+
z,
115+
maxptp,
116+
trough_val,
117+
trough_time,
118+
tip_val,
119+
zlim=(-50, 3900),
120+
xlim=None,
121+
ptplim=None,
122+
maxptp_c=None,
123+
title=None,
124+
axes=None,
125+
do_ellipse=True,
126+
figsize = (15, 15),
127+
s_dot=1,
128+
s_size_geom = 3,
129+
):
130+
fig = None
131+
if axes is None:
132+
fig, axes = plt.subplots(1, 6, sharey=True, figsize=figsize)
133+
134+
if title is not None:
135+
fig.suptitle(title)
136+
137+
excluded_ids = {-1}
138+
if not do_ellipse:
139+
excluded_ids = np.unique(labels)
140+
141+
cluster_scatter(
142+
x,
143+
z,
144+
labels,
145+
ax=axes[0],
146+
s=s_dot,
147+
alpha=0.1,
148+
excluded_ids=excluded_ids,
149+
do_ellipse=do_ellipse,
150+
)
151+
axes[0].scatter(*geom.T, c="orange", marker="s", s=s_size_geom)
152+
axes[0].scatter(geom[0, 0], geom[0, 1], c="orange", marker="s", s=10, label='Channel Locations')
153+
axes[0].set_ylabel("Registered Depth (um)", fontsize=14)
154+
axes[0].set_xlabel("x (um)", fontsize=14)
155+
axes[0].legend(fontsize=14, loc='upper left')
156+
axes[0].tick_params(axis='x', labelsize=14)
157+
axes[0].tick_params(axis='y', labelsize=14)
158+
159+
if maxptp_c is None:
160+
maxptp_c = np.clip(maxptp, 3, 15)
161+
axes[1].scatter(
162+
x,
163+
z,
164+
c=maxptp_c,
165+
s=s_dot,
166+
alpha=0.1,
167+
marker=".",
168+
cmap=plt.cm.jet,
169+
)
170+
axes[1].scatter(*geom.T, c="orange", marker="s", s=s_size_geom)
171+
axes[1].set_title("colored by ptps")
172+
173+
cluster_scatter(
174+
maxptp,
175+
z,
176+
labels,
177+
ax=axes[2],
178+
s=s_dot,
179+
alpha=0.1,
180+
excluded_ids=excluded_ids,
181+
do_ellipse=do_ellipse,
182+
)
183+
axes[2].set_xlabel("Amplitude (s.u.)", fontsize=16)
184+
axes[2].tick_params(axis='x', labelsize=16)
185+
186+
cluster_scatter(
187+
trough_val,
188+
z,
189+
labels,
190+
ax=axes[3],
191+
s=s_dot,
192+
alpha=0.1,
193+
excluded_ids=excluded_ids,
194+
do_ellipse=do_ellipse,
195+
)
196+
axes[3].set_xlabel("Trough Val (s.u.)", fontsize=16)
197+
axes[3].tick_params(axis='x', labelsize=16)
198+
199+
cluster_scatter(
200+
trough_time,
201+
z,
202+
labels,
203+
ax=axes[4],
204+
s=s_dot,
205+
alpha=0.1,
206+
excluded_ids=excluded_ids,
207+
do_ellipse=do_ellipse,
208+
)
209+
axes[4].set_xlabel("Trough Time (1/30ms)", fontsize=16)
210+
axes[4].tick_params(axis='x', labelsize=16)
211+
212+
cluster_scatter(
213+
tip_val,
214+
z,
215+
labels,
216+
ax=axes[5],
217+
s=s_dot,
218+
alpha=0.1,
219+
excluded_ids=excluded_ids,
220+
do_ellipse=do_ellipse,
221+
)
222+
axes[5].set_xlabel("Tip Val (s.u.)", fontsize=16)
223+
axes[5].tick_params(axis='x', labelsize=16)
224+
225+
axes[0].set_ylim(zlim)
226+
axes[1].set_ylim(zlim)
227+
axes[2].set_ylim(zlim)
228+
axes[3].set_ylim(zlim)
229+
axes[4].set_ylim(zlim)
230+
axes[5].set_ylim(zlim)
231+
232+
if xlim is not None:
233+
axes[0].set_xlim(xlim)
234+
axes[1].set_xlim(xlim)
235+
if ptplim is not None:
236+
axes[2].set_xlim(ptplim)
237+
238+
# if fig is not None:
239+
# plt.tight_layout()
240+
241+
return fig, axes
242+
110243

111244
# %%
112245
def array_scatter(

0 commit comments

Comments
 (0)