9
9
from ..cluster import gaussian_mixture
10
10
from ..util .multiprocessing_util import (CloudpicklePoolExecutor ,
11
11
ThreadPoolExecutor , get_pool , cloudpickle )
12
+ from ..util import spiketorch
12
13
from . import analysis_plots , gmm_helpers , layout
13
14
from .colors import glasbey1024
14
15
from .waveforms import geomplot
@@ -97,36 +98,49 @@ def draw(self, panel, gmm, unit_id):
97
98
98
99
class MStep (GMMPlot ):
99
100
kind = "waveform"
100
- width = 4
101
- height = 5
101
+ width = 5
102
+ height = 9
102
103
alpha = 0.05
103
104
n_show = 64
104
105
105
106
def draw (self , panel , gmm , unit_id , axes = None ):
106
- ax = panel .subplots ()
107
+ panel_top , panel_bottom = panel .subfigures (nrows = 2 , height_ratios = [1.5 , 1 ])
108
+ ax = panel_top .subplots ()
107
109
ax .axis ("off" )
108
110
111
+ # panel_bottom, panel_cbar = panel_bottom.subfigures(ncols=2, width_ratios=[5, 0.5])
112
+ cov_axes = panel_bottom .subplots (
113
+ nrows = 2 , ncols = 2 , sharey = True , sharex = True
114
+ )
115
+ # cax = panel_cbar.add_subplot(3, 1, 2)
116
+
117
+ # get spike data and determine channel set by plotting
109
118
sp = gmm .random_spike_data (unit_id , max_size = self .n_show , with_reconstructions = True )
110
119
maa = sp .waveforms .abs ().nan_to_num ().max ()
120
+ geomplot_kw = dict (
121
+ max_abs_amp = maa ,
122
+ geom = gmm .data .prgeom .numpy (force = True ),
123
+ show_zero = False ,
124
+ return_chans = True ,
125
+ )
111
126
lines , chans = geomplot (
112
127
sp .waveforms ,
113
128
channels = sp .channels ,
114
- geom = gmm .data .prgeom .numpy (force = True ),
115
- max_abs_amp = maa ,
116
129
color = "k" ,
117
130
alpha = self .alpha ,
118
- return_chans = True ,
119
- show_zero = False ,
120
131
ax = ax ,
132
+ ** geomplot_kw ,
121
133
)
122
134
chans = torch .tensor (list (chans ))
123
135
tup = gaussian_mixture .to_full_probe (
124
136
sp , weights = None , n_channels = gmm .data .n_channels , storage = None
125
137
)
126
138
features_full , weights_full , count_data , weights_normalized = tup
127
- emp_mean = torch .nanmean (features_full , dim = 0 )[:, chans ]
139
+ print (f"{ features_full .shape = } " )
140
+ feats = features_full [:, :, chans ]
141
+ n , r , c = feats .shape
142
+ emp_mean = torch .nanmean (feats , dim = 0 )
128
143
emp_mean = gmm .data .tpca .force_reconstruct (emp_mean .nan_to_num_ ())
129
-
130
144
model_mean = gmm .units [unit_id ].mean [:, chans ]
131
145
model_mean = gmm .data .tpca .force_reconstruct (model_mean )
132
146
@@ -142,6 +156,37 @@ def draw(self, panel, gmm, unit_id, axes=None):
142
156
ax .axis ("off" )
143
157
ax .set_title ("reconstructed mean and example inputs" )
144
158
159
+ # covariance vis
160
+ feats = features_full [:, :, gmm .units [unit_id ].channels ]
161
+ model_mean = gmm .units [unit_id ].mean [:, gmm .units [unit_id ].channels ]
162
+ n , r , c = feats .shape
163
+ emp_cov , nobs = spiketorch .nancov (feats .view (n , r * c ), return_nobs = True )
164
+ denom = nobs + gmm .units [unit_id ].prior_pseudocount
165
+ emp_cov = (nobs / denom ) * emp_cov
166
+ noise_cov = gmm .noise .marginal_covariance (channels = gmm .units [unit_id ].channels ).to_dense ()
167
+ m = model_mean .abs ().reshape (- 1 )
168
+ mmt = m [:, None ] @ m [None , :]
169
+ covs = (emp_cov , noise_cov , mmt )
170
+ vmax = max (c .abs ().max () for c in covs )
171
+ names = ("regemp" , "noise" , "|temptempT|" )
172
+ print (f"{ feats .shape = } { gmm .units [unit_id ].channels .shape = } " )
173
+ print (f"{ vmax = } " )
174
+ print (f"{ emp_cov .abs ().max ()= } " )
175
+ print (f"{ noise_cov .abs ().max ()= } " )
176
+ print (f"{ mmt .abs ().max ()= } " )
177
+ print (f"{ emp_cov .shape = } " )
178
+ print (f"{ noise_cov .shape = } " )
179
+ print (f"{ mmt .shape = } " )
180
+
181
+ for ax , cov , name in zip (cov_axes .flat , covs , names ):
182
+ vmax = cov .abs ().triu (diagonal = 1 )
183
+ vmax = vmax [vmax > 0 ].quantile (.95 )
184
+ im = ax .imshow (cov .numpy (force = True ), vmin = - vmax , vmax = vmax , cmap = plt .cm .seismic )
185
+ ax .axis ("off" )
186
+ ax .set_title (name , fontsize = "small" )
187
+ plt .colorbar (im , ax = ax , shrink = 0.5 )
188
+ # plt.colorbar(im, cax=cax, shrink=0.5)
189
+
145
190
146
191
class Likelihoods (GMMPlot ):
147
192
kind = "widescatter"
@@ -324,8 +369,8 @@ def draw(self, panel, gmm, unit_id, split_info=None):
324
369
325
370
class NeighborMeans (GMMPlot ):
326
371
kind = "merge"
327
- width = 3
328
- height = 4
372
+ width = 4
373
+ height = 3
329
374
330
375
def __init__ (self , n_neighbors = 5 ):
331
376
self .n_neighbors = n_neighbors
@@ -348,8 +393,8 @@ def draw(self, panel, gmm, unit_id):
348
393
349
394
class NeighborDistances (GMMPlot ):
350
395
kind = "merge"
351
- width = 3
352
- height = 3
396
+ width = 4
397
+ height = 2
353
398
354
399
def __init__ (self , n_neighbors = 5 , dist_vmax = 1.0 ):
355
400
self .n_neighbors = n_neighbors
0 commit comments