diff --git a/doc/references.bib b/doc/references.bib index 24b48327..aa79d595 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -238,4 +238,15 @@ @article{StamEtAl2012 year={2012}, month={Sep}, pages={1415–1428} +} + +@inproceedings{yang_state-space_2016, + title = {A state-space model of cross-region dynamic connectivity in {MEG}/{EEG}}, + volume = {29}, + url = {https://proceedings.neurips.cc/paper/2016/hash/9f396fe44e7c05c16873b05ec425cbad-Abstract.html}, + urldate = {2021-11-21}, + booktitle = {Advances in {Neural} {Information} {Processing} {Systems}}, + publisher = {Curran Associates, Inc.}, + author = {Yang, Ying and Aminoff, Elissa and Tarr, Michael and Robert, Kass E}, + year = {2016} } \ No newline at end of file diff --git a/state_space/megssm/__init__.py b/state_space/megssm/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/state_space/megssm/label_util.py b/state_space/megssm/label_util.py new file mode 100644 index 00000000..c56484ab --- /dev/null +++ b/state_space/megssm/label_util.py @@ -0,0 +1,449 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import mne +import numpy as np +import os + +from megssm.mne_util import combine_medial_labels + +subjects_dir = mne.utils.get_subjects_dir() +rtpj_modes = ('hcp', 'labsn', 'intersect') + +label_shortnames = {'Early Auditory Cortex-lh': 'AUD-lh', + 'Early Auditory Cortex-rh': 'AUD-rh', + 'Premotor Cortex-lh': 'FEF-lh', + 'Premotor Cortex-rh': 'FEF-rh', + 'lh.IPS-labsn-lh': 'IPS-lh', + 'rh.IPS-labsn-rh': 'IPS-rh', + 'lh.LIPSP-lh': 'LIPSP', + 'rh.RTPJ-rh': 'RTPJ', + 'rh.RTPJIntersect-rh-rh': 'RTPJ-intersect', + 'Primary Visual Cortex (V1)-lh + Primary Visual Cortex (V1)-rh + Early Visual Cortex-lh + Early Visual Cortex-rh': 'Vis', + 'Anterior Cingulate and Medial Prefrontal Cortex-lh + Anterior Cingulate and Medial Prefrontal Cortex-rh': 'ACC', + 'DorsoLateral Prefrontal Cortex-lh': 'DLPFC-lh', + 'DorsoLateral Prefrontal Cortex-rh': 'DLPFC-rh', + 'Temporo-Parieto-Occipital Junction-lh': 'TPOJ-lh', + 'Temporo-Parieto-Occipital Junction-rh': 'TPOJ-rh' + } + + +def _sps_meglds_base(): + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + ips_str = os.path.join(subjects_dir, "fsaverage/label/*.IPS-labsn.label") + ips_fnames = glob.glob(ips_str) + assert len(ips_fnames) == 2, ips_fnames + ips_labels = [mne.read_label(fn, subject='fsaverage') for fn in ips_fnames] + + pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + + labels = list() + labels.extend(pmc_labs) + labels.extend(eac_labs) + labels.extend(ips_labels) + + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJ.label') + rtpj = mne.read_label(rtpj_str, subject='fsaverage') + labels.append(rtpj) + + lipsp_str = os.path.join(subjects_dir, 'fsaverage/label/lh.LIPSP.label') + lipsp = mne.read_label(lipsp_str, subject='fsaverage') + labels.append(lipsp) + + return sorted(labels, key=lambda x: x.name), hcp_mmp1_labels + + +def sps_meglds_base(): + return _sps_meglds_base()[0] + + +def _sps_meglds_base_vision(): + + labels, hcp_mmp1_labels = _sps_meglds_base() + + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + label_names = [l.name for l in hcp_mmp1_labels] + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + visual = prim_visual + early_visual_lh + early_visual_rh + + labels.append(visual) + + return sorted(labels, key=lambda x: x.name), hcp_mmp1_labels + + +def sps_meglds_base_vision(): + return _sps_meglds_base_vision()[0] + + +def sps_meglds_base_vision_extra(): + + labels, hcp_mmp1_labels = _sps_meglds_base_vision() + + # glasser 19 + ac_mpc_labs = [l for l in hcp_mmp1_labels if 'Anterior Cingulate and Medial Prefrontal Cortex' in l.name] + labels.extend(ac_mpc_labs) + + # glasser 22 + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + return sorted(labels, key=lambda x: x.name) + + +def sps_meglds_base_extra(): + labels, hcp_mmp1_labels = _sps_meglds_base() + + # glasser 19 + ac_mpc_labs = [l for l in hcp_mmp1_labels if 'Anterior Cingulate and Medial Prefrontal Cortex' in l.name] + labels.extend(ac_mpc_labs) + + # glasser 22 + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + + return sorted(labels, key=lambda x: x.name) + + + +def load_labsn_7_labels(): + label_str = os.path.join(subjects_dir, "fsaverage/label/*labsn*") + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJ.label') + label_fnames = glob.glob(label_str) + assert len(label_fnames) == 6 + label_fnames.insert(0, rtpj_str) + labels = [mne.read_label(fn, subject='fsaverage') for fn in label_fnames] + labels = sorted(labels, key=lambda x: x.name) + + return labels + + +def load_hcpmmp1_combined(): + + labels = mne.read_labels_from_annot('fsaverage', parc='HCPMMP1_combined') + labels = sorted(labels, key=lambda x: x.name) + labels = combine_medial_labels(labels) + + return labels + + +def load_labsn_hcpmmp1_7_labels(include_visual=False, rtpj_mode='intersect'): + + if rtpj_mode not in rtpj_modes: + raise ValueError("rtpj must be one of", rtpj_modes) + + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + ips_str = os.path.join(subjects_dir, "fsaverage/label/*.IPS-labsn.label") + ips_fnames = glob.glob(ips_str) + assert len(ips_fnames) == 2 + ips_labels = [mne.read_label(fn, subject='fsaverage') for fn in ips_fnames] + + pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + + labels = list() + labels.extend(pmc_labs) + labels.extend(eac_labs) + labels.extend(ips_labels) + + # this is in place of original rtpj + #ipc_labs = [l for l in hcp_mmp1_labels if 'Inferior Parietal Cortex' in l.name] + if rtpj_mode == 'hcp': + rtpj = [l for l in hcp_mmp1_labels + if 'Inferior Parietal Cortex' in l.name and l.hemi == 'rh'] + rtpj = rtpj[0] + elif rtpj_mode == 'labsn': + #rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJAnatomical-rh.label') + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJ.label') + rtpj = mne.read_label(rtpj_str, subject='fsaverage') + #tmp = [l for l in ipc_labs if l.hemi == 'lh'] + [rtpj] + #ipc_labs = tmp + elif rtpj_mode == 'intersect': + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJIntersect-rh.label') + rtpj = mne.read_label(rtpj_str, subject='fsaverage') + + #tmp = [l for l in ipc_labs if l.hemi == 'lh'] + [rtpj_hcp] + #ipc_labs = tmp + + labels.append(rtpj) + + #labels.extend(ipc_labs) + + # optionally include early visual regions as controls + if include_visual: + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + visual = prim_visual + early_visual_lh + early_visual_rh + + labels.append(visual) + + return labels + + +def load_labsn_hcpmmp1_7_rtpj_hcp_plus_vision_labels(): + return load_labsn_hcpmmp1_7_labels(include_visual=True, rtpj_mode='hcp') + + +def load_labsn_hcpmmp1_7_rtpj_intersect_plus_vision_labels(): + return load_labsn_hcpmmp1_7_labels(include_visual=True, rtpj_mode='intersect') + + +def load_labsn_hcpmmp1_7_rtpj_sphere_plus_vision_labels(): + return load_labsn_hcpmmp1_7_labels(include_visual=True, rtpj_mode='labsn') + + +def load_labsn_hcpmmp1_av_rois_small(): + + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + #prim_visual_lh = label_names.index("Primary Visual Cortex (V1)-lh") + #prim_visual_rh = label_names.index("Primary Visual Cortex (V1)-rh") + #prim_visual_lh = hcp_mmp1_labels[prim_visual_lh] + #prim_visual_rh = hcp_mmp1_labels[prim_visual_rh] + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + + #visual_lh = prim_visual_lh + early_visual_lh + #visual_rh = prim_visual_rh + early_visual_rh + + visual = prim_visual + early_visual_lh + early_visual_rh + labels = [visual] + + #labels = [visual_lh, visual_rh] + + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + labels.extend(eac_labs) + + tpo_labs = [l for l in hcp_mmp1_labels if 'Temporo-Parieto-Occipital Junction' in l.name] + labels.extend(tpo_labs) + + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + ## extra labels KC wanted + #pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + #labels.extend(pmc_labs) + + #ips_str = glob.glob(os.path.join(subjects_dir, "fsaverage/label/*IPS*labsn*")) + #ips_labs = [mne.read_label(fn, subject='fsaverage') for fn in ips_str] + #labels.extend(ips_labs) + + #rtpj_labs = [l for l in hcp_mmp1_labels if 'Inferior Parietal Cortex-rh' in l.name] + #labels.extend(rtpj_labs) + + return labels + + +def load_labsn_hcpmmp1_av_rois_large(): + + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + #prim_visual_lh = label_names.index("Primary Visual Cortex (V1)-lh") + #prim_visual_rh = label_names.index("Primary Visual Cortex (V1)-rh") + #prim_visual_lh = hcp_mmp1_labels[prim_visual_lh] + #prim_visual_rh = hcp_mmp1_labels[prim_visual_rh] + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + + #visual_lh = prim_visual_lh + early_visual_lh + #visual_rh = prim_visual_rh + early_visual_rh + + visual = prim_visual + early_visual_lh + early_visual_rh + labels = [visual] + + #labels = [visual_lh, visual_rh] + + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + labels.extend(eac_labs) + + tpo_labs = [l for l in hcp_mmp1_labels if 'Temporo-Parieto-Occipital Junction' in l.name] + labels.extend(tpo_labs) + + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + # extra labels KC wanted + pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + labels.extend(pmc_labs) + + #ips_str = glob.glob(os.path.join(subjects_dir, "fsaverage/label/*IPS*labsn*")) + #ips_labs = [mne.read_label(fn, subject='fsaverage') for fn in ips_str] + #labels.extend(ips_labs) + + #rtpj_labs = [l for l in hcp_mmp1_labels if 'Inferior Parietal Cortex-rh' in l.name] + #labels.extend(rtpj_labs) + + # glasser 19 + ac_mpc_labs = [l for l in hcp_mmp1_labels if 'Anterior Cingulate and Medial Prefrontal Cortex' in l.name] + labels.extend(ac_mpc_labs) + + return labels + + +def load_labsn_hcpmmp1_av_rois_large_plus_IPS(): + + hcp_mmp1_labels = mne.read_labels_from_annot('fsaverage', + parc='HCPMMP1_combined') + hcp_mmp1_labels = combine_medial_labels(hcp_mmp1_labels) + label_names = [l.name for l in hcp_mmp1_labels] + + #prim_visual_lh = label_names.index("Primary Visual Cortex (V1)-lh") + #prim_visual_rh = label_names.index("Primary Visual Cortex (V1)-rh") + #prim_visual_lh = hcp_mmp1_labels[prim_visual_lh] + #prim_visual_rh = hcp_mmp1_labels[prim_visual_rh] + prim_visual = [l for l in hcp_mmp1_labels if 'Primary Visual Cortex' in l.name] + + # there should be only one b/c of medial merge + prim_visual = prim_visual[0] + + early_visual_lh = label_names.index("Early Visual Cortex-lh") + early_visual_rh = label_names.index("Early Visual Cortex-rh") + early_visual_lh = hcp_mmp1_labels[early_visual_lh] + early_visual_rh = hcp_mmp1_labels[early_visual_rh] + + #visual_lh = prim_visual_lh + early_visual_lh + #visual_rh = prim_visual_rh + early_visual_rh + + visual = prim_visual + early_visual_lh + early_visual_rh + labels = [visual] + + #labels = [visual_lh, visual_rh] + + eac_labs = [l for l in hcp_mmp1_labels if 'Early Auditory Cortex' in l.name] + labels.extend(eac_labs) + + tpo_labs = [l for l in hcp_mmp1_labels if 'Temporo-Parieto-Occipital Junction' in l.name] + labels.extend(tpo_labs) + + dpc_labs = [l for l in hcp_mmp1_labels if 'DorsoLateral Prefrontal Cortex' in l.name] + labels.extend(dpc_labs) + + # extra labels KC wanted + pmc_labs = [l for l in hcp_mmp1_labels if 'Premotor Cortex' in l.name] + labels.extend(pmc_labs) + + ips_str = os.path.join(subjects_dir, "fsaverage/label/*.IPS-labsn.label") + ips_fnames = glob.glob(ips_str) + ips_labels = [mne.read_label(fn, subject='fsaverage') for fn in ips_fnames] + labels.extend(ips_labels) + + #rtpj_labs = [l for l in hcp_mmp1_labels if 'Inferior Parietal Cortex-rh' in l.name] + #labels.extend(rtpj_labs) + + # glasser 19 + ac_mpc_labs = [l for l in hcp_mmp1_labels if 'Anterior Cingulate and Medial Prefrontal Cortex' in l.name] + labels.extend(ac_mpc_labs) + + return labels + + +def make_rtpj_intersect(): + labels = mne.read_labels_from_annot('fsaverage', 'HCPMMP1', 'rh', + subjects_dir=subjects_dir) + + rtpj_str = os.path.join(subjects_dir, 'fsaverage/label/rh.RTPJAnatomical-rh.label') + rtpj = mne.read_label(rtpj_str, subject='fsaverage') + src = mne.read_source_spaces(subjects_dir + '/fsaverage/bem/fsaverage-5-src.fif') + rtpj = rtpj.fill(src) + + mne.write_label(os.path.join(subjects_dir, + 'fsaverage/label/rh.RTPJ.label'), + rtpj) + + props = np.zeros((len(labels), 2)) + for li, label in enumerate(labels): + props[li] = [np.in1d(rtpj.vertices, label.vertices).mean(), + np.in1d(label.vertices, rtpj.vertices).mean()] + order = np.argsort(props[:, 0])[::-1] + for oi in order: + if props[oi, 0] > 0: + name = labels[oi].name.rstrip('-rh').lstrip('R_') + print('%4.1f%% RTPJ vertices cover %4.1f%% of %s' + % (100*props[oi,0], 100*props[oi,1], name)) + + for ii, oi in enumerate(order[:4]): + if ii == 0: + rtpj = labels[oi].copy() + else: + rtpj += labels[oi] + + mne.write_label(os.path.join(subjects_dir, + 'fsaverage/label/rh.RTPJIntersect-rh.label'), + rtpj) + + +def fixup_lipsp(): + labels = mne.read_labels_from_annot('fsaverage', 'HCPMMP1', 'rh', + subjects_dir=subjects_dir) + + lipsp_str = os.path.join(subjects_dir, 'fsaverage/label/lh.LIPSP_tf.label') + lipsp = mne.read_label(lipsp_str, subject='fsaverage') + lipsp.vertices = lipsp.vertices[lipsp.vertices < 10242] + + src = mne.read_source_spaces(subjects_dir + '/fsaverage/bem/fsaverage-5-src.fif') + lipsp = lipsp.fill(src) + + + mne.write_label(os.path.join(subjects_dir, 'fsaverage/label/lh.LIPSP.label'), + lipsp) + + return lipsp + + +#if __name__ == "__main__": +# +# from surfer import Brain +# labels = sps_meglds_base() +# +# subject_id = 'fsaverage' +# hemi = 'both' +# surf = 'inflated' +# +# brain = Brain(subject_id, hemi, surf) +# for l in labels: +# brain.add_label(l) diff --git a/state_space/megssm/message_passing.py b/state_space/megssm/message_passing.py new file mode 100755 index 00000000..6295059a --- /dev/null +++ b/state_space/megssm/message_passing.py @@ -0,0 +1,295 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import autograd.numpy as np +from autograd.scipy.linalg import block_diag + +from .util import sym, component_matrix, hs + +try: + from autograd_linalg import solve_triangular +except ImportError: + raise RuntimeError("must install `autograd_linalg` package") + +from numpy import einsum + +def rts_smooth(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False, + store_St=True): + """ RTS smoother that broadcasts over the first dimension. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_smooth = np.empty((N, T, Dnlags)) + sigmas_smooth = np.empty((N, T, Dnlags, Dnlags)) + + St = np.empty((N, T, p, p)) if store_St else None + + if compute_lag1_cov: + sigmas_smooth_tnt = np.empty((N, T-1, Dnlags, Dnlags)) + else: + sigmas_smooth_tnt = None + + ll = 0. + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + for t in range(T): + + # condition + # sigma_x = dot3(C, sigma_predict, C.T) + R + tmp1 = einsum('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + sigma_x = einsum('nik,jk->nij', tmp1, CC) + R + sigma_x = sym(sigma_x) + + if St is not None: + St[...,t,:,:] = sigma_x + + L = np.linalg.cholesky(sigma_x) + # res[n] = Y[n,t,:] = np.dot(C, mu_predict[n,t,:]) + res = Y[...,t,:] - einsum('ik,nk->ni', CC, mu_predict[...,t,:]) + v = solve_triangular(L, res, lower=True) + + # log-likelihood over all trials + ll += -0.5*(2.*np.sum(np.log(np.diagonal(L, axis1=1, axis2=2))) + + np.sum(v*v) + + N*p*np.log(2.*np.pi)) + + mus_smooth[:,t,:] = mu_predict[:,t,:] + \ + einsum('nki,nk->ni', tmp1, \ + solve_triangular(L, v, trans='T', lower=True)) + + # tmp2 = L^{-1}*C*sigma_predict + tmp2 = solve_triangular(L, tmp1, lower=True) + sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - \ + einsum('nki,nkj->nij', tmp2, tmp2)) + + # prediction + #mu_predict = np.dot(A[t], mus_smooth[t]) + mu_predict[:,t+1,:] = einsum('ik,nk->ni', AA[t], mus_smooth[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] + tmp = einsum('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum('nil,jl->nij', tmp, AA[t]) + \ + QQ[t]) + + for t in range(T-2, -1, -1): + + # these names are stolen from mattjj and slinderman + #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) + temp_nn = einsum('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) + + L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) + v = solve_triangular(L, temp_nn, lower=True) + # Look in Saarka for dfn of Gt_T + Gt_T = solve_triangular(L, v, trans='T', lower=True) + + # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're + # overwriting them on purpose + mus_smooth[:,t,:] = mus_smooth[:,t,:] + \ + einsum('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) + + tmp = einsum('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - \ + sigma_predict[:,t+1,:,:]) + tmp = einsum('nik,nkj->nij', tmp, Gt_T) + sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) + + if compute_lag1_cov: + # This matrix is NOT symmetric, so don't symmetrize! + sigmas_smooth_tnt[:,t,:,:] = einsum('nik,nkj->nij', \ + sigmas_smooth[:,t+1,:,:], Gt_T) + + return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt, St + + +def rts_smooth_fast(Y, A, C, Q, R, mu0, Q0, compute_lag1_cov=False): + """ RTS smoother that broadcasts over the first dimension. + Handles multiple lag dependence using component form. + + Note: This function doesn't handle control inputs (yet). + + Y : ndarray, shape=(N, T, D) + Observations + + A : ndarray, shape=(T, D*nlag, D*nlag) + Time-varying dynamics matrices + + C : ndarray, shape=(p, D) + Observation matrix + + mu0: ndarray, shape=(D,) + mean of initial state variable + + Q0 : ndarray, shape=(D, D) + Covariance of initial state variable + + Q : ndarray, shape=(D, D) + Covariance of latent states + + R : ndarray, shape=(D, D) + Covariance of observations + """ + + N, T, _ = Y.shape + _, D, Dnlags = A.shape + nlags = Dnlags // D + AA = np.stack([component_matrix(At, nlags) for At in A], axis=0) + + L_R = np.linalg.cholesky(R) + + p = C.shape[0] + CC = hs([C, np.zeros((p, D*(nlags-1)))]) + tmp = solve_triangular(L_R, CC, lower=True) + Rinv_CC = solve_triangular(L_R, tmp, trans='T', lower=True) + CCT_Rinv_CC = einsum('ki,kj->ij', CC, Rinv_CC) + + # tile L_R across number of trials so solve_triangular + # can broadcast over trials properly + L_R = np.tile(L_R, (N, 1, 1)) + + QQ = np.zeros((T, Dnlags, Dnlags)) + QQ[:,:D,:D] = Q + + QQ0 = block_diag(*[Q0 for _ in range(nlags)]) + + mu_predict = np.empty((N, T+1, Dnlags)) + sigma_predict = np.empty((N, T+1, Dnlags, Dnlags)) + + mus_smooth = np.empty((N, T, Dnlags)) + sigmas_smooth = np.empty((N, T, Dnlags, Dnlags)) + + if compute_lag1_cov: + sigmas_smooth_tnt = np.empty((N, T-1, Dnlags, Dnlags)) + else: + sigmas_smooth_tnt = None + + ll = 0. + mu_predict[:,0,:] = np.tile(mu0, nlags) + sigma_predict[:,0,:,:] = QQ0.copy() + + I_tiled = np.tile(np.eye(Dnlags), (N, 1, 1)) + + for t in range(T): + + # condition + tmp1 = einsum('ik,nkj->nij', CC, sigma_predict[:,t,:,:]) + + res = Y[...,t,:] - einsum('ik,nk->ni', CC, mu_predict[...,t,:]) + + # Rinv * res + tmp2 = solve_triangular(L_R, res, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * res + tmp3 = einsum('ki,nk->ni', Rinv_CC, res) + + # (Pinv + C^T Rinv C)_inv * tmp3 + L_P = np.linalg.cholesky(sigma_predict[:,t,:,:]) + tmp = solve_triangular(L_P, I_tiled, lower=True) + Pinv = solve_triangular(L_P, tmp, trans='T', lower=True) + tmp4 = sym(Pinv + CCT_Rinv_CC) + L_tmp4 = np.linalg.cholesky(tmp4) + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum('ik,nk->ni', Rinv_CC, tmp3) + + # add the two Woodbury * res terms together + tmp = tmp2 - tmp3 + + mus_smooth[:,t,:] = mu_predict[:,t,:] + einsum('nki,nk->ni', tmp1, tmp) + + # Rinv * tmp1 + tmp2 = solve_triangular(L_R, tmp1, lower=True) + tmp2 = solve_triangular(L_R, tmp2, trans='T', lower=True) + + # C^T Rinv * tmp1 + tmp3 = einsum('ki,nkj->nij', Rinv_CC, tmp1) + + # (Pinv + C^T Rinv C)_inv * tmp3 + tmp3 = solve_triangular(L_tmp4, tmp3, lower=True) + tmp3 = solve_triangular(L_tmp4, tmp3, trans='T', lower=True) + + # Rinv C * tmp3 + tmp3 = einsum('ik,nkj->nij', Rinv_CC, tmp3) + + # add the two Woodbury * tmp1 terms together, left-multiply by tmp1 + tmp = einsum('nki,nkj->nij', tmp1, tmp2 - tmp3) + + sigmas_smooth[:,t,:,:] = sym(sigma_predict[:,t,:,:] - tmp) + + # prediction + mu_predict[:,t+1,:] = einsum('ik,nk->ni', AA[t], mus_smooth[:,t,:]) + + #sigma_predict = dot3(A[t], sigmas_smooth[t], A[t].T) + Q[t] + tmp = einsum('ik,nkl->nil', AA[t], sigmas_smooth[:,t,:,:]) + sigma_predict[:,t+1,:,:] = sym(einsum('nil,jl->nij', tmp, AA[t]) + \ + QQ[t]) + + for t in range(T-2, -1, -1): + + # these names are stolen from mattjj and slinderman + #temp_nn = np.dot(A[t], sigmas_smooth[n,t,:,:]) + temp_nn = einsum('ik,nkj->nij', AA[t], sigmas_smooth[:,t,:,:]) + + L = np.linalg.cholesky(sigma_predict[:,t+1,:,:]) + v = solve_triangular(L, temp_nn, lower=True) + # Look in Saarka for dfn of Gt_T + Gt_T = solve_triangular(L, v, trans='T', lower=True) + + # {mus,sigmas}_smooth[n,t] contain the filtered estimates so we're + # overwriting them on purpose + mus_smooth[:,t,:] = mus_smooth[:,t,:] + \ + einsum('nki,nk->ni', Gt_T, mus_smooth[:,t+1,:] - mu_predict[:,t+1,:]) + + tmp = einsum('nki,nkj->nij', Gt_T, sigmas_smooth[:,t+1,:,:] - \ + sigma_predict[:,t+1,:,:]) + tmp = einsum('nik,nkj->nij', tmp, Gt_T) + sigmas_smooth[:,t,:,:] = sym(sigmas_smooth[:,t,:,:] + tmp) + + if compute_lag1_cov: + # This matrix is NOT symmetric, so don't symmetrize! + sigmas_smooth_tnt[:,t,:,:] = einsum('nik,nkj->nij', \ + sigmas_smooth[:,t+1,:,:], Gt_T) + + return ll, mus_smooth, sigmas_smooth, sigmas_smooth_tnt diff --git a/state_space/megssm/mne_util.py b/state_space/megssm/mne_util.py new file mode 100644 index 00000000..61b09165 --- /dev/null +++ b/state_space/megssm/mne_util.py @@ -0,0 +1,213 @@ +""" MNE-Python utility functions for preprocessing data and constructing + matrices necessary for MEGLDS analysis """ + +import mne +import numpy as np +from mne import label_sign_flip +from scipy.sparse import csc_matrix, csr_matrix +from sklearn.decomposition import PCA + +Carray = lambda X: np.require(X, dtype=np.float64, requirements='C') + +class ROIToSourceMap(object): + """ class for computing ROI-to-source space mapping matrix + + Notes + ----- + The following variables defined here correspond to various matrices + defined in :footcite:`yang_state-space_2016`: + - fwd_src_snsr : G + - fwd_roi_snsr : C + - fwd_src_roi : L + - snsr_cov : Q_e + - roi_cov : Q + - roi_cov_0 : Q0 """ + + def __init__(self, fwd, labels, label_flip=False): + + src = fwd['src'] + + roiidx = list() + vertidx = list() + + n_lhverts = len(src[0]['vertno']) + n_rhverts = len(src[1]['vertno']) + n_verts = n_lhverts + n_rhverts + offsets = {'lh': 0, 'rh': n_lhverts} + + # index vector of which ROI a source point belongs to + which_roi = np.zeros(n_verts, dtype=np.int64) + + data = [] + for li, lab in enumerate(labels): + + this_data = np.round(label_sign_flip(lab, src)) + if not label_flip: + this_data.fill(1.) + data.append(this_data) + if isinstance(lab, mne.Label): + comp_labs = [lab] + elif isinstance(lab, mne.BiHemiLabel): + comp_labs = [lab.lh, lab.rh] + + for clab in comp_labs: + hemi = clab.hemi + hi = 0 if hemi == 'lh' else 1 + + lverts = clab.get_vertices_used(vertices=src[hi]['vertno']) + + # gets the indices in the source space vertex array, not the + # huge array. + # use `src[hi]['vertno'][lverts]` to get surface vertex indices + # to plot. + lverts = np.searchsorted(src[hi]['vertno'], lverts) + lverts += offsets[hemi] + vertidx.extend(lverts) + roiidx.extend(np.full(lverts.size, li, dtype=np.int64)) + + # add 1 b/c 0 corresponds to unassigned variance + which_roi[lverts] = li + 1 + + N = len(labels) + M = n_verts + + # construct sparse fwd_src_roi matrix + data = np.concatenate(data) + vertidx = np.array(vertidx, int) + roiidx = np.array(roiidx, int) + assert data.shape == vertidx.shape == roiidx.shape + fwd_src_roi = csc_matrix((data, (vertidx, roiidx)), shape=(M, N)) + + self.fwd = fwd + self.fwd_src_roi = fwd_src_roi + self.which_roi = which_roi + self.offsets = offsets + self.n_lhverts = n_lhverts + self.n_rhverts = n_rhverts + self.labels = labels + + return + +def apply_projs(epochs, fwd, cov): + """ apply projection operators to fwd and cov """ + proj, _ = mne.io.proj.setup_proj(epochs.info, activate=False) + fwd_src_sn = fwd['sol']['data'] + fwd['sol']['data'] = np.dot(proj, fwd_src_sn) + + roi_cov = cov.data + if not np.allclose(np.dot(proj, roi_cov), roi_cov): + roi_cov = np.dot(proj, np.dot(roi_cov, proj.T)) + cov.data = roi_cov + + return fwd, cov + + +def _scale_sensor_data(epochs, fwd, cov, roi_to_src, eeg_scale=1., + mag_scale=1., grad_scale=1.): + """ apply per-channel-type scaling to epochs, forward, and covariance """ + + # get indices for each channel type + ch_names = cov['names'] + + # build scaler + info = epochs.info.copy() + std = dict(grad=1. / grad_scale, mag=1. / mag_scale, eeg=1. / eeg_scale) + noproj_info = info.copy() + with noproj_info._unlock(): + noproj_info['projs'] = [] + rescale_cov = mne.make_ad_hoc_cov(noproj_info, std=std) + scaler, ch_names = mne.cov.compute_whitener(rescale_cov, noproj_info) + np.testing.assert_array_equal(np.diag(np.diag(scaler)), scaler) + assert ch_names == info['ch_names'] + + # retrieve forward and sensor covariance + fwd_src_snsr = fwd['sol']['data'].copy() + roi_cov = cov.data.copy() + + # scale forward matrix + fwd_src_snsr = scaler @ fwd_src_snsr + + # construct fwd_roi_snsr matrix + fwd_roi_snsr = Carray(csr_matrix.dot(roi_to_src.fwd_src_roi.T, fwd_src_snsr.T).T) + + # scale sensor covariance + roi_cov = scaler.T @ roi_cov @ scaler + + # scale epochs + data = epochs.get_data().copy() + data = scaler.T @ data + epochs = mne.EpochsArray(data, info) + + return fwd_src_snsr, fwd_roi_snsr, roi_cov, epochs + + +def run_pca_on_subject(subject_name, epochs, fwd, cov, labels, dim_mode='rank', + pctvar=0.99, mean_center=False, label_flip=False): + """ apply sensor scaling, PCA dimensionality reduction with/without + whitening, and mean-centering to subject data """ + + if dim_mode not in ['rank', 'pctvar', 'whiten']: + raise ValueError("dim_mode must be in {'rank', 'pctvar', 'whiten'}") + + print("running pca for subject %s" % subject_name) + + scales = {'eeg_scale' : 1e8, 'mag_scale' : 1e16, 'grad_scale' : 1e14} + + # compute ROI-to-source map + roi_to_src = ROIToSourceMap(fwd, labels, label_flip) + + if dim_mode == 'whiten': + + fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = \ + _scale_sensor_data(epochs, fwd, cov, roi_to_src) + dat = epochs.get_data() + dat = Carray(np.swapaxes(dat, -1, -2)) + + if mean_center: + dat -= np.mean(dat, axis=1, keepdims=True) + + dat_stacked = np.reshape(dat, (-1, dat.shape[-1])) + + W, _ = mne.cov.compute_whitener(subject.sensor_cov, + info=subject.epochs_list[0].info, + pca=True) + print("whitener for subject %s using %d principal components" % + (subject_name, W.shape[0])) + + else: + + fwd_src_snsr, fwd_roi_snsr, cov_snsr, epochs = _scale_sensor_data( + epochs, fwd, cov, roi_to_src, **scales) + + dat = epochs.get_data().copy() + dat = Carray(np.swapaxes(dat, -1, -2)) + + if mean_center: + dat -= np.mean(dat, axis=1, keepdims=True) + + dat_stacked = np.reshape(dat, (-1, dat.shape[-1])) + pca = PCA() + pca.fit(dat_stacked) + + if dim_mode == 'rank': + idx = np.linalg.matrix_rank(np.cov(dat_stacked, rowvar=False)) + else: + idx = np.where(np.cumsum(pca.explained_variance_ratio_) > + pctvar)[0][0] + + idx = np.maximum(idx, len(labels)) + W = pca.components_[:idx] + print("subject %s using %d principal components" % (subject_name, idx)) + + ntrials, T, _ = dat.shape + dat_pca = np.dot(dat_stacked, W.T) + dat_pca = np.reshape(dat_pca, (ntrials, T, -1)) + + fwd_src_snsr_pca = np.dot(W, fwd_src_snsr) + fwd_roi_snsr_pca = np.dot(W, fwd_roi_snsr) + cov_snsr_pca = np.dot(W,np.dot(cov_snsr, W.T)) + + data = dat_pca + + return (data, fwd_roi_snsr_pca, fwd_src_snsr_pca, cov_snsr_pca, + roi_to_src.which_roi) diff --git a/state_space/megssm/models.py b/state_space/megssm/models.py new file mode 100755 index 00000000..f6ab8038 --- /dev/null +++ b/state_space/megssm/models.py @@ -0,0 +1,888 @@ +import sys +import mne + +import autograd.numpy as np +import scipy.optimize as spopt + +from autograd import grad +from autograd import value_and_grad as vgrad +from scipy.linalg import LinAlgError + +from .util import _ensure_ndim, rand_stable, rand_psd +from .util import linesearch, soft_thresh_At, block_thresh_At +from .util import relnormdiff +from .message_passing import rts_smooth, rts_smooth_fast +from .numpy_numthreads import numpy_num_threads + +from .mne_util import run_pca_on_subject, apply_projs + +try: + from autograd_linalg import solve_triangular +except ImportError: + raise RuntimeError("must install `autograd_linalg` package") + +from autograd.numpy import einsum + +from datetime import datetime + + +class _Model(object): + """ Base class for any model applied to MEG data that handles storing and + unpacking data from tuples. """ + + def __init__(self): + self._subjectdata = None + self._n_timepts = 0 + self._ntrials_all = 0 + self._nsubjects = 0 + + def set_data(self, subjectdata): + n_timepts_lst = [self.unpack_subject_data(e)[0].shape[1] for e in + subjectdata] + assert len(list(set(n_timepts_lst))) == 1 + self._n_timepts = n_timepts_lst[0] + ntrials_lst = [self.unpack_subject_data(e)[0].shape[0] for e in \ + subjectdata] + self._ntrials_all = np.sum(ntrials_lst) + self._nsubjects = len(subjectdata) + self._subjectdata = subjectdata + + def unpack_all_subject_data(self): + if self._subjectdata is None: + raise ValueError("use set_data to add subject data") + return map(self.unpack_subject_data, self._subjectdata) + + @classmethod + def unpack_subject_data(cls, sdata): + obs, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + Y = obs + w_s = 1. + if isinstance(obs, tuple): + if len(obs) == 2: + Y, w_s = obs + else: + raise ValueError("invalid format for subject data") + else: + Y = obs + w_s = 1. + + return Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi + + +class LDS(_Model): + """ State-space model for MEG data, as described in "A state-space model of + cross-region dynamic connectivity in MEG/EEG", Yang et al., NIPS 2016. + """ + + def __init__(self, lam0=0., lam1=0., penalty='ridge', store_St=True): + + super().__init__() + self._model_initalized = False + self.lam0 = lam0 + self.lam1 = lam1 + + if penalty not in ('ridge', 'lasso', 'group-lasso'): + raise ValueError('penalty must be one of: ridge, lasso,' \ + + ' group-lasso') + self._penalty = penalty + + # initialize lists of smoothed estimates + self._mus_smooth_lst = None + self._sigmas_smooth_lst = None + self._sigmas_tnt_smooth_lst = None + self._loglik = None + self._store_St = bool(store_St) + + self._all_subject_data = list() + + #SNR boost epochs, bootstraps of 3 + def bootstrap_subject(self, epochs, subject_name, seed=8675309, sfreq=100, + lower=None, upper=None, nbootstrap=3, g_nsamples=-5, + overwrite=False, validation_set=True): + + datasets = ['train', 'validation'] + # use_erm = eq = False + independent = False + if g_nsamples == 0: + print('nsamples == 0, ensuring independence of samples') + independent = True + elif g_nsamples == -1: + print("using half of trials per sample") + elif g_nsamples == -2: + print("using empty room noise at half of trials per sample") + # use_erm = True + elif g_nsamples == -3: + print("using independent and trial-count equalized samples") + eq = True + independent = True + elif g_nsamples == -4: + print("using independent, trial-count equailized, non-boosted" + "samples") + assert nbootstrap == 0 # sanity check + eq = True + independent = True + datasets = ['train'] + elif g_nsamples == -5: + print("using independent, trial-count equailized, integer boosted" + "samples") + eq = True + independent = True + datasets = ['train'] + + if lower is not None or upper is not None: + if upper is None: + print('high-pass filtering at %.2f Hz' % lower) + elif lower is None: + print('low-pass filtering at %.2f Hz' % upper) + else: + print('band-pass filtering from %.2f-%.2f Hz' % (lower, upper)) + + if sfreq is not None: + print('resampling to %.2f Hz' % sfreq) + + print(":: processing subject %s" % subject_name) + np.random.seed(seed) + + for dataset in datasets: + + print(' generating ', dataset, ' set') + # datadir = './data' + + condition_map = {'auditory_left':['auditory_left'], + 'auditory_right': ['auditory_right'], + 'visual_left': ['visual_left'], + 'visual_right': ['visual_right']} + condition_eq_map = dict(auditory_left=['auditory_left'], + auditory_right=['auditory_right'], + visual_left=['visual_left'], + visual_right='visual_right') + + if eq: + epochs.equalize_event_counts(list(condition_map)) + cond_map = condition_eq_map + + # apply band-pass filter to limit signal to desired frequency band + if lower is not None or upper is not None: + epochs = epochs.filter(lower, upper) + + # perform resampling with specified sampling frequency + if sfreq is not None: + epochs = epochs.resample(sfreq) + + data_bs_all = list() + events_bs_all = list() + for cond in sorted(cond_map.keys()): + print(" -> condition %s: bootstrapping" % cond, end='') + ep = epochs[cond_map[cond]] + dat = ep.get_data().copy() + ntrials, T, p = dat.shape + + use_bootstrap = nbootstrap + if g_nsamples == -4: + nsamples = 1 + use_bootstrap = ntrials + elif g_nsamples == -5: + nsamples = nbootstrap + use_bootstrap = ntrials // nsamples + elif independent: + nsamples = (ntrials - 1) // use_bootstrap + elif g_nsamples in (-1, -2): + nsamples = ntrials // 2 + else: + assert g_nsamples > 0 + nsamples = g_nsamples + print(" using %d samples (%d trials)" + % (nsamples, use_bootstrap)) + + # bootstrap here + if independent: # independent + if nsamples == 1 and use_bootstrap == ntrials: + inds = np.arange(ntrials) + else: + inds = np.random.choice(ntrials, + nsamples * use_bootstrap) + inds.shape = (use_bootstrap, nsamples) + dat_bs = np.mean(dat[inds], axis=1) + events_bs = ep.events[inds[:, 0]] + assert dat_bs.shape[0] == events_bs.shape[0] + else: + dat_bs = np.empty((ntrials, T, p)) + events_bs = np.empty((ntrials, 3), dtype=int) + for i in range(ntrials): + + inds = list(set(range(ntrials)).difference([i])) + inds = np.random.choice(inds, size=nsamples, + replace=False) + inds = np.append(inds, i) + + dat_bs[i] = np.mean(dat[inds], axis=0) + events_bs[i] = ep.events[i] + + inds = np.random.choice(ntrials, size=use_bootstrap, + replace=False) + dat_bs = dat_bs[inds] + events_bs = events_bs[inds] + + assert dat_bs.shape == (use_bootstrap, T, p) + assert events_bs.shape == (use_bootstrap, 3) + assert (events_bs[:, 2] == events_bs[0, 2]).all() + + data_bs_all.append(dat_bs) + events_bs_all.append(events_bs) + + # write bootstrap epochs + info_dict = epochs.info.copy() + + dat_all = np.vstack(data_bs_all) + events_all = np.vstack(events_bs_all) + # replace first column with sequential list as we don't really care + # about the raw timings + events_all[:, 0] = np.arange(events_all.shape[0]) + + epochs_bs = mne.EpochsArray( + dat_all, info_dict, events=events_all, tmin=-0.2, + event_id=epochs.event_id.copy(), on_missing='ignore') + + return epochs_bs + + def add_subject(self, subject,condition,epochs,labels,fwd, + cov): + + epochs_bs = self.bootstrap_subject(epochs, subject) + epochs_bs = epochs_bs[condition] + epochs = epochs_bs + + # ensure cov and fwd use correct channels + cov = cov.pick_channels(epochs.ch_names, ordered=True) + fwd = mne.convert_forward_solution(fwd, force_fixed=True) + fwd = fwd.pick_channels(epochs.ch_names, ordered=True) + + if not self._model_initalized: + n_timepts = len(epochs.times) + num_roi = len(labels) + self._init_model(n_timepts, num_roi) + self._model_initalized = True + self.n_timepts = n_timepts + self.num_roi = num_roi + self.times = epochs.times + if len(epochs.times) != self._n_times: + raise ValueError(f'Number of time points ({len(epochs.times)})' / + 'does not match original count ({self._n_times})') + + # scale cov matrix according to number of bootstraps + cov_scale = 3 # equal to number of bootstrap trials + cov['data'] /= cov_scale + fwd, cov = apply_projs(epochs_bs, fwd, cov) + + sdata = run_pca_on_subject(subject, epochs_bs, fwd, cov, labels, + dim_mode='pctvar', mean_center=True) + data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + subjectdata = (data, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi) + + self._all_subject_data.append(subjectdata) + + self._subject_data[subject] = dict() + self._subject_data[subject]['epochs'] = data + self._subject_data[subject]['fwd_src_snsr'] = fwd_src_snsr + self._subject_data[subject]['fwd_roi_snsr'] = fwd_roi_snsr + self._subject_data[subject]['snsr_cov'] = snsr_cov + self._subject_data[subject]['labels'] = labels + self._subject_data[subject]['which_roi'] = which_roi + + + def _init_model(self, n_timepts, num_roi, A_t_=None, roi_cov=None, + mu0=None, roi_cov_0=None, log_sigsq_lst=None): + + self._n_times = n_timepts + self._subject_data = dict() + + set_default = \ + lambda prm, val, deflt: \ + self.__setattr__(prm, val.copy() if val is not None else + deflt) + + # initialize parameters + set_default("A_t_", A_t_, + np.stack([rand_stable(num_roi, maxew=0.7) for _ in + range(n_timepts)], axis=0)) + set_default("roi_cov", roi_cov, rand_psd(num_roi)) + set_default("mu0", mu0, np.zeros(num_roi)) + set_default("roi_cov_0", roi_cov_0, rand_psd(num_roi)) + set_default("log_sigsq_lst", log_sigsq_lst, + [np.log(np.random.gamma(2, 1, size=num_roi+1))]) + + # initialize sufficient statistics + n_timepts, num_roi, _ = self.A_t_.shape + self._B0 = np.zeros((num_roi, num_roi)) + self._B1 = np.zeros((n_timepts-1, num_roi, num_roi)) + self._B3 = np.zeros((n_timepts-1, num_roi, num_roi)) + self._B2 = np.zeros((n_timepts-1, num_roi, num_roi)) + self._B4 = list() + + def set_data(self, subjectdata): + # add subject data, re-generate log_sigsq_lst if necessary + super().set_data(subjectdata) + if len(self.log_sigsq_lst) != self._nsubjects: + num_roi = self.log_sigsq_lst[0].shape[0] + self.log_sigsq_lst = [np.log(np.random.gamma(2, 1, size=num_roi)) + for _ in range(self._nsubjects)] + + # reset smoothed estimates and log-likelihood (no longer valid if + # new data was added) + self._mus_smooth_lst = None + self._sigmas_smooth_lst = None + self._sigmas_tnt_smooth_lst = None + self._loglik = None + self._B4 = [None] * self._nsubjects + + def _em_objective(self): + + _, num_roi, _ = self.A_t_.shape + + L_roi_cov_0 = np.linalg.cholesky(self.roi_cov_0) + L_roi_cov = np.linalg.cholesky(self.roi_cov) + + L1 = 0. + L2 = 0. + L3 = 0. + + obj = 0. + for s, sdata in enumerate(self.unpack_all_subject_data()): + + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + + ntrials, n_timepts, _ = Y.shape + + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = LDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + L_R = np.linalg.cholesky(R) + + if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None + or self._sigmas_tnt_smooth_lst is None): + roi_cov_t = _ensure_ndim(self.roi_cov, n_timepts, 3) + with numpy_num_threads(1): + _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ + rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, + roi_cov_t, R, self.mu0, + self.roi_cov_0, + compute_lag1_cov=True) + + else: + mus_smooth = self._mus_smooth_lst[s] + sigmas_smooth = self._sigmas_smooth_lst[s] + sigmas_tnt_smooth = self._sigmas_tnt_smooth_lst[s] + + x_smooth_0_outer = einsum('ri,rj->rij', mus_smooth[:,0,:num_roi], + mus_smooth[:,0,:num_roi]) + B0 = w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + + x_smooth_0_outer, axis=0) + + x_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], + mus_smooth[:,1:,:num_roi]) + B1 = w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + + x_smooth_outer, axis=0) + z_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,:-1,:], + mus_smooth[:,:-1,:]) + B3 = w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, axis=0) + + mus_smooth_outer_l1 = einsum('rti,rtj->rtij', + mus_smooth[:,1:,:num_roi], + mus_smooth[:,:-1,:]) + B2 = w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + + mus_smooth_outer_l1, axis=0) + + # obj += L1(roi_cov_0) + L_roi_cov_0_inv_B0 = solve_triangular(L_roi_cov_0, B0, lower=True) + L1 += (ntrials*2.*np.sum(np.log(np.diag(L_roi_cov_0))) + + np.trace(solve_triangular(L_roi_cov_0, L_roi_cov_0_inv_B0, + lower=True, trans='T'))) + + At = self.A_t_[:-1] + AtB2T = einsum('tik,tjk->tij', At, B2) + B2AtT = einsum('tik,tjk->tij', B2, At) + tmp = einsum('tik,tkl->til', At, B3) + AtB3AtT = einsum('tik,tjk->tij', tmp, At) + + tmp = np.sum(B1 - AtB2T - B2AtT + AtB3AtT, axis=0) + + # obj += L2(roi_cov, At) + L_roi_cov_inv_tmp = solve_triangular(L_roi_cov, tmp, lower=True) + L2 += (ntrials*(n_timepts-1)*2.*np.sum(np.log(np.diag(L_roi_cov))) + + np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_tmp, + lower=True, trans='T'))) + + res = Y - einsum('ik,ntk->nti', fwd_roi_snsr, + mus_smooth[:,:,:num_roi]) + CP_smooth = einsum('ik,ntkj->ntij', fwd_roi_snsr, + sigmas_smooth[:,:,:num_roi,:num_roi]) + + B4 = w_s*(np.sum(einsum('nti,ntj->ntij', res, res), axis=(0,1)) + + np.sum(einsum('ntik,jk->ntij', CP_smooth, + fwd_roi_snsr), axis=(0,1))) + self._B4[s] = B4 + + # obj += L3(sigsq_vals) + L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) + L3 += (ntrials*n_timepts*2*np.sum(np.log(np.diag(L_R))) + + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, + trans='T'))) + + obj = (L1 + L2 + L3) / self._ntrials_all + + # obj += penalty + if self.lam0 > 0.: + if self._penalty == 'ridge': + obj += self.lam0*np.sum(At**2) + elif self._penalty == 'lasso': + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + sum_At_diag = np.sum(np.abs(At_diag)) + obj += self.lam0*(np.sum(np.abs(At)) - sum_At_diag) + elif self._penalty == 'group-lasso': + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + norm_At_diag = np.sum(np.linalg.norm(At_diag, axis=0)) + norm_At = np.sum(np.linalg.norm(At, axis=0)) + obj += self.lam1*(norm_At - norm_At_diag) + if self.lam1 > 0.: + AtmAtm1_2 = (At[1:] - At[:-1])**2 + obj += self.lam1*np.sum(AtmAtm1_2) + + return obj + + def fit(self, niter=100, tol=1e-6, A_t_roi_cov_niter=100, + A_t_roi_cov_tol=1e-6, verbose=0, update_A_t_=True, + update_roi_cov=True, update_roi_cov_0=True, stationary_A_t_=False, + diag_roi_cov=False, update_sigsq=True, do_final_smoothing=True, + average_mus_smooth=True, Atrue=None, tau=0.1, c1=1e-4): + + self.set_data(self._all_subject_data) + + fxn_start = datetime.now() + + n_timepts, num_roi, _ = self.A_t_.shape + + # make initial A_t_ stationary if stationary_A_t_ option specified + if stationary_A_t_: + self.A_t_[:] = np.mean(self.A_t_, axis=0) + + # set parameters for (A_t_, roi_cov) optimization + self._A_t_roi_cov_niter = A_t_roi_cov_niter + self._A_t_roi_cov_tol = A_t_roi_cov_tol + + # make initial roi_cov, roi_cov_0 diagonal if diag_roi_cov specified + if diag_roi_cov: + self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) + self.roi_cov = np.diag(np.diag(self.roi_cov)) + + # keeping track of objective value and best parameters + objvals = np.zeros(niter+1) + converged = False + best_objval = np.finfo('float').max + best_params = (self.A_t_.copy(), self.roi_cov.copy(), self.mu0.copy(), + self.roi_cov_0.copy(), [l.copy() for l in + self.log_sigsq_lst]) + + # previous parameter values (for checking convergence) + At_prev = None + roi_cov_prev = None + roi_cov_0_prev = None + log_sigsq_lst_prev = None + + if Atrue is not None: + import matplotlib.pyplot as plt + fig_A_t_, ax_A_t_ = plt.subplots(num_roi, num_roi, sharex=True, + sharey=True) + plt.ion() + + # calculate initial objective value, check for updated best iterate + # have to do e-step here to initialize suff stats for _m_step + if (self._mus_smooth_lst is None or self._sigmas_smooth_lst is None + or self._sigmas_tnt_smooth_lst is None): + self._e_step(verbose=verbose-1) + + objval = self._em_objective() + objvals[0] = objval + + for it in range(1, niter+1): + + iter_start = datetime.now() + + if verbose > 0: + print("em: it %d / %d" % (it, niter)) + sys.stdout.flush() + sys.stderr.flush() + + # record values from previous M-step + At_prev = self.A_t_[:-1].copy() + roi_cov_prev = self.roi_cov.copy() + roi_cov_0_prev = self.roi_cov_0.copy() + log_sigsq_lst_prev = np.array(self.log_sigsq_lst).copy() + + self._m_step(update_A_t_=update_A_t_, + update_roi_cov=update_roi_cov, + update_roi_cov_0=update_roi_cov_0, + stationary_A_t_=stationary_A_t_, + diag_roi_cov=diag_roi_cov, + update_sigsq=update_sigsq, + tau=tau, c1=c1, verbose=verbose) + + if Atrue is not None: + for i in range(num_roi): + for j in range(num_roi): + ax_A_t_[i, j].cla() + ax_A_t_[i, j].plot(Atrue[:-1, i, j], color='green') + ax_A_t_[i, j].plot(self.A_t_[:-1, i, j], color='red', + alpha=0.7) + fig_A_t_.tight_layout() + fig_A_t_.canvas.draw() + plt.pause(1. / 60.) + + self._e_step(verbose=verbose-1) + + # calculate objective value, check for updated best iterate + objval = self._em_objective() + objvals[it] = objval + + if verbose > 0: + print(" objective: %.4e" % objval) + At = self.A_t_[:-1] + maxAt = np.max(np.abs(np.triu(At, k=1) + np.tril(At, k=-1))) + print(" max |A_t|: %.4e" % (maxAt,)) + sys.stdout.flush() + sys.stderr.flush() + + if objval < best_objval: + best_objval = objval + best_params = (self.A_t_.copy(), self.roi_cov.copy(), + self.mu0.copy(), self.roi_cov_0.copy(), + [l.copy() for l in self.log_sigsq_lst]) + + # check for convergence + if it >= 1: + relnormdiff_At = relnormdiff(self.A_t_[:-1], At_prev) + relnormdiff_roi_cov = relnormdiff(self.roi_cov, roi_cov_prev) + relnormdiff_roi_cov_0 = relnormdiff(self.roi_cov_0, + roi_cov_0_prev) + relnormdiff_log_sigsq_lst = \ + np.array( + [relnormdiff(self.log_sigsq_lst[s], + log_sigsq_lst_prev[s]) + for s in range(len(self.log_sigsq_lst))]) + params_converged = (relnormdiff_At <= tol) and \ + (relnormdiff_roi_cov <= tol) and \ + (relnormdiff_roi_cov_0 <= tol) and \ + np.all(relnormdiff_log_sigsq_lst <= tol) + + relobjdiff = np.abs((objval - objvals[it-1]) / objval) + + if verbose > 0: + print(" relnormdiff_At: %.3e" % relnormdiff_At) + print(" relnormdiff_roi_cov: %.3e" % relnormdiff_roi_cov) + print(" relnormdiff_roi_cov_0: %.3e" % + relnormdiff_roi_cov_0) + print(" relnormdiff_log_sigsq_lst:", + relnormdiff_log_sigsq_lst) + print(" relobjdiff: %.3e" % relobjdiff) + + objdiff = objval - objvals[it-1] + if objdiff > 0: + print(" \033[0;31mEM objective increased\033[0m") + + sys.stdout.flush() + sys.stderr.flush() + + if params_converged or relobjdiff <= tol: + if verbose > 0: + print("EM objective converged") + sys.stdout.flush() + sys.stderr.flush() + converged = True + objvals = objvals[:it+1] + break + + # retrieve best parameters and load into instance variables. + A_t_, roi_cov, mu0, roi_cov_0, log_sigsq_lst = best_params + self.A_t_ = A_t_.copy() + self.roi_cov = roi_cov.copy() + self.mu0 = mu0.copy() + self.roi_cov_0 = roi_cov_0.copy() + self.log_sigsq_lst = [l.copy() for l in log_sigsq_lst] + + if verbose > 0: + print() + print("elapsed, iteration:", datetime.now() - iter_start) + print("=" * 34) + print() + + # perform final smoothing + mus_smooth_lst = None + St_lst = None + if do_final_smoothing: + if verbose >= 1: + print("performing final smoothing") + + mus_smooth_lst = list() + self._loglik = 0. + if self._store_St: + St_lst = list() + for s, sdata in enumerate(self.unpack_all_subject_data()): + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = LDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) + with numpy_num_threads(1): + loglik_subject, mus_smooth, _, _, St = \ + rts_smooth(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, R, + self.mu0, self.roi_cov_0, + compute_lag1_cov=False, + store_St=self._store_St) + # just save the mean of the smoothed trials + if average_mus_smooth: + mus_smooth_lst.append(np.mean(mus_smooth, axis=0)) + else: + mus_smooth_lst.append(mus_smooth) + self._loglik += loglik_subject + # just save the diagonals of St b/c that's what we need for + # connectivity + if self._store_St: + St_lst.append(np.diagonal(St, axis1=-2, axis2=-1)) + + if verbose > 0: + print() + print("elapsed, function:", datetime.now() - fxn_start) + print("=" * 34) + print() + + return objvals, converged, mus_smooth_lst, self._loglik, St_lst + + def _e_step(self, verbose=0): + n_timepts, num_roi, _ = self.A_t_.shape + + # reset accumulation arrays + self._B0[:] = 0. + self._B1[:] = 0. + self._B3[:] = 0. + self._B2[:] = 0. + + self._mus_smooth_lst = list() + self._sigmas_smooth_lst = list() + self._sigmas_tnt_smooth_lst = list() + + if verbose > 0: + print(" e-step") + print(" subject", end="") + + for s, sdata in enumerate(self.unpack_all_subject_data()): + + if verbose > 0: + print(" %d" % (s+1,), end="") + sys.stdout.flush() + sys.stderr.flush() + + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + + sigsq_vals = np.exp(self.log_sigsq_lst[s]) + R = LDS.R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi) + roi_cov_t = _ensure_ndim(self.roi_cov, self._n_timepts, 3) + + with numpy_num_threads(1): + _, mus_smooth, sigmas_smooth, sigmas_tnt_smooth = \ + rts_smooth_fast(Y, self.A_t_, fwd_roi_snsr, roi_cov_t, + R, self.mu0, self.roi_cov_0, + compute_lag1_cov=True) + + self._mus_smooth_lst.append(mus_smooth) + self._sigmas_smooth_lst.append(sigmas_smooth) + self._sigmas_tnt_smooth_lst.append(sigmas_tnt_smooth) + + x_smooth_0_outer = einsum('ri,rj->rij', mus_smooth[:,0,:num_roi], + mus_smooth[:,0,:num_roi]) + self._B0 += w_s*np.sum(sigmas_smooth[:,0,:num_roi,:num_roi] + + x_smooth_0_outer, axis=0) + + x_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,1:,:num_roi], + mus_smooth[:,1:,:num_roi]) + self._B1 += w_s*np.sum(sigmas_smooth[:,1:,:num_roi,:num_roi] + + x_smooth_outer, axis=0) + + z_smooth_outer = einsum('rti,rtj->rtij', mus_smooth[:,:-1,:], + mus_smooth[:,:-1,:]) + self._B3 += w_s*np.sum(sigmas_smooth[:,:-1,:,:] + z_smooth_outer, + axis=0) + + mus_smooth_outer_l1 = einsum('rti,rtj->rtij', + mus_smooth[:,1:,:num_roi], + mus_smooth[:,:-1,:]) + self._B2 += w_s*np.sum(sigmas_tnt_smooth[:,:,:num_roi,:] + + mus_smooth_outer_l1, axis=0) + + if verbose > 0: + print("\n done") + + def _m_step(self, update_A_t_=True, update_roi_cov=True, + update_roi_cov_0=True, stationary_A_t_=False, + diag_roi_cov=False, update_sigsq=True, tau=0.1, c1=1e-4, + verbose=0): + self._loglik = None + if verbose > 0: + print(" m-step") + if update_roi_cov_0: + self.roi_cov_0 = (1. / self._ntrials_all) * self._B0 + if diag_roi_cov: + self.roi_cov_0 = np.diag(np.diag(self.roi_cov_0)) + self.update_A_t_and_roi_cov(update_A_t_=update_A_t_, + update_roi_cov=update_roi_cov, + stationary_A_t_=stationary_A_t_, + diag_roi_cov=diag_roi_cov, tau=tau, + c1=c1, verbose=verbose) + if update_sigsq: + self.update_log_sigsq_lst(verbose=verbose) + + def update_A_t_and_roi_cov(self, update_A_t_=True, update_roi_cov=True, + stationary_A_t_=False, + diag_roi_cov=False, tau=0.1, c1=1e-4, verbose=0): + + if verbose > 1: + print(" update A_t_ and roi_cov") + + # gradient descent + At = self.A_t_[:-1] + At_init = At.copy() + L_roi_cov = np.linalg.cholesky(self.roi_cov) + At_L_roi_cov_obj = lambda x, y: self.L2_obj(x, y) + At_obj = lambda x: self.L2_obj(x, L_roi_cov) + grad_At_obj = grad(At_obj) + obj_diff = np.finfo('float').max + obj = At_L_roi_cov_obj(At, L_roi_cov) + inner_it = 0 + + # specify proximal operator to use + if self._penalty == 'ridge': + prox_op = lambda x, y: x + elif self._penalty == 'lasso': + prox_op = soft_thresh_At + elif self._penalty == 'group-lasso': + prox_op = block_thresh_At + + while np.abs(obj_diff / obj) > self._A_t_roi_cov_tol: + + if inner_it > self._A_t_roi_cov_niter: + break + + obj_start = At_L_roi_cov_obj(At, L_roi_cov) + + # update At using gradient descent with backtracking line search + if update_A_t_: + if stationary_A_t_: + B2_sum = np.sum(self._B2, axis=0) + B3_sum = np.sum(self._B3, axis=0) + At[:] = np.linalg.solve(B3_sum.T, B2_sum.T).T + else: + grad_At = grad_At_obj(At) + step_size = linesearch(At_obj, grad_At_obj, At, grad_At, + prox_op=prox_op, lam=self.lam0, + tau=tau, c1=c1) + At[:] = prox_op(At - step_size * grad_At, + self.lam0 * step_size) + + # update roi_cov using closed form + if update_roi_cov: + AtB2T = einsum('tik,tjk->tij', At, self._B2) + B2AtT = einsum('tik,tjk->tij', self._B2, At) + tmp = einsum('tik,tkl->til', At, self._B3) + AtB3AtT = einsum('til,tjl->tij', tmp, At) + elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) + self.roi_cov = (1. / (self._ntrials_all * self._n_timepts + )) * elbo_2 + if diag_roi_cov: + self.roi_cov = np.diag(np.diag(self.roi_cov)) + L_roi_cov = np.linalg.cholesky(self.roi_cov) + + obj = At_L_roi_cov_obj(At, L_roi_cov) + obj_diff = obj_start - obj + inner_it += 1 + + if verbose > 1: + if not stationary_A_t_ and update_A_t_: + grad_norm = np.linalg.norm(grad_At) + norm_change = np.linalg.norm(At - At_init) + print(" last step size: %.3e" % step_size) + print(" last gradient norm: %.3e" % grad_norm) + print(" norm of total change: %.3e" % norm_change) + print(" number of iterations: %d" % inner_it) + print(" done") + + def update_log_sigsq_lst(self, verbose=0): + + if verbose > 1: + print(" update subject log-sigmasq") + + n_timepts, num_roi, _ = self.A_t_.shape + + # update log_sigsq_vals for each subject and ROI + for s, sdata in enumerate(self.unpack_all_subject_data()): + + Y, w_s, fwd_roi_snsr, fwd_src_snsr, snsr_cov, which_roi = sdata + ntrials, n_timepts, _ = Y.shape + B4 = self._B4[s] + + log_sigsq = self.log_sigsq_lst[s].copy() + log_sigsq_obj = lambda x: \ + LDS.L3_obj(x, snsr_cov, fwd_src_snsr, which_roi, B4, + ntrials, n_timepts) + log_sigsq_val_and_grad = vgrad(log_sigsq_obj) + + options = {'maxiter': 500} + opt_res = spopt.minimize(log_sigsq_val_and_grad, log_sigsq, + method='L-BFGS-B', jac=True, + options=options) + if verbose > 1: + print(" subject %d - %d iterations" % (s+1, opt_res.nit)) + + if not opt_res.success: + print(" log_sigsq opt") + print(" %s" % opt_res.message) + + self.log_sigsq_lst[s] = opt_res.x + + if verbose > 1: + print("\n done") + + + @staticmethod + def R_(snsr_cov, fwd_src_snsr, sigsq_vals, which_roi): + return snsr_cov + np.dot(fwd_src_snsr, + sigsq_vals[which_roi][:,None]*fwd_src_snsr.T) + + def L2_obj(self, At, L_roi_cov): + AtB2T = einsum('tik,tjk->tij', At, self._B2) + B2AtT = einsum('tik,tjk->tij', self._B2, At) + tmp = einsum('tik,tkl->til', At, self._B3) + AtB3AtT = einsum('til,tjl->tij', tmp, At) + elbo_2 = np.sum(self._B1 - AtB2T - B2AtT + AtB3AtT, axis=0) + + L_roi_cov_inv_elbo_2 = solve_triangular(L_roi_cov, elbo_2, lower=True) + obj = np.trace(solve_triangular(L_roi_cov, L_roi_cov_inv_elbo_2, + lower=True, + trans='T')) + obj = obj / self._ntrials_all + + if self._penalty == 'ridge': + obj += self.lam0*np.sum(At**2) + AtmAtm1_2 = (At[1:] - At[:-1])**2 + obj += self.lam1*np.sum(AtmAtm1_2) + + return obj + + @staticmethod + def L3_obj(log_sigsq_vals, snsr_cov, fwd_src_snsr, which_roi, B4, ntrials, + n_timepts): + R = LDS.R_(snsr_cov, fwd_src_snsr, np.exp(log_sigsq_vals), + which_roi) + try: + L_R = np.linalg.cholesky(R) + except LinAlgError: + return np.finfo('float').max + L_R_inv_B4 = solve_triangular(L_R, B4, lower=True) + return (ntrials*n_timepts*2.*np.sum(np.log(np.diag(L_R))) + + np.trace(solve_triangular(L_R, L_R_inv_B4, lower=True, + trans='T'))) diff --git a/state_space/megssm/numpy_numthreads.py b/state_space/megssm/numpy_numthreads.py new file mode 100755 index 00000000..d632c21c --- /dev/null +++ b/state_space/megssm/numpy_numthreads.py @@ -0,0 +1,81 @@ +import contextlib +import ctypes +from ctypes.util import find_library + +# heavily based on: +# https://stackoverflow.com/questions/29559338/set-max-number-of-threads-at-runtime-on-numpy-openblas + +# Prioritize hand-compiled OpenBLAS library over version in /usr/lib/ +# from Ubuntu repos +try_paths = [find_library('openblas')] +openblas_lib = None +mkl_rt = None +#if openblas_lib is None: + #raise EnvironmentError('Could not locate an OpenBLAS shared library', 2) + + +def set_num_threads(n): + """Set the current number of threads used by the OpenBLAS server.""" + if mkl_rt: + pass + #mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(n))) + elif openblas_lib: + openblas_lib.openblas_set_num_threads(int(n)) + + +# At the time of writing these symbols were very new: +# https://github.com/xianyi/OpenBLAS/commit/65a847c +try: + if mkl_rt: #False: #mkl_rt: + def get_num_threads(): + return mkl_rt.mkl_get_max_threads() + elif openblas_lib: + # do this to throw exception if it doesn't exist + openblas_lib.openblas_get_num_threads() + def get_num_threads(): + """Get the current number of threads used by the OpenBLAS server.""" + return openblas_lib.openblas_get_num_threads() + else: + def get_num_threads(): + return -1 +except AttributeError: + def get_num_threads(): + """Dummy function (symbol not present in %s), returns -1.""" + return -1 + pass + +try: + if False: #mkl_rt: + def get_num_procs(): + # this returns number of procs + return mkl_rt.mkl_get_max_threads() + elif openblas_lib: + # do this to throw exception if it doesn't exist + openblas_lib.openblas_get_num_procs() + def get_num_procs(): + """Get the total number of physical processors""" + return openblas_lib.openblas_get_num_procs() +except AttributeError: + def get_num_procs(): + """Dummy function (symbol not present), returns -1.""" + return -1 + pass + + +@contextlib.contextmanager +def numpy_num_threads(n): + """Temporarily changes the number of OpenBLAS threads. + + Example usage: + + print("Before: {}".format(get_num_threads())) + with num_threads(n): + print("In thread context: {}".format(get_num_threads())) + print("After: {}".format(get_num_threads())) + """ + old_n = get_num_threads() + set_num_threads(n) + try: + yield + finally: + set_num_threads(old_n) diff --git a/state_space/megssm/plotting.py b/state_space/megssm/plotting.py new file mode 100644 index 00000000..8925660a --- /dev/null +++ b/state_space/megssm/plotting.py @@ -0,0 +1,107 @@ +""" plotting functions """ + +import numpy as np +import matplotlib.pyplot as plt + +def plot_A_t_(A, ci='sd', times=None, ax=None, skipdiag=False, labels=None, + showticks=True, **kwargs): + """ plot traces of each entry of dynamics A in square grid of subplots """ + if A.ndim == 3: + T, d, _ = A.shape + elif A.ndim == 4: + _, T, d, _ = A.shape + + if times is None: + times = np.arange(T) + + if ax is None or ax.shape != (d, d): + fig, ax = plt.subplots(d, d, sharex=True, sharey=True, squeeze=False) + else: + fig = ax[0, 0].figure + + for i in range(d): + for j in range(d): + + # skip and hide subplots on diagonal + if skipdiag and i == j: + ax[i, j].set_visible(False) + continue + + # plot A entry as trace with/without error band + if A.ndim == 3: + ax[i, j].plot(times[:-1], A[:-1, i, j], **kwargs) + elif A.ndim == 4: + plot_fill(A[:, :-1, i, j], ci=ci, times=times[:-1], + ax=ax[i, j], **kwargs) + + # add labels above first row and to the left of the first column + if labels is not None: + if i == 0 or (skipdiag and (i, j) == (1, 0)): + ax[i, j].set_title(labels[j], fontsize=12) + if j == 0 or (skipdiag and (i, j) == (0, 1)): + ax[i, j].set_ylabel(labels[i], fontsize=12) + + # remove x- and y-ticks on subplot + if not showticks: + ax[i, j].set_xticks([]) + ax[i, j].set_yticks([]) + + diag_lims = [0, 1] + off_lims = [-0.25, 0.25] + for ri, row in enumerate(ax): + for ci, a in enumerate(row): + ylim = diag_lims if ri == ci else off_lims + a.set(ylim=ylim, xlim=times[[0, -1]]) + if ri == 0: + a.set_title(a.get_title(), fontsize='small') + if ci == 0: + a.set_ylabel(a.get_ylabel(), fontsize='small') + for line in a.lines: + line.set_clip_on(False) + line.set(lw=1.) + if ci != 0: + a.yaxis.set_major_formatter(plt.NullFormatter()) + if ri != len(labels) - 1: + a.xaxis.set_major_formatter(plt.NullFormatter()) + if ri == ci: + for spine in a.spines.values(): + spine.set(lw=2) + else: + a.axhline(0, color='k', ls=':', lw=1.) + + return fig, ax + +def plot_fill(X, times=None, ax=None, ci='sd', **kwargs): + """ plot mean and error band across first axis of X """ + N, T = X.shape + + if times is None: + times = np.arange(T) + if ax is None: + fig, ax = plt.subplots(1, 1) + + mu = np.mean(X, axis=0) + + # define lower and upper band limits based on ci + if ci == 'sd': # standard deviation + sigma = np.std(X, axis=0) + lower, upper = mu - sigma, mu + sigma + elif ci == 'se': # standard error + stderr = np.std(X, axis=0) / np.sqrt(X.shape[0]) + lower, upper = mu - stderr, mu + stderr + elif ci == '2sd': # 2 standard deviations + sigma = np.std(X, axis=0) + lower, upper = mu - 2 * sigma, mu + 2 * sigma + elif ci == 'max': # range (min to max) + lower, upper = np.min(X, axis=0), np.max(X, axis=0) + elif type(ci) is float and 0 < ci < 1: + # quantile-based confidence interval + a = 1 - ci + lower, upper = np.quantile(X, [a / 2, 1 - a / 2], axis=0) + else: + raise ValueError("ci must be in ('sd', 'se', '2sd', 'max') " + "or float in (0, 1)") + + lines = ax.plot(times, mu, **kwargs) + c = lines[0].get_color() + ax.fill_between(times, lower, upper, color=c, alpha=0.3, lw=0) diff --git a/state_space/megssm/util.py b/state_space/megssm/util.py new file mode 100755 index 00000000..63898be3 --- /dev/null +++ b/state_space/megssm/util.py @@ -0,0 +1,117 @@ +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import autograd.numpy as np +from numpy.lib.stride_tricks import as_strided as ast + + +hs = lambda *args: np.concatenate(*args, axis=-1) + +def T_(X): + return np.swapaxes(X, -1, -2) + +def sym(X): + return 0.5*(X + T_(X)) + +def dot3(A, B, C): + return np.dot(A, np.dot(B, C)) + +def relnormdiff(A, B, min_denom=1e-9): + return np.linalg.norm(A - B) / np.maximum(np.linalg.norm(A), min_denom) + +def _ensure_ndim(X, T, ndim): + X = np.require(X, dtype=np.float64, requirements='C') + assert ndim-1 <= X.ndim <= ndim + if X.ndim == ndim: + assert X.shape[0] == T + return X + else: + return ast(X, shape=(T,) + X.shape, strides=(0,) + X.strides) + +def rand_psd(n, minew=0.1, maxew=1.): + # maxew is badly named + if n == 1: + return maxew * np.eye(n) + X = np.random.randn(n,n) + S = np.dot(T_(X), X) + S = sym(S) + ew, ev = np.linalg.eigh(S) + ew -= np.min(ew) + ew /= np.max(ew) + ew *= (maxew - minew) + ew += minew + return dot3(ev, np.diag(ew), T_(ev)) + +def rand_stable(n, maxew=0.9): + A = np.random.randn(n, n) + A *= maxew / np.max(np.abs(np.linalg.eigvals(A))) + return A + +def component_matrix(As, nlags): + """ compute component form of latent VAR process + + [A_1 A_2 ... A_p] + [ I 0 ... 0 ] + [ 0 I 0 0 ] + [ 0 ... I 0 ] + + """ + + d = As.shape[0] + res = np.zeros((d*nlags, d*nlags)) + res[:d] = As + + if nlags > 1: + res[np.arange(d,d*nlags), np.arange(d*nlags-d)] = 1 + + return res + +def linesearch(f, grad_f, xk, pk, step_size=1., tau=0.1, c1=1e-4, + prox_op=None, lam=1.): + """ find a step size via backtracking line search with armijo condition """ + obj_start = f(xk) + grad_xk = grad_f(xk) + obj_new = np.finfo('float').max + armijo_condition = 0 + + if prox_op is None: + prox_op = lambda x, y: x + + while obj_new > armijo_condition: + x_new = prox_op(xk - step_size * pk, lam*step_size) + armijo_condition = obj_start - c1*step_size*(np.sum(pk*grad_xk)) + obj_new = f(x_new) + step_size *= tau + + return step_size/tau + +def soft_thresh_At(At, lam): + At = At.copy() + diag_inds = np.diag_indices(At.shape[1]) + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + + At = np.sign(At) * np.maximum(np.abs(At) - lam, 0.) + + # fill in diagonal with originally updated entries as we're not + # going to penalize them + for tt in range(At.shape[0]): + At[tt][diag_inds] = At_diag[tt] + return At + +def block_thresh_At(At, lam, min_norm=1e-16): + At = At.copy() + diag_inds = np.diag_indices(At.shape[1]) + At_diag = np.diagonal(At, axis1=-2, axis2=-1) + + norms = np.linalg.norm(At, axis=0, keepdims=True) + norms = np.maximum(norms, min_norm) + scales = np.maximum(norms - lam, 0.) + At = scales * (At / norms) + + # fill in diagonal with originally updated entries as we're not + # going to penalize them + for tt in range(At.shape[0]): + At[tt][diag_inds] = At_diag[tt] + return At + diff --git a/state_space/state_space_connectivity.py b/state_space/state_space_connectivity.py new file mode 100644 index 00000000..2c464258 --- /dev/null +++ b/state_space/state_space_connectivity.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Authors: Jordan Drew + +""" + +''' +For 'mne-connectivity' examples to show usage of LDS +Use MNE-sample-data for auditory/left +''' + +## import necessary libraries +import mne +import matplotlib.pyplot as plt +import matplotlib as mpl + + +from megssm.models import LDS +from megssm.plotting import plot_A_t_ + +# define paths to sample data +data_path = mne.datasets.sample.data_path() +sample_folder = data_path / 'MEG/sample' + +## import raw data and find events +raw_fname = sample_folder / 'sample_audvis_raw.fif' +raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) +events = mne.find_events(raw, stim_channel='STI 014') + +## define epochs using event_dict +event_dict = {'auditory_left': 1, 'auditory_right': 2, 'visual_left': 3, + 'visual_right': 4} +epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=0.7, event_id=event_dict, + preload=True).pick_types(meg=True,eeg=True) +condition = 'auditory_left' + +## read forward solution, remove bad channels +fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' +fwd = mne.read_forward_solution(fwd_fname) + +## read in covariance +cov_fname = sample_folder / 'sample_audvis-cov.fif' +cov = mne.read_cov(cov_fname) + +## read labels for analysis +label_names = ['Aud-lh', 'Aud-rh', 'Vis-lh', 'Vis-rh'] +labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label', + subject='sample') for label in label_names] + +# initiate model +model = LDS(lam0=0, lam1=100) +model.add_subject('sample', condition, epochs, labels, fwd, cov) +model.fit(niter=100, verbose=2) + +#plot model output +num_roi = model.num_roi +n_timepts = model.n_timepts +times = model.times +A_t_ = model.A_t_ +assert A_t_.shape == (n_timepts, num_roi, num_roi) +with mpl.rc_context(): + {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} + fig, ax = plt.subplots(num_roi, num_roi, constrained_layout=True, + squeeze=False, figsize=(12, 10)) + plot_A_t_(A_t_, labels=label_names, times=times, ax=ax) + fig.suptitle('API output_new Q scale_') + diag_lims = [0, 1] + off_lims = [-0.6, 0.6] + for ri, row in enumerate(ax): + for ci, a in enumerate(row): + ylim = diag_lims if ri == ci else off_lims + a.set(ylim=ylim, xlim=times[[0, -1]]) + if ri == 0: + a.set_title(a.get_title(), fontsize='small') + if ci == 0: + a.set_ylabel(a.get_ylabel(), fontsize='small') + for line in a.lines: + line.set_clip_on(False) + line.set(lw=1.) + if ci != 0: + a.yaxis.set_major_formatter(plt.NullFormatter()) + if ri != len(label_names) - 1: + a.xaxis.set_major_formatter(plt.NullFormatter()) + if ri == ci: + for spine in a.spines.values(): + spine.set(lw=2) + else: + a.axhline(0, color='k', ls=':', lw=1.) diff --git a/state_space/state_space_connectivity_test.py b/state_space/state_space_connectivity_test.py new file mode 100644 index 00000000..848997ea --- /dev/null +++ b/state_space/state_space_connectivity_test.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Authors: Jordan Drew + +""" + +''' +For 'mne-connectivity/examples/' to show usage of LDS +Use MNE-sample-data for auditory/left +''' + +## import necessary libraries +import mne +import matplotlib.pyplot as plt +import matplotlib as mpl + +#where should these files live within mne-connectivity repo? +from megssm.models import MEGLDS as LDS +from megssm.plotting import plot_A_t_ +from megssm import label_util + +## define paths to sample data +data_path = '/Users/jordandrew/Documents/MEG/meglds-master/data/sps' +# sample_folder = data_path / 'MEG/sample' +# subjects_dir = data_path / 'subjects' + +subjects = ['eric_sps_03','eric_sps_04','eric_sps_05','eric_sps_06', + 'eric_sps_07','eric_sps_09','eric_sps_10','eric_sps_15', + 'eric_sps_17','eric_sps_18','eric_sps_19','eric_sps_21', + 'eric_sps_25','eric_sps_26','eric_sps_31','eric_sps_32'] + +label_names = ['ACC', 'DLPFC-lh', 'DLPFC-rh', 'AUD-lh', 'AUD-rh', 'FEF-lh', + 'FEF-rh', 'Vis', 'IPS-lh', 'LIPSP', 'IPS-rh', 'RTPJ'] +label_func = 'sps_meglds_base_vision_extra' +labels = getattr(label_util, label_func)() +labels = sorted(labels, key=lambda x: x.name) + +def eq_trials(epochs, kind): + """ equalize trial counts """ + import numpy as np + import mne + assert kind in ('sub', 'big') + print(' equalizing trial counts', end='') + in_names = [ + 'LL3', 'LR3', 'LU3', 'LD3', 'RL3', 'RR3', 'RU3', 'RD3', + 'UL3', 'UR3', 'UU3', 'UD3', 'DL3', 'DR3', 'DU3', 'DD3', + 'LL4', 'LR4', 'LU4', 'LD4', 'RL4', 'RR4', 'RU4', 'RD4', + 'UL4', 'UR4', 'UU4', 'UD4', 'DL4', 'DR4', 'DU4', 'DD4', + 'VS_', 'VM_', + 'Junk', + ] + out_names = ['LL', 'LR', 'LX', 'UX', 'UU', 'UD', 'VS', 'VM'] + + # strip 3/4 and combine + clean_names = np.unique([ii[:2] for ii in in_names + if not ii.startswith('V')]) + for name in clean_names: + combs = [in_name for in_name in in_names if in_name.startswith(name)] + new_id = {name: epochs.event_id[combs[-1]] + 1} + mne.epochs.combine_event_ids(epochs, combs, new_id, copy=False) + + # Now we equalize LU+LD, RU+RD, UL+UR, DL+DR, and combine those + for n1, n2, oname in zip(('LU', 'RU', 'UL', 'DL'), + ('LD', 'RD', 'UR', 'DR'), + ('LX', 'RX', 'UX', 'DX')): + if kind == 'sub': + epochs.equalize_event_counts([n1, n2]) + new_id = {oname: epochs.event_id[n1] + 1} + mne.epochs.combine_event_ids(epochs, [n1, n2], new_id, copy=False) + + # Now we equalize "sides" + cs = dict(L='R', R='L', U='D', D='U') + for n1 in ['L', 'R', 'U', 'D']: + # first equalize it with its complement in the second pos + if kind == 'sub': + epochs.equalize_event_counts([n1 + n1, n1 + cs[n1]]) + epochs.equalize_event_counts([n1 + n1, cs[n1] + n1]) + epochs.equalize_event_counts([n1 + 'X', cs[n1] + 'X']) + + # now combine cross types + for n1 in ['L', 'U']: + # LR+RL=LR, UD+DU=UD + old_ids = [n1 + cs[n1], cs[n1] + n1] + if kind == 'sub': + epochs.equalize_event_counts(old_ids) + new_id = {n1 + cs[n1]: epochs.event_id[n1 + cs[n1]] + 1} + mne.epochs.combine_event_ids(epochs, old_ids, new_id, copy=False) + # LL+RR=LL, UU+DD=UU + old_ids = [n1 + n1, cs[n1] + cs[n1]] + if kind == 'sub': + epochs.equalize_event_counts(old_ids) + new_id = {n1 + n1: epochs.event_id[n1 + n1] + 1} + mne.epochs.combine_event_ids(epochs, old_ids, new_id, copy=False) + # LC+RC=LC + old_ids = [n1 + 'X', cs[n1] + 'X'] + if kind == 'sub': + epochs.equalize_event_counts(old_ids) + new_id = {n1 + 'X': epochs.event_id[n1 + 'X'] + 1} + mne.epochs.combine_event_ids(epochs, old_ids, new_id, copy=False) + + mne.epochs.combine_event_ids(epochs, ['VM_'], dict(VM=96), copy=False) + assert 'Ju' in epochs.event_id + epochs.drop(np.where(epochs.events[:, 2] == + epochs.event_id['Ju'])[0]) + mne.epochs.combine_event_ids(epochs, ['VS_', 'Ju'], dict(VS=97), + copy=False) + + # at this point we only care about: + eq_names = ('LX', 'UX', 'LL', 'LR', 'UU', 'UD', 'VS') + assert set(eq_names + ('VM',)) == set(epochs.event_id.keys()) + assert set(eq_names + ('VM',)) == set(out_names) + orig_len = len(epochs['LL']) + epochs.equalize_event_counts(eq_names) + new_len = len(epochs['LL']) + print(' (reduced LL %s -> %s)' % (orig_len, new_len)) + for ni, out_name in enumerate(out_names): + idx = (epochs.events[:, 2] == epochs.event_id[out_name]) + epochs.event_id[out_name] = ni + 1 + epochs.events[idx, 2] = ni + 1 + return epochs + +for subject in subjects: + + subject_dir = f'{data_path}/{subject}' + + epochs_fname = f'{subject_dir}/epochs/All_55-sss_{subject}-epo.fif' + epochs = mne.read_epochs(epochs_fname) + epochs = eq_trials(epochs, kind='sub') + epochs = epochs['LL'] + + fwd_fname = f'{subject_dir}/forward/{subject}-sss-fwd.fif' + fwd = mne.read_forward_solution(fwd_fname) + + cov_fname = f'{subject_dir}/covariance/{subject}-55-sss-cov.fif' + cov = mne.read_cov(cov_fname) + + if subject == subjects[0]: + num_rois = len(labels) + timepts = len(epochs.times) + model = LDS(num_rois, timepts, lam0=0, lam1=100) + + model.add_subject(subject, subject_dir, epochs, labels, fwd, cov) #not using subject_dir + + +# model.fit(niter=100, verbose=1) +# A_t_ = model.A_t_ +# assert A_t_.shape == (timepts, num_rois, num_rois) + +# with mpl.rc_context(): +# {'xtick.labelsize': 'x-small', 'ytick.labelsize': 'x-small'} +# fig, ax = plt.subplots(num_rois, num_rois, constrained_layout=True, squeeze=False, +# figsize=(12, 10)) +# plot_A_t_(A_t_, labels=label_names, times=epochs.times, ax=ax) +# fig.suptitle(condition) + + + + + + + + + + + + + + + diff --git a/state_space/test_state_space.py b/state_space/test_state_space.py new file mode 100644 index 00000000..ce206c26 --- /dev/null +++ b/state_space/test_state_space.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Authors: Jordan Drew + +""" + +''' +Test script to ensure LDS API is functioning properly +''' + +import pickle +import mne +from megssm.models import LDS +import numpy as np + +def test_state_space_output(): + + # define paths to sample data + data_path = mne.datasets.sample.data_path() + sample_folder = data_path / 'MEG/sample' + + ## import raw data and find events + raw_fname = sample_folder / 'sample_audvis_raw.fif' + raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60) + events = mne.find_events(raw, stim_channel='STI 014') + + ## define epochs using event_dict + event_dict = {'auditory_left': 1, 'auditory_right': 2, 'visual_left': 3, + 'visual_right': 4} + epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=0.7, event_id=event_dict, + preload=True).pick_types(meg=True,eeg=True) + condition = 'auditory_left' + + ## read forward solution, remove bad channels + fwd_fname = sample_folder / 'sample_audvis-meg-eeg-oct-6-fwd.fif' + fwd = mne.read_forward_solution(fwd_fname) + + ## read in covariance + cov_fname = sample_folder / 'sample_audvis-cov.fif' + cov = mne.read_cov(cov_fname) + + ## read labels for analysis + label_names = ['Aud-lh', 'Aud-rh', 'Vis-lh', 'Vis-rh'] + labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label', + subject='sample') for label in label_names] + + # initiate model + model = LDS(lam0=0, lam1=100) + model.add_subject('sample', condition, epochs, labels, fwd, cov) + model.fit(niter=50, verbose=2) + + with open('sample A_t', 'rb') as f: + A_t_ = pickle.load(f) + np.testing.assert_allclose(A_t_, model.A_t_) + print('Model is working!') + +test_state_space_output() \ No newline at end of file