Skip to content

Commit dda7911

Browse files
committed
Analysis and unit plot improvements; ccg with neighbors
1 parent fa88bf9 commit dda7911

File tree

2 files changed

+128
-38
lines changed

2 files changed

+128
-38
lines changed

src/dartsort/util/analysis.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@
2626
from ..templates import TemplateData
2727
from ..transform import WaveformPipeline
2828
from .data_util import DARTsortSorting
29-
from .drift_util import (get_spike_pitch_shifts,
30-
get_waveforms_on_static_channels, registered_average)
29+
from .drift_util import (
30+
get_spike_pitch_shifts,
31+
get_waveforms_on_static_channels,
32+
registered_average,
33+
)
3134
from .spikeio import read_waveforms_channel_index
3235
from .waveform_util import make_channel_index
3336

@@ -307,8 +310,10 @@ def show_geom(self):
307310
show_geom = self.recording.get_channel_locations()
308311
return show_geom
309312

310-
def show_channel_index(self, channel_channel_show_radius_um=50):
311-
return make_channel_index(self.show_geom, channel_channel_show_radius_um)
313+
def show_channel_index(self, channel_show_radius_um=50, channel_dist_p=np.inf):
314+
return make_channel_index(
315+
self.show_geom, channel_show_radius_um, p=channel_dist_p
316+
)
312317

313318
# spike feature loading methods
314319

