2
2
from typing import Optional
3
3
4
4
import linear_operator
5
+ from linear_operator import operators
5
6
from linear_operator .operators import CholLinearOperator
6
7
import torch
7
8
import torch .nn .functional as F
10
11
11
12
from ..util .noise_util import EmbeddedNoise
12
13
from .stable_features import SpikeFeatures , SpikeNeighborhoods
13
- from ..util import spiketorch
14
+ from ..util import spiketorch , more_operators
14
15
15
16
vecdot = torch .linalg .vecdot
16
17
@@ -29,7 +30,7 @@ def ppca_em(
29
30
mean_prior_pseudocount = 0.0 ,
30
31
show_progress = False ,
31
32
W_initialization = "svd" ,
32
- normalize = False ,
33
+ normalize = True ,
33
34
em_converged_atol = 0.1 ,
34
35
prior_var = 1.0 ,
35
36
cache_global_direct = True ,
@@ -145,6 +146,7 @@ def ppca_em(
145
146
break
146
147
if show_progress :
147
148
iters .set_description (f"PPCA[{ dmu = :.2g} , { dW = :.2g} ]" )
149
+ print (i , dmu , dW )
148
150
149
151
if normalize and any_missing and state ["W" ] is not None :
150
152
_ , _ , state ["W" ], state ["mu" ] = embed (
@@ -156,6 +158,7 @@ def ppca_em(
156
158
state ["W" ],
157
159
state ["mu" ],
158
160
active_channels = active_channels ,
161
+ ess = ess ,
159
162
active_cov_chol_factor = active_cov_chol_factor ,
160
163
prior_var = prior_var ,
161
164
normalize = normalize ,
@@ -216,18 +219,18 @@ def ppca_e_step(
216
219
D = rank * nc
217
220
218
221
# get normalized weights
219
- y = sp .features
222
+ new_zeros = sp .features . new_zeros
220
223
221
224
# we will build our outputs by iterating over the unique
222
225
# neighborhoods and adding weighted sums of moments in each
223
- e_y = y . new_zeros ((rank , nc ))
226
+ e_y = new_zeros ((rank , nc ))
224
227
yc = e_u = e_ycu = e_uu = None
225
228
if return_yc :
226
- yc = y . new_zeros ((n , rank , nc ))
229
+ yc = new_zeros ((n , rank , nc ))
227
230
if yes_pca :
228
- e_u = y . new_zeros ((M ,))
229
- e_ycu = y . new_zeros ((rank , nc , M ))
230
- e_uu = y . new_zeros ((M , M ))
231
+ e_u = new_zeros ((M ,))
232
+ e_ycu = new_zeros ((rank , nc , M ))
233
+ e_uu = new_zeros ((M , M ))
231
234
232
235
# helpful tensors to keep around
233
236
if yes_pca :
@@ -257,8 +260,8 @@ def ppca_e_step(
257
260
W_m = active_W [:, nd .missing_subset ].reshape (D - nd .D_neighb , M )
258
261
259
262
if yes_pca :
260
- ubar = full_ubar [nd .neighb_members ]
261
- uubar = full_uubar [nd .neighb_members ]
263
+ ubar = full_ubar [nd .u_slice ]
264
+ uubar = full_uubar [nd .u_slice ]
262
265
263
266
# actual data in neighborhood
264
267
xcc = xc = nd .x - nu
@@ -312,7 +315,7 @@ def ppca_e_step(
312
315
wx = nd .w_norm @ nd .x
313
316
if nd .have_missing :
314
317
wxbar_m = nd .w_norm @ xbar_m
315
- ybar = y . new_zeros ((rank , nc ))
318
+ ybar = new_zeros ((rank , nc ))
316
319
ybar [:, nd .active_subset ] = wx .view (rank , nd .neighb_nc )
317
320
ybar [:, nd .missing_subset ] = wxbar_m .view (rank , nc - nd .neighb_nc )
318
321
else :
@@ -324,7 +327,7 @@ def ppca_e_step(
324
327
if nd .have_missing and yes_pca :
325
328
wmxcu = nd .w_norm @ e_mxcu .reshape (nd .neighb_n_spikes , - 1 )
326
329
wmxcu = wmxcu .view (e_mxcu .shape [1 :])
327
- ycubar = y . new_zeros ((rank , nc , M ))
330
+ ycubar = new_zeros ((rank , nc , M ))
328
331
ycubar [:, nd .active_subset ] = wxcu .view (rank , nd .neighb_nc , M )
329
332
ycubar [:, nd .missing_subset ] = wmxcu .view (rank , nc - nd .neighb_nc , M )
330
333
elif yes_pca :
@@ -334,12 +337,12 @@ def ppca_e_step(
334
337
if return_yc :
335
338
if nd .have_missing :
336
339
xc = xc .view (nd .neighb_n_spikes , rank , nd .neighb_nc ).mT
337
- yc [nd .neighb_members [:, None ] , :, nd .active_subset [None , :]] = xc
340
+ yc [nd .u_slice ][: , :, nd .active_subset [None , :]] = xc
338
341
xbar_m -= tnu
339
342
txc = xbar_m .view (nd .neighb_n_spikes , rank , nc - nd .neighb_nc ).mT
340
- yc [nd .neighb_members [:, None ] , :, nd .missing_subset [None , :]] = txc
343
+ yc [nd .u_slice ][: , :, nd .missing_subset [None , :]] = txc
341
344
else :
342
- yc [nd .neighb_members ] = xc .view (nd .neighb_n_spikes , rank , nd .neighb_nc )
345
+ yc [nd .u_slice ] = xc .view (nd .neighb_n_spikes , rank , nd .neighb_nc )
343
346
344
347
# accumulate results
345
348
e_y += ybar
@@ -367,14 +370,18 @@ def embed(
367
370
scratch = None ,
368
371
):
369
372
N = len (sp )
373
+ new_zeros = sp .features .new_zeros
374
+ device = sp .features .device
375
+ dtype = sp .features .dtype
376
+
370
377
if scratch is not None :
371
378
_ubar , _uubar = scratch
372
379
else :
373
- _ubar = sp . features .new_zeros ((N , M ))
380
+ _ubar = features .new_zeros ((N , M ))
374
381
# if not normalize:
375
- _uubar = sp . features .new_zeros (N , M , M )
376
- # _T = sp.features.new_zeros((N, M, M) )
377
- eye_M = prior_var * torch . eye ( M , device = sp . features . device , dtype = sp . features . dtype )
382
+ _uubar = features .new_zeros (N , M , M )
383
+ eye_M_ = torch . eye ( M , device = device , dtype = dtype )
384
+ eye_M = prior_var * eye_M_
378
385
379
386
for nd in neighb_data :
380
387
nu = active_mean [:, nd .active_subset ].reshape (nd .D_neighb )
@@ -388,14 +395,22 @@ def embed(
388
395
# moments of embeddings
389
396
# T_inv = eye_M + W_o.T @ nd.C_oo_chol.solve(W_o)
390
397
T_inv = eye_M + W_o .T @ nd .C_oo_inv @ W_o
391
- T = torch .linalg .inv (T_inv )
392
- u_proj = nd .C_oo_inv @ W_o @ T
398
+ # root = operators.LowRankRootLinearOperator(W_o.T @ nd.C_oo_cholinv)
399
+ # print(f"{root.shape=} {I_M.shape=}")
400
+ # helper = root + I_M
401
+ # helper = operators.LowRankRootSumLinearOperator(I_M
402
+ # print(f"{T_inv.shape=}")
403
+ # T = helper.solve(eye_M_)
404
+ T , info = torch .linalg .inv_ex (T_inv )
405
+ u_proj = nd .C_oo_inv @ (W_o @ T )
393
406
# ubar = Cooinvxc @ (W_o @ T)
394
- ubar = xc @ u_proj
395
- uubar = torch .baddbmm (T , ubar [:, :, None ], ubar [:, None , :])
407
+ # ubar = xc @ u_proj
408
+ # uubar = torch.baddbmm(T, ubar[:, :, None], ubar[:, None, :])
396
409
397
- _ubar [nd .neighb_members ] = ubar
398
- _uubar [nd .neighb_members ] = uubar
410
+ # _ubar[nd.u_slice] = ubar
411
+ # _uubar[nd.u_slice] = uubar
412
+ torch .mm (xc , u_proj , out = _ubar [nd .u_slice ])
413
+ torch .baddbmm (T , _ubar [nd .u_slice ].unsqueeze (2 ), _ubar [nd .u_slice ].unsqueeze (1 ), out = _uubar [nd .u_slice ])
399
414
400
415
if normalize :
401
416
if active_cov_chol_factor is None :
@@ -412,14 +427,14 @@ def embed(
412
427
# active_mean = active_mean + W @ um
413
428
414
429
# whitening. need to do a GEVP to start...
415
- S = (weights @ _uubar .view (N , M * M )).view (N , M , M )
430
+ S = (weights @ _uubar .view (N , M * M )).view (M , M )
416
431
Dx , U = torch .linalg .eigh (S )
417
432
Dx = Dx .flip (dims = (0 ,))
418
433
U = U .flip (dims = (1 ,))
419
434
U .mul_ (sgn (U [0 ]))
420
435
UDxrt = U * Dx .sqrt ()
421
436
rhs = Wflat @ UDxrt .T
422
- gevp_W_right = torch .linalg .solve_triangular (active_cov_chol_factor , rhs )
437
+ gevp_W_right = torch .linalg .solve_triangular (active_cov_chol_factor , rhs , upper = False )
423
438
gevp_W = gevp_W_right .T @ gevp_W_right
424
439
# gevp_W = linear_operator.solve(lhs=rhs.T, input=active_cov, rhs=rhs)
425
440
Dw , V = torch .linalg .eigh (gevp_W )
@@ -435,7 +450,8 @@ def embed(
435
450
W @= W_tf
436
451
_ubar @= u_tf
437
452
_uubar = torch .einsum ("nij,ip,jq->npq" , _uubar , u_tf , u_tf )
438
- active_mean .addmm_ (W , um )
453
+ active_mean += W @ um
454
+ # .addmm_(W.view(-1, M), um.unsqueeze(1))
439
455
440
456
return _ubar , _uubar , W , active_mean
441
457
@@ -449,11 +465,13 @@ class NeighborhoodPPCAData:
449
465
450
466
C_oo : linear_operator .LinearOperator
451
467
C_oo_chol : CholLinearOperator
468
+ C_oo_cholinv : torch .Tensor
452
469
C_oo_inv : CholLinearOperator
453
470
w : torch .Tensor
454
471
w_norm : torch .Tensor
455
472
x : torch .Tensor
456
473
neighb_members : torch .Tensor
474
+ u_slice : torch .Tensor
457
475
458
476
C_mo : Optional [torch .Tensor ]
459
477
active_subset : Optional [torch .Tensor ]
@@ -516,14 +534,15 @@ def get_neighborhood_data(
516
534
517
535
neighborhood_data = []
518
536
ess = weights .sum ()
537
+ n_start = 0
519
538
for chans_tuple , chans_data in dedup_data .items ():
520
539
* info , xs , mems = chans_data
521
540
nid , neighb_chans , active_subset , can_cache_by_neighborhood , have_missing = info
522
541
if len (mems ) > 1 :
523
542
x = torch .concatenate (xs )
524
543
neighb_members = torch .concatenate (mems )
525
- neighb_members , order = neighb_members .sort ()
526
- x = x [order ]
544
+ # neighb_members, order = neighb_members.sort()
545
+ # x = x[order]
527
546
nid = None
528
547
else :
529
548
x = xs [0 ]
@@ -578,16 +597,19 @@ def get_neighborhood_data(
578
597
have_missing = have_missing ,
579
598
C_oo = C_oo ,
580
599
C_oo_chol = C_oo_chol ,
600
+ C_oo_cholinv = Linv ,
581
601
C_oo_inv = C_oo_inv ,
582
602
w = w ,
583
603
w_norm = w / ess ,
584
604
x = x ,
585
605
neighb_members = neighb_members ,
606
+ u_slice = slice (n_start , n_start + n_neighb ),
586
607
C_mo = C_mo ,
587
608
active_subset = active_subset ,
588
609
missing_subset = missing_subset ,
589
610
)
590
611
neighborhood_data .append (nd )
612
+ n_start += n_neighb
591
613
592
614
return neighborhood_data
593
615
0 commit comments