Skip to content

Commit 6fe2005

Browse files
committed
Work on cov vis
1 parent 07d095b commit 6fe2005

File tree

1 file changed

+58
-13
lines changed

1 file changed

+58
-13
lines changed

src/dartsort/vis/gmm.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..cluster import gaussian_mixture
1010
from ..util.multiprocessing_util import (CloudpicklePoolExecutor,
1111
ThreadPoolExecutor, get_pool, cloudpickle)
12+
from ..util import spiketorch
1213
from . import analysis_plots, gmm_helpers, layout
1314
from .colors import glasbey1024
1415
from .waveforms import geomplot
@@ -97,36 +98,49 @@ def draw(self, panel, gmm, unit_id):
9798

9899
class MStep(GMMPlot):
99100
kind = "waveform"
100-
width = 4
101-
height = 5
101+
width = 5
102+
height = 9
102103
alpha = 0.05
103104
n_show = 64
104105

105106
def draw(self, panel, gmm, unit_id, axes=None):
106-
ax = panel.subplots()
107+
panel_top, panel_bottom = panel.subfigures(nrows=2, height_ratios=[1.5, 1])
108+
ax = panel_top.subplots()
107109
ax.axis("off")
108110

111+
# panel_bottom, panel_cbar = panel_bottom.subfigures(ncols=2, width_ratios=[5, 0.5])
112+
cov_axes = panel_bottom.subplots(
113+
nrows=2, ncols=2, sharey=True, sharex=True
114+
)
115+
# cax = panel_cbar.add_subplot(3, 1, 2)
116+
117+
# get spike data and determine channel set by plotting
109118
sp = gmm.random_spike_data(unit_id, max_size=self.n_show, with_reconstructions=True)
110119
maa = sp.waveforms.abs().nan_to_num().max()
120+
geomplot_kw = dict(
121+
max_abs_amp=maa,
122+
geom=gmm.data.prgeom.numpy(force=True),
123+
show_zero=False,
124+
return_chans=True,
125+
)
111126
lines, chans = geomplot(
112127
sp.waveforms,
113128
channels=sp.channels,
114-
geom=gmm.data.prgeom.numpy(force=True),
115-
max_abs_amp=maa,
116129
color="k",
117130
alpha=self.alpha,
118-
return_chans=True,
119-
show_zero=False,
120131
ax=ax,
132+
**geomplot_kw,
121133
)
122134
chans = torch.tensor(list(chans))
123135
tup = gaussian_mixture.to_full_probe(
124136
sp, weights=None, n_channels=gmm.data.n_channels, storage=None
125137
)
126138
features_full, weights_full, count_data, weights_normalized = tup
127-
emp_mean = torch.nanmean(features_full, dim=0)[:, chans]
139+
print(f"{features_full.shape=}")
140+
feats = features_full[:, :, chans]
141+
n, r, c = feats.shape
142+
emp_mean = torch.nanmean(feats, dim=0)
128143
emp_mean = gmm.data.tpca.force_reconstruct(emp_mean.nan_to_num_())
129-
130144
model_mean = gmm.units[unit_id].mean[:, chans]
131145
model_mean = gmm.data.tpca.force_reconstruct(model_mean)
132146

@@ -142,6 +156,37 @@ def draw(self, panel, gmm, unit_id, axes=None):
142156
ax.axis("off")
143157
ax.set_title("reconstructed mean and example inputs")
144158

159+
# covariance vis
160+
feats = features_full[:, :, gmm.units[unit_id].channels]
161+
model_mean = gmm.units[unit_id].mean[:, gmm.units[unit_id].channels]
162+
n, r, c = feats.shape
163+
emp_cov, nobs = spiketorch.nancov(feats.view(n, r * c), return_nobs=True)
164+
denom = nobs + gmm.units[unit_id].prior_pseudocount
165+
emp_cov = (nobs / denom) * emp_cov
166+
noise_cov = gmm.noise.marginal_covariance(channels=gmm.units[unit_id].channels).to_dense()
167+
m = model_mean.abs().reshape(-1)
168+
mmt = m[:, None] @ m[None, :]
169+
covs = (emp_cov, noise_cov, mmt)
170+
vmax = max(c.abs().max() for c in covs)
171+
names = ("regemp", "noise", "|temptempT|")
172+
print(f"{feats.shape=} {gmm.units[unit_id].channels.shape=}")
173+
print(f"{vmax=}")
174+
print(f"{emp_cov.abs().max()=}")
175+
print(f"{noise_cov.abs().max()=}")
176+
print(f"{mmt.abs().max()=}")
177+
print(f"{emp_cov.shape=}")
178+
print(f"{noise_cov.shape=}")
179+
print(f"{mmt.shape=}")
180+
181+
for ax, cov, name in zip(cov_axes.flat, covs, names):
182+
vmax = cov.abs().triu(diagonal=1)
183+
vmax = vmax[vmax>0].quantile(.95)
184+
im = ax.imshow(cov.numpy(force=True), vmin=-vmax, vmax=vmax, cmap=plt.cm.seismic)
185+
ax.axis("off")
186+
ax.set_title(name, fontsize="small")
187+
plt.colorbar(im, ax=ax, shrink=0.5)
188+
# plt.colorbar(im, cax=cax, shrink=0.5)
189+
145190

146191
class Likelihoods(GMMPlot):
147192
kind = "widescatter"
@@ -324,8 +369,8 @@ def draw(self, panel, gmm, unit_id, split_info=None):
324369

325370
class NeighborMeans(GMMPlot):
326371
kind = "merge"
327-
width = 3
328-
height = 4
372+
width = 4
373+
height = 3
329374

330375
def __init__(self, n_neighbors=5):
331376
self.n_neighbors = n_neighbors
@@ -348,8 +393,8 @@ def draw(self, panel, gmm, unit_id):
348393

349394
class NeighborDistances(GMMPlot):
350395
kind = "merge"
351-
width = 3
352-
height = 3
396+
width = 4
397+
height = 2
353398

354399
def __init__(self, n_neighbors=5, dist_vmax=1.0):
355400
self.n_neighbors = n_neighbors

0 commit comments

Comments
 (0)