@@ -369,6 +374,7 @@ def unit_raw_waveforms(
369374
channel_show_radius_um=75,
370375
trough_offset_samples=42,
371376
spike_length_samples=121,
377+
channel_dist_p=np.inf,
372378
relocated=False,
373379
):
374380
if which is None:
@@ -389,8 +395,9 @@ def unit_raw_waveforms(
389395
if self.shifting:
390396
load_ci = self.channel_index
391397
else:
392-
load_ci = make_channel_index(
393-
self.recording.get_channel_locations(), channel_show_radius_um
398+
load_ci = self.show_channel_index(
399+
channel_show_radius_um=channel_show_radius_um,
400+
channel_dist_p=channel_dist_p,
394401
)
395402
waveforms = read_waveforms_channel_index(
396403
self.recording,
@@ -415,6 +422,7 @@ def unit_raw_waveforms(
415422
waveforms,
416423
load_ci,
417424
channel_show_radius_um=channel_show_radius_um,
425+
channel_dist_p=channel_dist_p,
418426
relocated=relocated,
419427
)
420428
return which, waveforms, max_chan, show_geom, show_channel_index
@@ -441,8 +449,14 @@ def unit_tpca_waveforms(
441449

442450
tpca_embeds = self.tpca_features(which=which)
443451
n, rank, c = tpca_embeds.shape
444-
waveforms = tpca_embeds.transpose(0, 2, 1).reshape(n * c, rank)
445-
waveforms = self.sklearn_tpca.inverse_transform(waveforms)
452+
tpca_embeds = tpca_embeds.transpose(0, 2, 1).reshape(n * c, rank)
453+
waveforms = np.full(
454+
(n * c, self.sklearn_tpca.components_.shape[1]),
455+
np.nan,
456+
dtype=tpca_embeds.dtype,
457+
)
458+
valid = np.flatnonzero(np.isfinite(tpca_embeds[:, 0]))
459+
waveforms[valid] = self.sklearn_tpca.inverse_transform(tpca_embeds[valid])
446460
t = waveforms.shape[1]
447461
waveforms = waveforms.reshape(n, c, t).transpose(0, 2, 1)
448462

@@ -501,12 +515,15 @@ def unit_shift_or_relocate_channels(
501515
waveforms,
502516
load_channel_index,
503517
channel_show_radius_um=75,
518+
channel_dist_p=np.inf,
504519
relocated=False,
505520
):
506521
geom = self.recording.get_channel_locations()
507-
show_geom = self.template_data.registered_geom
508-
if show_geom is None:
509-
show_geom = geom
522+
show_geom = self.show_geom
523+
show_channel_index = self.show_channel_index(
524+
channel_show_radius_um=channel_show_radius_um, channel_dist_p=channel_dist_p
525+
)
526+
510527
temp = self.coarse_template_data.unit_templates(unit_id)
511528
n_pitches_shift = None
512529
if temp.shape[0]:
@@ -532,7 +549,7 @@ def unit_shift_or_relocate_channels(
532549
else:
533550
amp_template = np.nanmean(amps, axis=0)
534551
max_chan = np.nanargmax(amp_template)
535-
show_channel_index = make_channel_index(show_geom, channel_show_radius_um)
552+
536553
show_chans = show_channel_index[max_chan]
537554
show_chans = show_chans[show_chans < len(show_geom)]
538555
show_channel_index = np.broadcast_to(
@@ -580,7 +597,9 @@ def nearby_coarse_templates(self, unit_id, n_neighbors=5):
580597
unit_ix = np.searchsorted(self.unit_ids, unit_id)
581598
unit_dists = self.merge_dist[unit_ix]
582599
distance_order = np.argsort(unit_dists)
583-
distance_order = np.concatenate(([unit_ix], distance_order[distance_order != unit_ix]))
600+
distance_order = np.concatenate(
601+
([unit_ix], distance_order[distance_order != unit_ix])
602+
)
584603
# assert distance_order[0] == unit_ix
585604
neighbor_ixs = distance_order[:n_neighbors]
586605
neighbor_ids = self.unit_ids[neighbor_ixs]

src/dartsort/vis/unit.py

Lines changed: 96 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class UnitPlot:
2828
width = 1
2929
height = 1
3030

31-
def draw(self, axis, sorting_analysis, unit_id):
31+
def draw(self, panel, sorting_analysis, unit_id):
3232
raise NotImplementedError
3333

3434
def notify_global_params(self, **params):
@@ -47,7 +47,8 @@ class TextInfo(UnitPlot):
4747
kind = "text"
4848
height = 0.5
4949

50-
def draw(self, axis, sorting_analysis, unit_id):
50+
def draw(self, panel, sorting_analysis, unit_id):
51+
axis = panel.subplots()
5152
axis.axis("off")
5253
msg = f"unit {unit_id}\n"
5354

@@ -77,12 +78,13 @@ class ACG(UnitPlot):
7778
def __init__(self, max_lag=50):
7879
self.max_lag = max_lag
7980

80-
def draw(self, axis, sorting_analysis, unit_id):
81+
def draw(self, panel, sorting_analysis, unit_id):
82+
axis = panel.subplots()
8183
times_samples = sorting_analysis.times_samples(
8284
which=sorting_analysis.in_unit(unit_id)
8385
)
8486
lags, acg = correlogram(times_samples, max_lag=self.max_lag)
85-
axis.bar(lags, acg)
87+
bar(axis, lags, acg, fill=True, color="k")
8688
axis.set_xlabel("lag (samples)")
8789
axis.set_ylabel("acg")
8890

@@ -95,7 +97,8 @@ def __init__(self, bin_ms=0.1, max_ms=5):
9597
self.bin_ms = bin_ms
9698
self.max_ms = max_ms
9799

98-
def draw(self, axis, sorting_analysis, unit_id):
100+
def draw(self, panel, sorting_analysis, unit_id):
101+
axis = panel.subplots()
99102
times_s = sorting_analysis.times_seconds(
100103
which=sorting_analysis.in_unit(unit_id)
101104
)
@@ -108,7 +111,7 @@ def draw(self, axis, sorting_analysis, unit_id):
108111
# counts, _ = np.histogram(dt_ms, bin_edges)
109112
# bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
110113
# axis.bar(bin_centers, counts)
111-
plt.hist(dt_ms, bin_edges)
114+
plt.hist(dt_ms, bin_edges, color="k")
112115
axis.set_xlabel("isi (ms)")
113116
axis.set_ylabel(f"count (out of {dt_ms.size} total isis)")
114117

@@ -130,7 +133,8 @@ def __init__(
130133
self.probe_margin_um = probe_margin_um
131134
self.colorbar = colorbar
132135

133-
def draw(self, axis, sorting_analysis, unit_id):
136+
def draw(self, panel, sorting_analysis, unit_id):
137+
axis = panel.subplots()
134138
in_unit = sorting_analysis.in_unit(unit_id)
135139
x = sorting_analysis.x(which=in_unit)
136140
z = sorting_analysis.z(which=in_unit, registered=self.registered)
@@ -160,15 +164,19 @@ class PCAScatter(UnitPlot):
160164
kind = "scatter"
161165

162166
def __init__(
163-
self, relocate_amplitudes=False, relocated=True, amplitude_color_cutoff=15,
164-
colorbar=False
167+
self,
168+
relocate_amplitudes=False,
169+
relocated=True,
170+
amplitude_color_cutoff=15,
171+
colorbar=False,
165172
):
166173
self.relocated = relocated
167174
self.relocate_amplitudes = relocate_amplitudes
168175
self.amplitude_color_cutoff = amplitude_color_cutoff
169176
self.colorbar = colorbar
170177

171-
def draw(self, axis, sorting_analysis, unit_id):
178+
def draw(self, panel, sorting_analysis, unit_id):
179+
axis = panel.subplots()
172180
which, loadings = sorting_analysis.unit_pca_features(
173181
unit_id=unit_id, relocated=self.relocated
174182
)
@@ -205,7 +213,8 @@ def __init__(
205213
self.amplitude_color_cutoff = amplitude_color_cutoff
206214
self.probe_margin_um = probe_margin_um
207215

208-
def draw(self, axis, sorting_analysis, unit_id):
216+
def draw(self, panel, sorting_analysis, unit_id):
217+
axis = panel.subplots()
209218
in_unit = sorting_analysis.in_unit(unit_id)
210219
t = sorting_analysis.times_seconds(which=in_unit)
211220
z = sorting_analysis.z(which=in_unit, registered=self.registered)
@@ -245,7 +254,8 @@ def __init__(
245254
self.amplitude_color_cutoff = amplitude_color_cutoff
246255
self.color_by_amplitude = color_by_amplitude
247256

248-
def draw(self, axis, sorting_analysis, unit_id):
257+
def draw(self, panel, sorting_analysis, unit_id):
258+
axis = panel.subplots()
249259
in_unit = sorting_analysis.in_unit(unit_id)
250260
t = sorting_analysis.times_seconds(which=in_unit)
251261
feat = sorting_analysis.named_feature(self.feat_name, which=in_unit)
@@ -271,7 +281,8 @@ def __init__(self, relocate_amplitudes=False, amplitude_color_cutoff=15):
271281
self.relocate_amplitudes = relocate_amplitudes
272282
self.amplitude_color_cutoff = amplitude_color_cutoff
273283

274-
def draw(self, axis, sorting_analysis, unit_id):
284+
def draw(self, panel, sorting_analysis, unit_id):
285+
axis = panel.subplots()
275286
in_unit = sorting_analysis.in_unit(unit_id)
276287
t = sorting_analysis.times_seconds(which=in_unit)
277288
amps = sorting_analysis.amplitudes(
@@ -327,7 +338,8 @@ def __init__(
327338
def get_waveforms(self, sorting_analysis, unit_id):
328339
raise NotImplementedError
329340

330-
def draw(self, axis, sorting_analysis, unit_id):
341+
def draw(self, panel, sorting_analysis, unit_id):
342+
axis = panel.subplots()
331343
which, waveforms, max_chan, geom, ci = self.get_waveforms(
332344
sorting_analysis, unit_id
333345
)
@@ -353,7 +365,7 @@ def draw(self, axis, sorting_analysis, unit_id):
353365
new_offset=self.trough_offset_samples,
354366
new_length=self.spike_length_samples,
355367
)
356-
max_abs_amp = self.max_abs_template_scale * np.abs(templates).max()
368+
max_abs_amp = self.max_abs_template_scale * np.nanmax(np.abs(templates))
357369
show_superres_templates = (
358370
self.show_superres_templates and self.template_index is None
359371
)
@@ -368,7 +380,7 @@ def draw(self, axis, sorting_analysis, unit_id):
368380
new_length=self.spike_length_samples,
369381
)
370382
show_superres_templates = suptemplates.shape[0] > 1
371-
max_abs_amp = self.max_abs_template_scale * np.abs(suptemplates).max()
383+
max_abs_amp = self.max_abs_template_scale * np.nanmax(np.abs(suptemplates))
372384

373385
ls = geomplot(
374386
waveforms,
@@ -434,6 +446,7 @@ def draw(self, axis, sorting_analysis, unit_id):
434446
reg_str = "registered " * sorting_analysis.shifting
435447
axis.set_ylabel(reg_str + "depth (um)")
436448
axis.set_xticks([])
449+
axis.set_yticks([])
437450

438451
if self.legend:
439452
axis.legend(
@@ -487,7 +500,8 @@ def __init__(self, channel_show_radius_um=50, n_neighbors=5, legend=True):
487500
self.n_neighbors = n_neighbors
488501
self.legend = legend
489502

490-
def draw(self, axis, sorting_analysis, unit_id):
503+
def draw(self, panel, sorting_analysis, unit_id):
504+
axis = panel.subplots()
491505
(
492506
neighbor_ids,
493507
neighbor_dists,
@@ -529,7 +543,7 @@ def draw(self, axis, sorting_analysis, unit_id):
529543
)
530544
labels.append(str(uid))
531545
handles.append(lines[0])
532-
axis.legend(handles=handles, labels=labels, fancybox=False)
546+
axis.legend(handles=handles, labels=labels, fancybox=False, loc="upper left")
533547
axis.set_xticks([])
534548
axis.set_yticks([])
535549
axis.set_title(self.title)
@@ -539,15 +553,18 @@ class CoarseTemplateDistancePlot(UnitPlot):
539553
title = "coarse template distance"
540554
kind = "neighbors"
541555
width = 2
542-
height = 2
556+
height = 1.25
543557

544-
def __init__(self, channel_show_radius_um=50, n_neighbors=5, dist_vmax=1.0, show_values=True):
558+
def __init__(
559+
self, channel_show_radius_um=50, n_neighbors=5, dist_vmax=1.0, show_values=True
560+
):
545561
self.channel_show_radius_um = channel_show_radius_um
546562
self.n_neighbors = n_neighbors
547563
self.dist_vmax = dist_vmax
548564
self.show_values = show_values
549565

550-
def draw(self, axis, sorting_analysis, unit_id):
566+
def draw(self, panel, sorting_analysis, unit_id):
567+
axis = panel.subplots()
551568
(
552569
neighbor_ids,
553570
neighbor_dists,
@@ -580,6 +597,52 @@ def draw(self, axis, sorting_analysis, unit_id):
580597
axis.set_title(self.title)
581598

582599

600+
class NeighborCCGPlot(UnitPlot):
601+
kind = "neighbors"
602+
width = 2
603+
height = 0.75
604+
605+
def __init__(self, n_neighbors=3, max_lag=50):
606+
self.n_neighbors = n_neighbors
607+
self.max_lag = max_lag
608+
609+
def draw(self, panel, sorting_analysis, unit_id):
610+
(
611+
neighbor_ids,
612+
neighbor_dists,
613+
neighbor_coarse_templates,
614+
) = sorting_analysis.nearby_coarse_templates(
615+
unit_id, n_neighbors=self.n_neighbors + 1
616+
)
617+
colors = np.array(cc.glasbey_light)[neighbor_ids % len(cc.glasbey_light)]
618+
# assert neighbor_ids[0] == unit_id
619+
neighbor_ids = neighbor_ids[1:]
620+
621+
my_st = sorting_analysis.times_samples(which=sorting_analysis.in_unit(unit_id))
622+
neighb_sts = [
623+
sorting_analysis.times_samples(which=sorting_analysis.in_unit(nid))
624+
for nid in neighbor_ids
625+
]
626+
ccgs = [correlogram(my_st, nst, max_lag=self.max_lag) for nst in neighb_sts]
627+
acgs = [correlogram(my_st, nst, max_lag=self.max_lag) for nst in neighb_sts]
628+
629+
axes = panel.subplots(
630+
nrows=2, sharey="row", sharex=True, ncols=self.n_neighbors
631+
)
632+
for j in range(self.n_neighbors):
633+
clags, ccg = correlogram(my_st, neighb_sts[j], max_lag=self.max_lag)
634+
merged_st = np.concatenate((my_st, neighb_sts[j]))
635+
merged_st.sort()
636+
alags, acg = correlogram(merged_st, max_lag=self.max_lag)
637+
638+
bar(axes[0, j], clags, ccg, fill=True, fc=colors[j]) # , ec="k", lw=1)
639+
bar(axes[1, j], alags, acg, fill=True, fc=colors[j]) # , ec="k", lw=1)
640+
axes[1, j].set_xlabel("lag (samples)")
641+
axes[0, j].set_title(f"unit {neighbor_ids[j]}")
642+
axes[0, 0].set_ylabel("ccg")
643+
axes[1, 0].set_ylabel("merged acg")
644+
645+
583646
# -- multi plots
584647
# these have multiple plots per unit, and we don't know in advance how many
585648
# for instance, making separate plots of spikes belonging to each superres template
@@ -662,6 +725,7 @@ def unit_plots(self, sorting_analysis, unit_id):
662725
TPCAWaveformPlot(relocated=True),
663726
NearbyCoarseTemplatesPlot(),
664727
CoarseTemplateDistancePlot(),
728+
NeighborCCGPlot(),
665729
)
666730

667731

@@ -716,10 +780,11 @@ def make_unit_summary(
716780
all_panels.extend(cardfigs)
717781

718782
for cardfig, card in zip(cardfigs, column):
719-
axes = cardfig.subplots(nrows=len(card.plots), ncols=1)
720-
axes = np.atleast_1d(axes)
721-
for plot, axis in zip(card.plots, axes):
722-
plot.draw(axis, sorting_analysis, unit_id)
783+
panels = cardfig.subfigures(nrows=len(card.plots), ncols=1)
784+
panels = np.atleast_1d(panels)
785+
for plot, panel in zip(card.plots, panels):
786+
plot.draw(panel, sorting_analysis, unit_id)
787+
all_panels.extend(panels)
723788

724789
# clean up the panels, or else things get clipped
725790
for panel in all_panels:
@@ -953,3 +1018,9 @@ def _summary_job(unit_id):
9531018
fig.savefig(tmp_out, dpi=_summary_job_context.dpi)
9541019
tmp_out.rename(final_out)
9551020
plt.close(fig)
1021+
1022+
1023+
def bar(ax, x, y, **kwargs):
1024+
dx = np.diff(x).min()
1025+
x0 = np.concatenate((x - dx, x[-1:] + dx))
1026+
ax.stairs(y, x0, **kwargs)

0 commit comments

Comments
 (0)