diff --git a/MANIFEST.in b/MANIFEST.in index 93b28376..5907df91 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -40,4 +40,6 @@ exclude .github/PULL_REQUEST_TEMPLATE.md exclude mne_connectivity/tests/data recursive-exclude mne_connectivity/tests/data * +exclude mne_connectivity/spectral/tests/data +recursive-exclude mne_connectivity/spectral/tests/data * recursive-exclude benchmarks * diff --git a/doc/authors.inc b/doc/authors.inc index 7360a788..21a3fad4 100644 --- a/doc/authors.inc +++ b/doc/authors.inc @@ -11,3 +11,6 @@ .. _Daniel McCloy: https://dan.mccloy.info .. _Sam Steingold: https://github.com/sam-s .. _Qianliang Li: https://github.com/Avoide +.. _Thomas Binns: https://github.com/tsbinns +.. _Tien Nguyen: https://github.com/nguyen-td +.. _Richard Köhler: https://github.com/richardkoehler diff --git a/doc/references.bib b/doc/references.bib index b8a4059b..f3a20b3f 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -11,6 +11,53 @@ @article{AvantsEtAl2008 year = {2008} } +@article{BarnettSeth2015, + title={Granger causality for state-space models}, + author={Barnett, Lionel and Seth, Anil K.}, + journal={Physical Review E}, + volume={91}, + number={4}, + pages={040101}, + year={2015}, + publisher={APS}, + doi={10.1103/PhysRevE.91.040101} +} + +@article{BrunaEtAl2018, + doi = {10.1088/1741-2552/aacfe4}, + year = {2018}, + publisher = {{IOP} Publishing}, + volume = {15}, + number = {5}, + pages = {056011}, + author = {Ricardo Bru{\~{n}}a, Fernando Maest{\'{u}}, Ernesto Pereda}, + title = {Phase locking value revisited: teaching new tricks to an old dog}, + journal = {Journal of Neural Engineering}, +} + +@article{ColcloughEtAl2015, + title = {A symmetric multivariate leakage correction for {MEG} connectomes}, + volume = {117}, + issn = {1053-8119}, + doi = {10.1016/j.neuroimage.2015.03.071}, + language = {en}, + journal = {NeuroImage}, + author = {Colclough, G. L. and Brookes, M. J. and Smith, S. M. and Woolrich, M. W.}, + month = aug, + year = {2015}, + pages = {439--448} +} + +@book{CrochiereRabiner1983, + address = {Englewood Cliffs, NJ}, + edition = {1 edition}, + title = {Multirate {Digital} {Signal} {Processing}}, + isbn = {978-0-13-605162-6}, + publisher = {Pearson}, + author = {Crochiere, Ronald E. and Rabiner, Lawrence R.}, + year = {1983} +} + @article{Dawson_2016, author={Dawson, Scott T. M. and Hemati, Maziar S. and Williams, Matthew O. and Rowley, Clarence W.}, DOI={10.1007/s00348-016-2127-7}, @@ -25,6 +72,105 @@ @article{Dawson_2016 year={2016}, } +@article{EwaldEtAl2012, + title={Estimating true brain connectivity from {EEG/MEG} data invariant to linear and static transformations in sensor space}, + author={Ewald, Arne and Marzetti, Laura and Zappasodi, Filippo and Meinecke, Frank C. and Nolte, Guido}, + journal={NeuroImage}, + volume={60}, + number={1}, + pages={476--488}, + year={2012}, + publisher={Elsevier}, + doi={10.1016/j.neuroimage.2011.11.084} +} + +@article{HaufeEtAl2013, + title={A critical assessment of connectivity measures for EEG data: a simulation study}, + author={Haufe, Stefan and Nikulin, Vadim V and M{\"u}ller, Klaus-Robert and Nolte, Guido}, + journal={NeuroImage}, + volume={64}, + pages={120--133}, + year={2013}, + publisher={Elsevier}, + doi={10.1016/j.neuroimage.2012.09.036} +} + +@article{HippEtAl2012, + author = {Hipp, Joerg F and Hawellek, David J and Corbetta, Maurizio and Siegel, Markus and Engel, Andreas K}, + doi = {10.1038/nn.3101}, + journal = {Nature Neuroscience}, + number = {6}, + pages = {884-890}, + title = {Large-Scale Cortical Correlation Structure of Spontaneous Oscillatory Activity}, + volume = {15}, + year = {2012} +} + +@article{KhanEtAl2018, + author = {Khan, Sheraz and Hashmi, Javeria A. and Mamashli, Fahimeh and Michmizos, Konstantinos and Kitzbichler, Manfred G. and Bharadwaj, Hari and Bekhti, Yousra and Ganesan, Santosh and Garel, Keri-Lee A. and {Whitfield-Gabrieli}, Susan and Gollub, Randy L. and Kong, Jian and Vaina, Lucia M. and Rana, Kunjan D. and Stufflebeam, Steven M. and Hämäläinen, Matti S. and Kenet, Tal}, + doi = {10.1016/j.neuroimage.2018.02.018}, + journal = {NeuroImage}, + pages = {57-68}, + title = {Maturation Trajectories of Cortical Resting-State Networks Depend on the Mediating Frequency Band}, + volume = {174}, + year = {2018} +} + +@article{KlimeschEtAl2004, + title = {Phase-locked alpha and theta oscillations generate the P1–N1 complex and are related to memory performance}, + journal = {Cognitive Brain Research}, + volume = {19}, + number = {3}, + pages = {302-316}, + year = {2004}, + issn = {0926-6410}, + doi = {https://doi.org/10.1016/j.cogbrainres.2003.11.016}, + author = {Wolfgang Klimesch and Bärbel Schack and Manuel Schabus and Michael Doppelmayr and Walter Gruber and Paul Sauseng} +} + +@article{LachauxEtAl1999, + author = {Lachaux, Jean-Philippe and Rodriguez, Eugenio and Martinerie, Jacques and Varela, Francisco J.}, + doi = {10.1002/(SICI)1097-0193(1999)8:4<194::AID-HBM4>3.0.CO;2-C}, + journal = {Human Brain Mapping}, + number = {4}, + pages = {194-208}, + title = {Measuring Phase Synchrony in Brain Signals}, + volume = {8}, + year = {1999} +} + +@INPROCEEDINGS{li_linear_2017, + author = {Li, Adam and Gunnarsdottir, Kristin M. and Inati, Sara and Zaghloul, Kareem and Gale, John and Bulacio, Juan and Martinez-Gonzalez, Jorge and Sarma, Sridevi V.}, + booktitle = {2017 39th Annual International Conference of the IEEE Engineering in Medicine and Biology Society (EMBC)}, + title = {Linear time-varying model characterizes invasive EEG signals generated from complex epileptic networks}, + year = {2017}, + volume = {}, + number = {}, + pages = {2802-2805}, + doi = {10.1109/EMBC.2017.8037439} +} + +@article{NolteEtAl2004, + author = {Nolte, Guido and Bai, Ou and Wheaton, Lewis and Mari, Zoltan and Vorbach, Sherry and Hallett, Mark}, + doi = {10.1016/j.clinph.2004.04.029}, + journal = {Clinical Neurophysiology}, + number = {10}, + pages = {2292-2307}, + title = {Identifying True Brain Interaction from {{EEG}} Data Using the Imaginary Part of Coherency}, + volume = {115}, + year = {2004} +} + +@article{NolteEtAl2008, + author = {Nolte, Guido and Ziehe, Andreas and Nikulin, Vadim V. and Schlögl, Alois and Krämer, Nicole and Brismar, Tom and Müller, Klaus-Robert}, + doi = {10.1103/PhysRevLett.100.234101}, + journal = {Physical Review Letters}, + number = {23}, + title = {Robustly Estimating the Flow Direction of Information in Complex Physical Systems}, + volume = {100}, + year = {2008} +} + @book{OppenheimEtAl1999, address = {Upper Saddle River, NJ}, edition = {2 edition}, @@ -59,6 +205,31 @@ @article{SmithNichols2009 year = {2009} } +@article{StamEtAl2007, + author = {Stam, Cornelis J. and Nolte, Guido and Daffertshofer, Andreas}, + doi = {10.1002/hbm.20346}, + journal = {Human Brain Mapping}, + number = {11}, + pages = {1178-1193}, + shorttitle = {Phase Lag Index}, + title = {Phase Lag Index: Assessment of Functional Connectivity from Multi Channel {{EEG}} and {{MEG}} with Diminished Bias from Common Sources}, + volume = {28}, + year = {2007} +} + +@article{StamEtAl2012, + title={Go with the flow: Use of a directed phase lag index (dPLI) to characterize patterns of phase relations in a large-scale model of brain dynamics}, + volume={62}, + ISSN={1053-8119}, + DOI={10.1016/j.neuroimage.2012.05.050}, + number={3}, + journal={NeuroImage}, + author={Stam, C. J. and van Straaten, E. C. W.}, + year={2012}, + month={Sep}, + pages={1415–1428} +} + @article{VanVeenEtAl1997, author = {Van Veen, Barry D. and {van Drongelen}, Wim and Yuchtman, Moshe and Suzuki, Akifumi}, doi = {10.1109/10.623056}, @@ -79,6 +250,29 @@ @article{vanVlietEtAl2018 year = {2018} } +@article{VidaurreEtAl2019, + title={Canonical maximization of coherence: a novel tool for investigation of neuronal interactions between two datasets}, + author={Vidaurre, Carmen and Nolte, Guido and de Vries, Ingmar E.J. and G{\'o}mez, M. and Boonstra, Tjeerd W. and M{\"u}ller, K.-R. and Villringer, Arno and Nikulin, Vadim V.}, + journal={NeuroImage}, + volume={201}, + pages={116009}, + year={2019}, + publisher={Elsevier}, + doi={10.1016/j.neuroimage.2019.116009} +} + +@article{VinckEtAl2010, + author = {Vinck, Martin and {van Wingerden}, Marijn and Womelsdorf, Thilo and Fries, Pascal and Pennartz, Cyriel M.A.}, + doi = {10.1016/j.neuroimage.2010.01.073}, + journal = {NeuroImage}, + number = {1}, + pages = {112-122}, + shorttitle = {The Pairwise Phase Consistency}, + title = {The Pairwise Phase Consistency: A Bias-Free Measure of Rhythmic Neuronal Synchronization}, + volume = {51}, + year = {2010} +} + @article{VinckEtAl2011, author = {Vinck, Martin and Oostenveld, Robert and {van Wingerden}, Marijn and Battaglia, Franscesco and Pennartz, Cyriel M.A.}, doi = {10.1016/j.neuroimage.2011.01.055}, @@ -90,14 +284,39 @@ @article{VinckEtAl2011 year = {2011} } -@book{CrochiereRabiner1983, - address = {Englewood Cliffs, NJ}, - edition = {1 edition}, - title = {Multirate {Digital} {Signal} {Processing}}, - isbn = {978-0-13-605162-6}, - publisher = {Pearson}, - author = {Crochiere, Ronald E. and Rabiner, Lawrence R.}, - year = {1983} +@article{VinckEtAl2015, + title={How to detect the Granger-causal flow direction in the presence of additive noise?}, + author={Vinck, Martin and Huurdeman, Lisanne and Bosman, Conrado A and Fries, Pascal and Battaglia, Francesco P and Pennartz, Cyriel MA and Tiesinga, Paul H}, + journal={NeuroImage}, + volume={108}, + pages={301--318}, + year={2015}, + publisher={Elsevier}, + doi={10.1016/j.neuroimage.2014.12.017} +} + +@article{Whittle1963, + title={On the fitting of multivariate autoregressions, and the approximate canonical factorization of a spectral density matrix}, + author={Whittle, Peter}, + journal={Biometrika}, + volume={50}, + number={1-2}, + pages={129--134}, + year={1963}, + publisher={Oxford University Press}, + doi={10.1093/biomet/50.1-2.129} +} + +@article{WinklerEtAl2016, + title={Validity of time reversal for testing Granger causality}, + author={Winkler, Irene and Panknin, Danny and Bartz, Daniel and M{\"u}ller, Klaus-Robert and Haufe, Stefan}, + journal={IEEE Transactions on Signal Processing}, + volume={64}, + number={11}, + pages={2746--2760}, + year={2016}, + publisher={IEEE}, + doi={10.1109/TSP.2016.2531628} } @article{Yao2001, @@ -113,146 +332,7 @@ @article{Yao2001 pages = {693--711} } -@article{HippEtAl2012, - author = {Hipp, Joerg F and Hawellek, David J and Corbetta, Maurizio and Siegel, Markus and Engel, Andreas K}, - doi = {10.1038/nn.3101}, - journal = {Nature Neuroscience}, - number = {6}, - pages = {884-890}, - title = {Large-Scale Cortical Correlation Structure of Spontaneous Oscillatory Activity}, - volume = {15}, - year = {2012} -} - -@article{KhanEtAl2018, - author = {Khan, Sheraz and Hashmi, Javeria A. and Mamashli, Fahimeh and Michmizos, Konstantinos and Kitzbichler, Manfred G. and Bharadwaj, Hari and Bekhti, Yousra and Ganesan, Santosh and Garel, Keri-Lee A. and {Whitfield-Gabrieli}, Susan and Gollub, Randy L. and Kong, Jian and Vaina, Lucia M. and Rana, Kunjan D. and Stufflebeam, Steven M. and Hämäläinen, Matti S. and Kenet, Tal}, - doi = {10.1016/j.neuroimage.2018.02.018}, - journal = {NeuroImage}, - pages = {57-68}, - title = {Maturation Trajectories of Cortical Resting-State Networks Depend on the Mediating Frequency Band}, - volume = {174}, - year = {2018} -} - -@article{NolteEtAl2008, - author = {Nolte, Guido and Ziehe, Andreas and Nikulin, Vadim V. and Schlögl, Alois and Krämer, Nicole and Brismar, Tom and Müller, Klaus-Robert}, - doi = {10.1103/PhysRevLett.100.234101}, - journal = {Physical Review Letters}, - number = {23}, - title = {Robustly Estimating the Flow Direction of Information in Complex Physical Systems}, - volume = {100}, - year = {2008} -} - - -@article{LachauxEtAl1999, - author = {Lachaux, Jean-Philippe and Rodriguez, Eugenio and Martinerie, Jacques and Varela, Francisco J.}, - doi = {10.1002/(SICI)1097-0193(1999)8:4<194::AID-HBM4>3.0.CO;2-C}, - journal = {Human Brain Mapping}, - number = {4}, - pages = {194-208}, - title = {Measuring Phase Synchrony in Brain Signals}, - volume = {8}, - year = {1999} -} - -@article{StamEtAl2007, - author = {Stam, Cornelis J. and Nolte, Guido and Daffertshofer, Andreas}, - doi = {10.1002/hbm.20346}, - journal = {Human Brain Mapping}, - number = {11}, - pages = {1178-1193}, - shorttitle = {Phase Lag Index}, - title = {Phase Lag Index: Assessment of Functional Connectivity from Multi Channel {{EEG}} and {{MEG}} with Diminished Bias from Common Sources}, - volume = {28}, - year = {2007} -} - -@article{VinckEtAl2010, - author = {Vinck, Martin and {van Wingerden}, Marijn and Womelsdorf, Thilo and Fries, Pascal and Pennartz, Cyriel M.A.}, - doi = {10.1016/j.neuroimage.2010.01.073}, - journal = {NeuroImage}, - number = {1}, - pages = {112-122}, - shorttitle = {The Pairwise Phase Consistency}, - title = {The Pairwise Phase Consistency: A Bias-Free Measure of Rhythmic Neuronal Synchronization}, - volume = {51}, - year = {2010} -} - -@article{BrunaEtAl2018, - doi = {10.1088/1741-2552/aacfe4}, - year = {2018}, - publisher = {{IOP} Publishing}, - volume = {15}, - number = {5}, - pages = {056011}, - author = {Ricardo Bru{\~{n}}a, Fernando Maest{\'{u}}, Ernesto Pereda}, - title = {Phase locking value revisited: teaching new tricks to an old dog}, - journal = {Journal of Neural Engineering}, -} - -@article{NolteEtAl2004, - author = {Nolte, Guido and Bai, Ou and Wheaton, Lewis and Mari, Zoltan and Vorbach, Sherry and Hallett, Mark}, - doi = {10.1016/j.clinph.2004.04.029}, - journal = {Clinical Neurophysiology}, - number = {10}, - pages = {2292-2307}, - title = {Identifying True Brain Interaction from {{EEG}} Data Using the Imaginary Part of Coherency}, - volume = {115}, - year = {2004} -} - -@INPROCEEDINGS{li_linear_2017, - author = {Li, Adam and Gunnarsdottir, Kristin M. and Inati, Sara and Zaghloul, Kareem and Gale, John and Bulacio, Juan and Martinez-Gonzalez, Jorge and Sarma, Sridevi V.}, - booktitle = {2017 39th Annual International Conference of the IEEE Engineering in Medicine and Biology Society (EMBC)}, - title = {Linear time-varying model characterizes invasive EEG signals generated from complex epileptic networks}, - year = {2017}, - volume = {}, - number = {}, - pages = {2802-2805}, - doi = {10.1109/EMBC.2017.8037439} -} - -@article{ColcloughEtAl2015, - title = {A symmetric multivariate leakage correction for {MEG} connectomes}, - volume = {117}, - issn = {1053-8119}, - doi = {10.1016/j.neuroimage.2015.03.071}, - language = {en}, - journal = {NeuroImage}, - author = {Colclough, G. L. and Brookes, M. J. and Smith, S. M. and Woolrich, M. W.}, - month = aug, - year = {2015}, - pages = {439--448} -} - - @article{StamEtAl2012, - title={Go with the flow: Use of a directed phase lag index (dPLI) to characterize patterns of phase relations in a large-scale model of brain dynamics}, - volume={62}, - ISSN={1053-8119}, - DOI={10.1016/j.neuroimage.2012.05.050}, - number={3}, - journal={NeuroImage}, - author={Stam, C. J. and van Straaten, E. C. W.}, - year={2012}, - month={Sep}, - pages={1415–1428} -} - - @article{KlimeschEtAl2004, - title = {Phase-locked alpha and theta oscillations generate the P1–N1 complex and are related to memory performance}, - journal = {Cognitive Brain Research}, - volume = {19}, - number = {3}, - pages = {302-316}, - year = {2004}, - issn = {0926-6410}, - doi = {https://doi.org/10.1016/j.cogbrainres.2003.11.016}, - author = {Wolfgang Klimesch and Bärbel Schack and Manuel Schabus and Michael Doppelmayr and Walter Gruber and Paul Sauseng} -} - - @article{Zimmermann2022, +@article{Zimmermann2022, author = {Zimmermann, Marius and Lomoriello, Arianna Schiano and Konvalinka, Ivana}, doi = {10.1098/rsos.211352}, issn = {20545703}, diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 5016010d..a844c282 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -25,6 +25,8 @@ Enhancements - Add the option to set the number of connections plotted in :func:`mne_connectivity.viz.plot_sensors_connectivity` by `Qianliang Li`_ (:pr:`133`). - Allow setting colormap via new parameter ``cmap`` in :func:`mne_connectivity.viz.plot_sensors_connectivity` by `Daniel McCloy`_ (:pr:`141`). +- Add support for multivariate connectivity methods in :func:`mne_connectivity.spectral_connectivity_epochs` and :func:`mne_connectivity.spectral_connectivity_time` by `Thomas Binns`_ and `Tien Nguyen`_ and `Richard Köhler`_ (:pr:`138`). + Bug ~~~ diff --git a/examples/granger_causality.py b/examples/granger_causality.py new file mode 100644 index 00000000..f5d8316d --- /dev/null +++ b/examples/granger_causality.py @@ -0,0 +1,414 @@ +""" +========================================================================== +Compute directionality of connectivity with multivariate Granger causality +========================================================================== + +This example demonstrates how Granger causality based on state-space models +:footcite:`BarnettSeth2015` can be used to compute directed connectivity +between sensors in a multivariate manner. Furthermore, the use of time-reversal +for improving the robustness of directed connectivity estimates to noise in the +data is discussed :footcite:`WinklerEtAl2016`. +""" + +# Author: Thomas S. Binns +# License: BSD (3-clause) + +# %% + +import numpy as np +from matplotlib import pyplot as plt + +import mne +from mne.datasets.fieldtrip_cmc import data_path +from mne_connectivity import spectral_connectivity_epochs + +############################################################################### +# Background +# ---------- +# +# Multivariate forms of signal analysis allow you to simultaneously consider +# the activity of multiple signals. In the case of connectivity, the +# interaction between multiple sensors can be analysed at once, producing a +# single connectivity spectrum. This approach brings not only practical +# benefits (e.g. easier interpretability of results from the dimensionality +# reduction), but can also offer methodological improvements (e.g. enhanced +# signal-to-noise ratio and reduced bias). +# +# Additionally, it can be of interest to examine the directionality of +# connectivity between signals, providing additional clarity to how information +# flows in a system. One such directed measure of connectivity is Granger +# causality (GC). A signal, :math:`\boldsymbol{x}`, is said to Granger-cause +# another signal, :math:`\boldsymbol{y}`, if information from the past of +# :math:`\boldsymbol{x}` improves the prediction of the present of +# :math:`\boldsymbol{y}` over the case where only information from the past of +# :math:`\boldsymbol{y}` is used. Note: GC does not make any assertions about +# the true causality between signals. +# +# The degree to which :math:`\boldsymbol{x}` and :math:`\boldsymbol{y}` can be +# used to predict one another in a linear model can be quantified using vector +# autoregressive (VAR) models. Considering the simpler case of time domain +# connectivity, the VAR models are as follows: +# +# :math:`y_t = \sum_{k=1}^{K} a_k y_{t-k} + \xi_t^y` , +# :math:`Var(\xi_t^y) := \Sigma_y` , +# +# and :math:`\boldsymbol{z}_t = \sum_{k=1}^K \boldsymbol{A}_k +# \boldsymbol{z}_{t-k} + \boldsymbol{\epsilon}_t` , +# :math:`\boldsymbol{\Sigma} := \langle \boldsymbol{\epsilon}_t +# \boldsymbol{\epsilon}_t^T \rangle = \begin{bmatrix} \Sigma_{xx} & \Sigma_{xy} +# \\ \Sigma_{yx} & \Sigma_{yy} \end{bmatrix}` , +# +# representing the reduced and full VAR models, respectively, where: :math:`K` +# is the order of the VAR model, determining the number of lags, :math:`k`, +# used; :math:`\boldsymbol{Z} := \begin{bmatrix} \boldsymbol{x} \\ +# \boldsymbol{y} \end{bmatrix}`; :math:`\boldsymbol{A}` is a matrix of +# coefficients explaining the contribution of past entries of +# :math:`\boldsymbol{Z}` to its current value; and :math:`\xi` and +# :math:`\boldsymbol{\epsilon}` are the residuals of the VAR models. In this +# way, the information of the signals at time :math:`t` can be represented as a +# weighted form of the information from the previous timepoints, plus some +# residual information not encoded in the signals' past. In practice, VAR model +# parameters are computed from an autocovariance sequence generated from the +# time-series data using the Yule-Walker equations :footcite:`Whittle1963`. +# +# The residuals, or errors, represent how much information about the present +# state of the signals is not explained by their past. We can therefore +# estimate how much :math:`\boldsymbol{x}` Granger-causes +# :math:`\boldsymbol{y}` by comparing the variance of the residuals of the +# reduced VAR model (:math:`\Sigma_y`; i.e. how much the present of +# :math:`\boldsymbol{y}` is not explained by its own past) and of the full VAR +# model (:math:`\Sigma_{yy}`; i.e. how much the present of +# :math:`\boldsymbol{y}` is not explained by both its own past and that of +# :math:`\boldsymbol{x}`): +# +# :math:`F_{x \rightarrow y} = ln \Large{(\frac{\Sigma_y}{\Sigma_{yy}})}` , +# +# where :math:`F` is the Granger score. For example, if :math:`\boldsymbol{x}` +# contains no information about :math:`\boldsymbol{y}`, the residuals of the +# reduced and full VAR models will be identical, and +# :math:`F_{x \rightarrow y}` will naturally be 0, indicating that +# information from :math:`\boldsymbol{x}` does not flow to +# :math:`\boldsymbol{y}`. In contrast, if :math:`\boldsymbol{x}` does help to +# predict :math:`\boldsymbol{y}`, the residual of the full model will be +# smaller than that of the reduced model. :math:`\Large{\frac{\Sigma_y} +# {\Sigma_{yy}}}` will therefore be greater than 1, leading to a Granger score +# > 0. Granger scores are bound between :math:`[0, \infty)`. +# +# These same principles apply to spectral GC, which provides information about +# the directionality of connectivity for individual frequencies. For spectral +# GC, the autocovariance sequence is generated from an inverse Fourier +# transform applied to the cross-spectral density of the signals. Additionally, +# a spectral transfer function is used to translate information from the VAR +# models back into the frequency domain before computing the final Granger +# scores. +# +# Barnett and Seth (2015) :footcite:`BarnettSeth2015` have defined a +# multivariate form of spectral GC based on state-space models, enabling the +# estimation of information flow between whole sets of signals simultaneously: +# +# :math:`F_{A \rightarrow B}(f) = \Re ln \Large{(\frac{ +# det(\boldsymbol{S}_{BB}(f))}{det(\boldsymbol{S}_{BB}(f) - +# \boldsymbol{H}_{BA}(f) \boldsymbol{\Sigma}_{AA \lvert B} +# \boldsymbol{H}_{BA}^*(f))})}` , +# +# where: :math:`A` and :math:`B` are the seeds and targets, respectively; +# :math:`f` is a given frequency; :math:`\boldsymbol{H}` is the spectral +# transfer function; :math:`\boldsymbol{\Sigma}` is the innovations form +# residuals' covariance matrix of the state-space model; :math:`\boldsymbol{S}` +# is :math:`\boldsymbol{\Sigma}` transformed by :math:`\boldsymbol{H}`; and +# :math:`\boldsymbol{\Sigma}_{IJ \lvert K} := \boldsymbol{\Sigma}_{IJ} - +# \boldsymbol{\Sigma}_{IK} \boldsymbol{\Sigma}_{KK}^{-1} +# \boldsymbol{\Sigma}_{KJ}`, representing a partial covariance matrix. The same +# principles apply as before: a numerator greater than the denominator means +# that information from the seed signals aids the prediction of activity in the +# target signals, leading to a Granger score > 0. +# +# There are several benefits to a state-space approach for computing GC: +# compared to traditional autoregressive-based approaches, the use of +# state-space models offers reduced statistical bias and increased statistical +# power; furthermore, the dimensionality reduction offered by the multivariate +# nature of the approach can aid in the interpretability and subsequent +# analysis of the results. +# +# To demonstrate the use of GC for estimating directed connectivity, we start +# by loading some example MEG data and dividing it into two-second-long epochs. + +# %% + +raw = mne.io.read_raw_ctf(data_path() / 'SubjectCMC.ds') +raw.pick('mag') +raw.crop(50., 110.).load_data() +raw.notch_filter(50) +raw.resample(100) + +epochs = mne.make_fixed_length_epochs(raw, duration=2.0).load_data() + +############################################################################### +# We will focus on connectivity between sensors over the parietal and occipital +# cortices, with 20 parietal sensors designated as group A, and 20 occipital +# sensors designated as group B. + +# %% + +# parietal sensors +signals_a = [idx for idx, ch_info in enumerate(epochs.info['chs']) if + ch_info['ch_name'][2] == 'P'] +# occipital sensors +signals_b = [idx for idx, ch_info in enumerate(epochs.info['chs']) if + ch_info['ch_name'][2] == 'O'] + +# XXX: Currently ragged indices are not supported, so we only consider a single +# list of indices with an equal number of seeds and targets +min_n_chs = min(len(signals_a), len(signals_b)) +signals_a = signals_a[:min_n_chs] +signals_b = signals_b[:min_n_chs] + +indices_ab = (np.array(signals_a), np.array(signals_b)) # A => B +indices_ba = (np.array(signals_b), np.array(signals_a)) # B => A + +signals_a_names = [epochs.info['ch_names'][idx] for idx in signals_a] +signals_b_names = [epochs.info['ch_names'][idx] for idx in signals_b] + +# compute Granger causality +gc_ab = spectral_connectivity_epochs( + epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, + rank=(np.array([5]), np.array([5])), gc_n_lags=20) # A => B +gc_ba = spectral_connectivity_epochs( + epochs, method=['gc'], indices=indices_ba, fmin=5, fmax=30, + rank=(np.array([5]), np.array([5])), gc_n_lags=20) # B => A +freqs = gc_ab.freqs + + +############################################################################### +# Plotting the results, we see that there is a flow of information from our +# parietal sensors (group A) to our occipital sensors (group B) with noticeable +# peaks at around 8, 18, and 26 Hz. + +# %% + +fig, axis = plt.subplots(1, 1) +axis.plot(freqs, gc_ab.get_data()[0], linewidth=2) +axis.set_xlabel('Frequency (Hz)') +axis.set_ylabel('Connectivity (A.U.)') +fig.suptitle('GC: [A => B]') + + +############################################################################### +# Drivers and receivers: analysing the net direction of information flow +# ---------------------------------------------------------------------- +# +# Although analysing connectivity in a given direction can be of interest, +# there may exist a bidirectional relationship between signals. In such cases, +# identifying the signals that dominate information flow (the drivers) may be +# desired. For this, we can simply subtract the Granger scores in the opposite +# direction, giving us the net GC score: +# +# :math:`F_{A \rightarrow B}^{net} := F_{A \rightarrow B} - +# F_{B \rightarrow A}`. +# +# Doing so, we see that the flow of information across the spectrum remains +# dominant from parietal to occipital sensors (indicated by the positive-valued +# Granger scores). However, the pattern of connectivity is altered, such as +# around 10 and 12 Hz where peaks of net information flow are now present. + +# %% + +net_gc = gc_ab.get_data() - gc_ba.get_data() # [A => B] - [B => A] + +fig, axis = plt.subplots(1, 1) +axis.plot((freqs[0], freqs[-1]), (0, 0), linewidth=2, linestyle='--', + color='k') +axis.plot(freqs, net_gc[0], linewidth=2) +axis.set_xlabel('Frequency (Hz)') +axis.set_ylabel('Connectivity (A.U.)') +fig.suptitle('Net GC: [A => B] - [B => A]') + + +############################################################################### +# Improving the robustness of connectivity estimates with time-reversal +# --------------------------------------------------------------------- +# +# One limitation of GC methods is the risk of connectivity estimates being +# contaminated with noise. For instance, consider the case where, due to +# volume conduction, multiple sensors detect activity from the same source. +# Naturally, information recorded at these sensors mutually help to predict +# the activity of one another, leading to spurious estimates of directed +# connectivity which one may incorrectly attribute to information flow between +# different brain regions. On the other hand, even if there is no source +# mixing, the presence of correlated noise between sensors can similarly bias +# directed connectivity estimates. +# +# To address this issue, Haufe *et al.* (2013) :footcite:`HaufeEtAl2013` +# propose contrasting causality scores obtained on the original time-series to +# those obtained on the reversed time-series. The idea behind this approach is +# as follows: if temporal order is crucial in distinguishing a driver from a +# recipient, then reversing the temporal order should reduce, if not flip, an +# estimate of directed connectivity. In practice, time-reversal is implemented +# as a transposition of the autocovariance sequence used to compute GC. Several +# studies have shown that that such an approach can reduce the degree of +# false-positive connectivity estimates (even performing favourably against +# other methods such as the phase slope index) :footcite:`VinckEtAl2015` and +# retain the ability to correctly identify the net direction of information +# flow akin to net GC :footcite:`WinklerEtAl2016,HaufeEtAl2013`. This approach +# is termed time-reversed GC (TRGC): +# +# :math:`\tilde{D}_{A \rightarrow B}^{net} := F_{A \rightarrow B}^{net} - +# F_{\tilde{A} \rightarrow \tilde{B}}^{net}` , +# +# where :math:`\sim` represents time-reversal, and: +# +# :math:`F_{\tilde{A} \rightarrow \tilde{B}}^{net} := F_{\tilde{A} \rightarrow +# \tilde{B}} - F_{\tilde{B} \rightarrow \tilde{A}}`. +# +# GC on time-reversed signals can be computed simply with ``method=['gc_tr']``, +# which will perform the time-reversal of the signals for the end-user. Note +# that **time-reversed results should only be interpreted in the context of net +# results**, i.e. with :math:`\tilde{D}_{A \rightarrow B}^{net}`. In the +# example below, notice how the outputs are not used directly, but rather used +# to produce net scores of the time-reversed signals. The net scores of the +# time-reversed signals can then be subtracted from the net scores of the +# original signals to produce the final TRGC scores. + +# %% + +# compute GC on time-reversed signals +gc_tr_ab = spectral_connectivity_epochs( + epochs, method=['gc_tr'], indices=indices_ab, fmin=5, fmax=30, + rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[A => B] +gc_tr_ba = spectral_connectivity_epochs( + epochs, method=['gc_tr'], indices=indices_ba, fmin=5, fmax=30, + rank=(np.array([5]), np.array([5])), gc_n_lags=20) # TR[B => A] + +# compute net GC on time-reversed signals (TR[A => B] - TR[B => A]) +net_gc_tr = gc_tr_ab.get_data() - gc_tr_ba.get_data() + +# compute TRGC +trgc = net_gc - net_gc_tr + +############################################################################### +# Plotting the TRGC results, reveals a very different picture compared to net +# GC. For one, there is now a dominance of information flow ~6 Hz from +# occipital to parietal sensors (indicated by the negative-valued Granger +# scores). Additionally, the peaks ~10 Hz are less dominant in the spectrum, +# with parietal to occipital information flow between 13-20 Hz being much more +# prominent. The stark difference between net GC and TRGC results indicates +# that the net GC spectrum was contaminated by spurious connectivity resulting +# from source mixing or correlated noise in the recordings. Altogether, the use +# of TRGC instead of net GC is generally advised. + +# %% + +fig, axis = plt.subplots(1, 1) +axis.plot((freqs[0], freqs[-1]), (0, 0), linewidth=2, linestyle='--', + color='k') +axis.plot(freqs, trgc[0], linewidth=2) +axis.set_xlabel('Frequency (Hz)') +axis.set_ylabel('Connectivity (A.U.)') +fig.suptitle('TRGC: net[A => B] - net time-reversed[A => B]') + + +############################################################################### +# Controlling spectral smoothing with the number of lags +# ------------------------------------------------------ +# +# One important parameter when computing GC is the number of lags used when +# computing the VAR model. A lower number of lags reduces the computational +# cost, but in the context of spectral GC, leads to a smoothing of Granger +# scores across frequencies. The number of lags can be specified using the +# ``gc_n_lags`` parameter. The default value is 40, however there is no correct +# number of lags to use when computing GC. Instead, you have to use your own +# best judgement of whether or not your Granger scores look overly smooth. +# +# Below is a comparison of Granger scores computed with a different number of +# lags. In the above examples we used 20 lags, which we will compare to Granger +# scores computed with 60 lags. As you can see, the spectra of Granger scores +# computed with 60 lags is noticeably less smooth, but it does share the same +# overall pattern. + +# %% + +gc_ab_60 = spectral_connectivity_epochs( + epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, + rank=(np.array([5]), np.array([5])), gc_n_lags=60) # A => B + +fig, axis = plt.subplots(1, 1) +axis.plot(freqs, gc_ab.get_data()[0], linewidth=2, label='20 lags') +axis.plot(freqs, gc_ab_60.get_data()[0], linewidth=2, label='60 lags') +axis.set_xlabel('Frequency (Hz)') +axis.set_ylabel('Connectivity (A.U.)') +axis.legend() +fig.suptitle('GC: [A => B]') + + +############################################################################### +# Handling high-dimensional data +# ------------------------------ +# +# An important issue to consider when computing multivariate GC is that the +# data GC is computed on should not be rank deficient (i.e. must have full +# rank). More specifically, the autocovariance matrix must not be singular or +# close to singular. +# +# In the case that your data is not full rank and ``rank`` is left as ``None``, +# an automatic rank computation is performed and an appropriate degree of +# dimensionality reduction will be enforced. The rank of the data is determined +# by computing the singular values of the data and finding those within a +# factor of :math:`1e^{-10}` relative to the largest singular value. +# +# In some circumstances, this threshold may be too lenient, in which case you +# should inspect the singular values of your data to identify an appropriate +# degree of dimensionality reduction to perform, which you can then specify +# manually using the ``rank`` argument. The code below shows one possible +# approach for finding an appropriate rank of close-to-singular data with a +# more conservative threshold of :math:`1e^{-5}`. + +# %% + +# gets the singular values of the data +s = np.linalg.svd(raw.get_data(), compute_uv=False) +# finds how many singular values are "close" to the largest singular value +rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria + +############################################################################### +# Nonethless, even in situations where you specify an appropriate rank, it is +# not guaranteed that the subsequently-computed autocovariance sequence will +# retain this non-singularity (this can depend on, e.g. the number of lags). +# Hence, you may also encounter situations where you have to specify a rank +# less than that of your data to ensure that the autocovariance sequence is +# non-singular. +# +# In the above examples, notice how a rank of 5 was given, despite there being +# 20 channels in the seeds and targets. Attempting to compute GC on the +# original data would not succeed, given that the resulting autocovariance +# sequence is singular, as the example below shows. + +# %% + +try: + spectral_connectivity_epochs( + epochs, method=['gc'], indices=indices_ab, fmin=5, fmax=30, rank=None, + gc_n_lags=20) # A => B + print('Success!') +except RuntimeError as error: + print('\nCaught the following error:\n' + repr(error)) + +############################################################################### +# Rigorous checks are implemented to identify any such instances which would +# otherwise cause the GC computation to produce erroneous results. You can +# therefore be confident as an end-user that these cases will be caught. +# +# Finally, when comparing GC scores across recordings, **it is highly +# recommended to estimate connectivity from the same number of channels (or +# equally from the same degree of rank subspace projection)** to avoid biases +# in connectivity estimates. Bias can be avoided by specifying a consistent +# rank subspace to project to using the ``rank`` argument, standardising your +# connectivity estimates regardless of changes in e.g. the number of channels +# across recordings. Note that this does not refer to the number of seeds and +# targets *within* a connection being identical, rather to the number of seeds +# and targets *across* connections. + + +############################################################################### +# References +# ---------- +# .. footbibliography:: diff --git a/examples/mic_mim.py b/examples/mic_mim.py new file mode 100644 index 00000000..179ea620 --- /dev/null +++ b/examples/mic_mim.py @@ -0,0 +1,432 @@ +""" +================================================================ +Compute multivariate measures of the imaginary part of coherency +================================================================ + +This example demonstrates how multivariate methods based on the imaginary part +of coherency :footcite:`EwaldEtAl2012` can be used to compute connectivity +between whole sets of sensors, and how spatial patterns of this connectivity +can be interpreted. + +The methods in question are: the maximised imaginary part of coherency (MIC); +and the multivariate interaction measure (MIM; as well as its extension, the +global interaction measure, GIM). +""" + +# Author: Thomas S. Binns +# License: BSD (3-clause) + +# %% + +import numpy as np +from matplotlib import pyplot as plt +from matplotlib import patheffects as pe + +import mne +from mne import EvokedArray, make_fixed_length_epochs +from mne.datasets.fieldtrip_cmc import data_path +from mne_connectivity import seed_target_indices, spectral_connectivity_epochs + +############################################################################### +# Background +# ---------- +# +# Multivariate forms of signal analysis allow you to simultaneously consider +# the activity of multiple signals. In the case of connectivity, the +# interaction between multiple sensors can be analysed at once, producing a +# single connectivity spectrum. This approach brings not only practical +# benefits (e.g. easier interpretability of results from the dimensionality +# reduction), but can also offer methodological improvements (e.g. enhanced +# signal-to-noise ratio and reduced bias). +# +# A popular bivariate measure of connectivity is the imaginary part of +# coherency, which looks at the correlation between two signals in the +# frequency domain and is immune to spurious connectivity arising from volume +# conduction artefacts :footcite:`NolteEtAl2004`. However, depending on the +# degree of source mixing, this measure is susceptible to biased estimates of +# connectivity based on the spatial proximity of sensors +# :footcite:`EwaldEtAl2012`. +# +# To overcome this limitation, spatial filters can be used to estimate +# connectivity free from this source mixing-dependent bias, which additionally +# increases the signal-to-noise ratio and allows signals to be analysed in a +# multivariate manner :footcite:`EwaldEtAl2012`. This approach leads to the +# following methods: the maximised imaginary part of coherency (MIC); and the +# multivariate interaction measure (MIM). +# +# We start by loading some example MEG data and dividing it into +# two-second-long epochs. + +# %% + +raw = mne.io.read_raw_ctf(data_path() / 'SubjectCMC.ds') +raw.pick('mag') +raw.crop(50., 110.).load_data() +raw.notch_filter(50) +raw.resample(100) + +epochs = make_fixed_length_epochs(raw, duration=2.0).load_data() + +############################################################################### +# We will focus on connectivity between sensors over the left and right +# hemispheres, with 75 sensors in the left hemisphere designated as seeds, and +# 75 sensors in the right hemisphere designated as targets. + +# %% + +# left hemisphere sensors +seeds = [idx for idx, ch_info in enumerate(epochs.info['chs']) if + ch_info['loc'][0] < 0] +# right hemisphere sensors +targets = [idx for idx, ch_info in enumerate(epochs.info['chs']) if + ch_info['loc'][0] > 0] + +# XXX: Currently ragged indices are not supported, so we only consider a single +# list of indices with an equal number of seeds and targets +min_n_chs = min(len(seeds), len(targets)) +seeds = seeds[:min_n_chs] +targets = targets[:min_n_chs] + +multivar_indices = (np.array(seeds), np.array(targets)) + +seed_names = [epochs.info['ch_names'][idx] for idx in seeds] +target_names = [epochs.info['ch_names'][idx] for idx in targets] + +# multivariate imaginary part of coherency +(mic, mim) = spectral_connectivity_epochs( + epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, + rank=None) + +# bivariate imaginary part of coherency (for comparison) +bivar_indices = seed_target_indices(seeds, targets) +imcoh = spectral_connectivity_epochs( + epochs, method='imcoh', indices=bivar_indices, fmin=5, fmax=30) + +############################################################################### +# By averaging across each connection between the seeds and targets, we can see +# that the bivariate measure of the imaginary part of coherency estimates a +# strong peak in connectivity between seeds and targets around 13-18 Hz, with a +# weaker peak around 27 Hz. + +# %% +fig, axis = plt.subplots(1, 1) +axis.plot(imcoh.freqs, np.mean(np.abs(imcoh.get_data()), axis=0), + linewidth=2) +axis.set_xlabel('Frequency (Hz)') +axis.set_ylabel('Absolute connectivity (A.U.)') +fig.suptitle('Imaginary part of coherency') + + +############################################################################### +# Maximised imaginary part of coherency (MIC) +# ------------------------------------------- +# +# For MIC, a set of spatial filters are found that will maximise the estimated +# connectivity between the seed and target signals. These maximising filters +# correspond to the eigenvectors with the largest eigenvalue, derived from an +# eigendecomposition of information from the cross-spectral density (Eq. 7 of +# :footcite:`EwaldEtAl2012`): +# +# :math:`MIC=\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}}{\parallel +# \boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta}\parallel}`, +# +# where :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are the +# spatial filters for the seeds and targets, respectively, and +# :math:`\boldsymbol{E}` is the imaginary part of the transformed +# cross-spectral density between the seeds and targets. All elements are +# frequency-dependent, however this is omitted for readability. MIC is bound +# between :math:`[-1, 1]` where the absolute value reflects connectivity +# strength and the sign reflects the phase angle difference between signals. +# +# MIC can also be computed between identical sets of seeds and targets, +# allowing connectivity within a single set of signals to be estimated. This is +# possible as a result of the exclusion of zero phase lag components from the +# connectivity estimates, which would otherwise return a perfect correlation. +# +# In this instance, we see MIC reveal that in addition to the 13-18 Hz peak, a +# previously unobserved peak in connectivity around 9 Hz is present. +# Furthermore, the previous peak around 27 Hz is much less pronounced. This may +# indicate that the connectivity was the result of some distal interaction +# exacerbated by strong source mixing, which biased the bivariate connectivity +# estimate. + +# %% + +fig, axis = plt.subplots(1, 1) +axis.plot(mic.freqs, np.abs(mic.get_data()[0]), linewidth=2) +axis.set_xlabel('Frequency (Hz)') +axis.set_ylabel('Absolute connectivity (A.U.)') +fig.suptitle('Maximised imaginary part of coherency') + + +############################################################################### +# Furthermore, spatial patterns of connectivity can be constructed from the +# spatial filters to give a picture of the location of the sources involved in +# the connectivity. This information is stored under ``attrs['patterns']`` of +# the connectivity class, with one value per frequency for each channel in the +# seeds and targets. As with MIC, the absolute value of the patterns reflect +# the strength, however the sign differences can be used to visualise the +# orientation of the underlying dipole sources. The spatial patterns are +# **not** bound between :math:`[-1, 1]`. +# +# Here, we average across the patterns in the 13-18 Hz range. Plotting the +# patterns shows that the greatest connectivity between the left and right +# hemispheres occurs at the posteromedial regions, based on the regions with +# the largest absolute values. Using the signs of the values, we can infer the +# existence of a dipole source in the central regions of the left hemisphere +# which may account for the connectivity contributions seen for the left +# posteromedial and frontolateral areas (represented on the plot as a green +# line). + +# %% + +# compute average of patterns in desired frequency range +fband = [13, 18] +fband_idx = [mic.freqs.index(freq) for freq in fband] + +# patterns have shape [seeds/targets x cons x channels x freqs (x times)] +patterns = np.array(mic.attrs["patterns"]) +seed_pattern = patterns[0] +target_pattern = patterns[1] +# average across frequencies +seed_pattern = np.mean(seed_pattern[0, :, fband_idx[0]:fband_idx[1] + 1], + axis=1) +target_pattern = np.mean(target_pattern[0, :, fband_idx[0]:fband_idx[1] + 1], + axis=1) + +# store the patterns for plotting +seed_info = epochs.copy().pick(seed_names).info +target_info = epochs.copy().pick(target_names).info +seed_pattern = EvokedArray(seed_pattern[:, np.newaxis], seed_info) +target_pattern = EvokedArray(target_pattern[:, np.newaxis], target_info) + +# plot the patterns +fig, axes = plt.subplots(1, 4) +seed_pattern.plot_topomap( + times=0, sensors='m.', units=dict(mag='A.U.'), cbar_fmt='%.1E', + axes=axes[0:2], time_format='', show=False) +target_pattern.plot_topomap( + times=0, sensors='m.', units=dict(mag='A.U.'), cbar_fmt='%.1E', + axes=axes[2:], time_format='', show=False) +axes[0].set_position((0.1, 0.1, 0.35, 0.7)) +axes[1].set_position((0.4, 0.3, 0.02, 0.3)) +axes[2].set_position((0.5, 0.1, 0.35, 0.7)) +axes[3].set_position((0.9, 0.3, 0.02, 0.3)) +axes[0].set_title('Seed spatial pattern\n13-18 Hz') +axes[2].set_title('Target spatial pattern\n13-18 Hz') + +# plot the left hemisphere dipole example +axes[0].plot( + [-0.1, -0.05], [-0.075, -0.03], color='lime', linewidth=2, + path_effects=[pe.Stroke(linewidth=4, foreground='k'), pe.Normal()]) + +plt.show() + + +############################################################################### +# Multivariate interaction measure (MIM) +# -------------------------------------- +# +# Although it can be useful to analyse the single, largest connectivity +# component with MIC, multiple such components exist and can be examined with +# MIM. MIM can be thought of as an average of all connectivity components +# between the seeds and targets, and can be useful for an exploration of all +# available components. It is unnecessary to use the spatial filters of each +# component explicitly, and instead the desired result can be achieved from +# :math:`E` alone (Eq. 14 of :footcite:`EwaldEtAl2012`): +# +# :math:`MIM=tr(\boldsymbol{EE}^T)`, +# +# where again the frequency dependence is omitted. Unlike MIC, MIM is +# positive-valued and can be > 1. Without normalisation, MIM can be +# thought of as reflecting the total interaction between the seeds and targets. +# MIM can be normalised to lie in the range :math:`[0, 1]` by dividing the +# scores by the number of unique channels in the seeds and targets. Normalised +# MIM represents the interaction *per channel*, which can be biased by factors +# such as the presence of channels with little to no interaction. In line with +# the preferences of the method's authors :footcite:`EwaldEtAl2012`, since +# normalisation alters the interpretability of the results, **normalisation is +# not performed by default**. +# +# Here we see MIM reveal the strongest connectivity component to be around 10 +# Hz, with the higher frequency 13-18 Hz connectivity no longer being so +# prominent. This suggests that, across all components in the data, there may +# be more lower frequency connectivity sources than higher frequency sources. +# Thus, when combining these different components in MIM, the peak around 10 Hz +# remains, but the 13-18 Hz connectivity is diminished relative to the single, +# largest connectivity component of MIC. +# +# Looking at the values for normalised MIM, we see it has a maximum of ~0.1. +# The relatively small connectivity values thus indicate that many of the +# channels show little to no interaction. + +# %% + +fig, axis = plt.subplots(1, 1) +axis.plot(mim.freqs, mim.get_data()[0], linewidth=2) +axis.set_xlabel('Frequency (Hz)') +axis.set_ylabel('Absolute connectivity (A.U.)') +fig.suptitle('Multivariate interaction measure') + +n_channels = len(np.unique([*multivar_indices[0], *multivar_indices[1]])) +normalised_mim = mim.get_data()[0] / n_channels +print(f'Normalised MIM has a maximum value of {normalised_mim.max():.2f}') + + +############################################################################### +# Additionally, the instance where the seeds and targets are identical can be +# considered as a special case of MIM: the global interaction measure (GIM; Eq. +# 15 of :footcite:`EwaldEtAl2012`). Again, this allows connectivity within a +# single set of signals to be estimated. Computing GIM follows from Eq. 14, +# however since each interaction is considered twice, correcting the +# connectivity by a factor of :math:`\frac{1}{2}` is necessary (**the +# correction is performed automatically in this implementation**). Like MIM, +# GIM can also be > 1, but it can again be normalised to lie in the range +# :math:`[0, 1]` by dividing by the number of unique channels in the seeds and +# targets. However, since normalisation alters the interpretability of the +# results (i.e. interaction per channel for normalised GIM vs. total +# interaction for standard GIM), **GIM is not normalised by default**. +# +# With GIM, we find a broad connectivity peak around 10 Hz, with an additional +# peak around 20 Hz. The differences observed with GIM highlight the presence +# of interactions within each hemisphere that are absent for MIC or MIM. +# Furthermore, the values for normalised GIM are higher than for MIM, with a +# maximum of ~0.2, again indicating the presence of interactions across +# channels within each hemisphere. + +# %% + +indices = (np.array([*seeds, *targets]), np.array([*seeds, *targets])) +gim = spectral_connectivity_epochs( + epochs, method='mim', indices=indices, fmin=5, fmax=30, rank=None, + verbose=False) + +fig, axis = plt.subplots(1, 1) +axis.plot(gim.freqs, gim.get_data()[0], linewidth=2) +axis.set_xlabel('Frequency (Hz)') +axis.set_ylabel('Connectivity (A.U.)') +fig.suptitle('Global interaction measure') + +n_channels = len(np.unique([*indices[0], *indices[1]])) +normalised_gim = gim.get_data()[0] / n_channels +print(f'Normalised GIM has a maximum value of {normalised_gim.max():.2f}') + + +############################################################################### +# Handling high-dimensional data +# ------------------------------ +# +# An important issue to consider when using these multivariate methods is +# overfitting, which risks biasing connectivity estimates to maximise noise in +# the data. This risk can be reduced by performing a preliminary dimensionality +# reduction prior to estimating the connectivity with a singular value +# decomposition (Eqs. 32 & 33 of :footcite:`EwaldEtAl2012`). The degree of this +# dimensionality reduction can be specified using the ``rank`` argument, which +# by default will not perform any dimensionality reduction (assuming your data +# is full rank; see below if not). Choosing an expected rank of the data +# requires *a priori* knowledge about the number of components you expect to +# observe in the data. +# +# When comparing MIC/MIM scores across recordings, **it is highly recommended +# to estimate connectivity from the same number of channels (or equally from +# the same degree of rank subspace projection)** to avoid biases in +# connectivity estimates. Bias can be avoided by specifying a consistent rank +# subspace to project to using the ``rank`` argument, standardising your +# connectivity estimates regardless of changes in e.g. the number of channels +# across recordings. Note that this does not refer to the number of seeds and +# targets *within* a connection being identical, rather to the number of seeds +# and targets *across* connections. +# +# Here, we will project our seed and target data to only the first 25 +# components of our rank subspace. Results for MIM show that the general +# spectral pattern of connectivity is retained in the rank subspace-projected +# data, suggesting that a fair degree of redundant connectivity information is +# contained in the remaining 50 components of the seed and target data. We also +# assert that the spatial patterns of MIC are returned in the original sensor +# space despite this rank subspace projection, being reconstructed using the +# products of the singular value decomposition (Eqs. 46 & 47 of +# :footcite:`EwaldEtAl2012`). + +# %% + +(mic_red, mim_red) = spectral_connectivity_epochs( + epochs, method=['mic', 'mim'], indices=multivar_indices, fmin=5, fmax=30, + rank=([25], [25])) + +# subtract mean of scores for comparison +mim_red_meansub = mim_red.get_data()[0] - mim_red.get_data()[0].mean() +mim_meansub = mim.get_data()[0] - mim.get_data()[0].mean() + +# compare standard and rank subspace-projected MIM +fig, axis = plt.subplots(1, 1) +axis.plot(mim_red.freqs, mim_red_meansub, linewidth=2, + label='rank subspace (25) MIM') +axis.plot(mim.freqs, mim_meansub, linewidth=2, label='standard MIM') +axis.set_xlabel('Frequency (Hz)') +axis.set_ylabel('Mean-corrected connectivity (A.U.)') +axis.legend() +fig.suptitle('Multivariate interaction measure (non-normalised)') + +# no. channels equal with and without projecting to rank subspace for patterns +assert (patterns[0, 0].shape[0] == + np.array(mic_red.attrs["patterns"])[0, 0].shape[0]) +assert (patterns[1, 0].shape[0] == + np.array(mic_red.attrs["patterns"])[1, 0].shape[0]) + + +############################################################################### +# In the case that your data is not full rank and ``rank`` is left as ``None``, +# an automatic rank computation is performed and an appropriate degree of +# dimensionality reduction will be enforced. The rank of the data is determined +# by computing the singular values of the data and finding those within a +# factor of :math:`1e^{-10}` relative to the largest singular value. +# +# In some circumstances, this threshold may be too lenient, in which case you +# should inspect the singular values of your data to identify an appropriate +# degree of dimensionality reduction to perform, which you can then specify +# manually using the ``rank`` argument. The code below shows one possible +# approach for finding an appropriate rank of close-to-singular data with a +# more conservative threshold of :math:`1e^{-5}`. + +# %% + +# gets the singular values of the data +s = np.linalg.svd(raw.get_data(), compute_uv=False) +# finds how many singular values are "close" to the largest singular value +rank = np.count_nonzero(s >= s[0] * 1e-5) # 1e-5 is the "closeness" criteria + + +############################################################################### +# Limitations +# ----------- +# +# These multivariate methods offer many benefits in the form of dimensionality +# reduction, signal-to-noise ratio improvements, and invariance to +# estimate-biasing source mixing; however, no method is perfect. The immunity +# of the imaginary part of coherency to volume conduction comes from the fact +# that these artefacts have zero phase lag, and hence a zero-valued imaginary +# component. By projecting the complex-valued coherency to the imaginary axis, +# signals of a given magnitude with phase lag differences close to 90° or 270° +# see their contributions to the connectivity estimate increased relative to +# comparable signals with phase lag differences close to 0° or 180°. Therefore, +# the imaginary part of coherency is biased towards connectivity involving 90° +# and 270° phase lag difference components. +# +# Whilst this is not a limitation specific to the multivariate extension of +# this measure, these multivariate methods can introduce further bias: when +# maximising the imaginary part of coherency, components with phase lag +# differences close to 90° and 270° will likely give higher connectivity +# estimates, and so may be prioritised by the spatial filters. +# +# Such a limitation should be kept in mind when estimating connectivity using +# these methods. Possible sanity checks can involve comparing the spectral +# profiles of MIC/MIM to coherence and the imaginary part of coherency +# computed on the same data, as well as comparing to other multivariate +# measures, such as canonical coherence :footcite:`VidaurreEtAl2019`. + +############################################################################### +# References +# ---------- +# .. footbibliography:: + +# %% diff --git a/mne_connectivity/base.py b/mne_connectivity/base.py index 76ee01b0..88951529 100644 --- a/mne_connectivity/base.py +++ b/mne_connectivity/base.py @@ -667,7 +667,8 @@ def get_data(self, output='compact'): ``(n_nodes_in * n_nodes_out,)`` list. If 'dense', then will return each connectivity matrix as a 2D array. If 'compact' (default) then will return 'raveled' if ``indices`` were defined as - a list of tuples, or ``dense`` if indices is 'all'. + a list of tuples, or ``dense`` if indices is 'all'. Multivariate + connectivity data cannot be returned in a dense form. Returns ------- @@ -685,6 +686,14 @@ def get_data(self, output='compact'): if output == 'raveled': data = self._data else: + if self.method in ['mic', 'mim', 'gc', 'gc_tr']: + # multivariate results cannot be returned in a dense form as a + # single set of results would correspond to multiple entries in + # the matrix, and there could also be cases where multiple + # results correspond to the same entries in the matrix. + raise ValueError('cannot return multivariate connectivity ' + 'data in a dense form') + # get the new shape of the data array if self.is_epoched: new_shape = [self.n_epochs] diff --git a/mne_connectivity/spectral/epochs.py b/mne_connectivity/spectral/epochs.py index 1ee2d1a2..eb766f06 100644 --- a/mne_connectivity/spectral/epochs.py +++ b/mne_connectivity/spectral/epochs.py @@ -1,6 +1,9 @@ # Authors: Martin Luessi # Denis A. Engemann # Adam Li +# Thomas S. Binns +# Tien D. Nguyen +# Richard M. Köhler # # License: BSD (3-clause) @@ -8,6 +11,7 @@ import inspect import numpy as np +import scipy as sp from mne.epochs import BaseEpochs from mne.parallel import parallel_func from mne.source_estimate import _BaseSourceEstimate @@ -16,8 +20,8 @@ _psd_from_mt_adaptive) from mne.time_frequency.tfr import cwt, morlet from mne.time_frequency.multitaper import _compute_mt_params -from mne.utils import (_arange_div, _check_option, logger, warn, _time_mask, - verbose) +from mne.utils import ( + ProgressBar, _arange_div, _check_option, _time_mask, logger, warn, verbose) from ..base import (SpectralConnectivity, SpectroTemporalConnectivity) from ..utils import fill_doc, check_indices @@ -61,7 +65,7 @@ def _compute_freq_mask(freqs_all, fmin, fmax, fskip): def _prepare_connectivity(epoch_block, times_in, tmin, tmax, fmin, fmax, sfreq, indices, - mode, fskip, n_bands, + method, mode, fskip, n_bands, cwt_freqs, faverage): """Check and precompute dimensions of results data.""" first_epoch = epoch_block[0] @@ -89,14 +93,39 @@ def _prepare_connectivity(epoch_block, times_in, tmin, tmax, n_times = len(times) if indices is None: - logger.info('only using indices for lower-triangular matrix') - # only compute r for lower-triangular region - indices_use = np.tril_indices(n_signals, -1) + if any(this_method in _multivariate_methods for this_method in method): + if any(this_method in _gc_methods for this_method in method): + raise ValueError( + 'indices must be specified when computing Granger ' + 'causality, as all-to-all connectivity is not supported') + else: + logger.info('using all indices for multivariate connectivity') + indices_use = (np.arange(n_signals, dtype=int), + np.arange(n_signals, dtype=int)) + else: + logger.info('only using indices for lower-triangular matrix') + # only compute r for lower-triangular region + indices_use = np.tril_indices(n_signals, -1) else: + if any(this_method in _gc_methods for this_method in method): + if set(indices[0]).intersection(indices[1]): + raise ValueError( + 'seed and target indices must not intersect when computing' + 'Granger causality') indices_use = check_indices(indices) # number of connectivities to compute - n_cons = len(indices_use[0]) + if any(this_method in _multivariate_methods for this_method in method): + if ( + len(np.unique(indices_use[0])) != len(indices_use[0]) or + len(np.unique(indices_use[1])) != len(indices_use[1]) + ): + raise ValueError( + 'seed and target indices cannot contain repeated channels for ' + 'multivariate connectivity') + n_cons = 1 # UNTIL RAGGED ARRAYS SUPPORTED + else: + n_cons = len(indices_use[0]) logger.info(' computing connectivity for %d connections' % n_cons) @@ -222,6 +251,8 @@ def compute_con(self, con_idx, n_epochs): class _EpochMeanConEstBase(_AbstractConEstBase): """Base class for methods that estimate connectivity as mean epoch-wise.""" + patterns = None + def __init__(self, n_cons, n_freqs, n_times): self.n_cons = n_cons self.n_freqs = n_freqs @@ -243,9 +274,76 @@ def combine(self, other): self._acc += other._acc +class _EpochMeanMultivariateConEstBase(_AbstractConEstBase): + """Base class for mean epoch-wise multivar. con. estimation methods.""" + + n_steps = None + patterns = None + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + self.n_signals = n_signals + self.n_cons = n_cons + self.n_freqs = n_freqs + self.n_times = n_times + self.n_jobs = n_jobs + + # include time dimension, even when unused for indexing flexibility + if n_times == 0: + self.csd_shape = (n_signals**2, n_freqs) + self.con_scores = np.zeros((n_cons, n_freqs, 1)) + else: + self.csd_shape = (n_signals**2, n_freqs, n_times) + self.con_scores = np.zeros((n_cons, n_freqs, n_times)) + + # allocate space for accumulation of CSD + self._acc = np.zeros(self.csd_shape, dtype=np.complex128) + + self._compute_n_progress_bar_steps() + + def start_epoch(self): # noqa: D401 + """Called at the start of each epoch.""" + pass # for this type of con. method we don't do anything + + def combine(self, other): + """Include con. accumulated for some epochs in this estimate.""" + self._acc += other._acc + + def accumulate(self, con_idx, csd_xy): + """Accumulate CSD for some connections.""" + self._acc[con_idx] += csd_xy + + def _compute_n_progress_bar_steps(self): + """Calculate the number of steps to include in the progress bar.""" + self.n_steps = int(np.ceil(self.n_freqs / self.n_jobs)) + + def _log_connection_number(self, con_i): + """Log the number of the connection being computed.""" + logger.info('Computing %s for connection %i of %i' + % (self.name, con_i + 1, self.n_cons, )) + + def _get_block_indices(self, block_i, limit): + """Get indices for a computation block capped by a limit.""" + indices = np.arange(block_i * self.n_jobs, (block_i + 1) * self.n_jobs) + + return indices[np.nonzero(indices < limit)] + + def reshape_csd(self): + """Reshape CSD into a matrix of times x freqs x signals x signals.""" + if self.n_times == 0: + return (np.reshape(self._acc, ( + self.n_signals, self.n_signals, self.n_freqs, 1) + ).transpose(3, 2, 0, 1)) + + return (np.reshape(self._acc, ( + self.n_signals, self.n_signals, self.n_freqs, self.n_times) + ).transpose(3, 2, 0, 1)) + + class _CohEstBase(_EpochMeanConEstBase): """Base Estimator for Coherence, Coherency, Imag. Coherence.""" + accumulate_psd = True + def __init__(self, n_cons, n_freqs, n_times): super(_CohEstBase, self).__init__(n_cons, n_freqs, n_times) @@ -297,10 +395,236 @@ def compute_con(self, con_idx, n_epochs, psd_xx, psd_yy): # lgtm self.con_scores[con_idx] = np.imag(csd_mean) / np.sqrt(psd_xx * psd_yy) +class _MultivariateCohEstBase(_EpochMeanMultivariateConEstBase): + """Base estimator for multivariate imag. part of coherency methods. + + See Ewald et al. (2012). NeuroImage. DOI: 10.1016/j.neuroimage.2011.11.084 + for equation references. + """ + + name = None + accumulate_psd = False + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1): + super(_MultivariateCohEstBase, self).__init__( + n_signals, n_cons, n_freqs, n_times, n_jobs) + + def compute_con(self, indices, ranks, n_epochs=1): + """Compute multivariate imag. part of coherency between signals.""" + assert self.name in ['MIC', 'MIM'], ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + + csd = self.reshape_csd() / n_epochs + n_times = csd.shape[0] + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + if self.name == 'MIC': + self.patterns = np.full( + (2, self.n_cons, len(indices[0]), self.n_freqs, n_times), + np.nan) + + con_i = 0 + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + [indices[0]], [indices[1]], ranks[0], ranks[1]): + self._log_connection_number(con_i) + + n_seeds = len(seed_idcs) + con_idcs = [*seed_idcs, *target_idcs] + + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] + + # Eqs. 32 & 33 + C_bar, U_bar_aa, U_bar_bb = self._csd_svd( + C, n_seeds, seed_rank, target_rank) + + # Eqs. 3 & 4 + E = self._compute_e(C_bar, n_seeds=U_bar_aa.shape[3]) + + if self.name == 'MIC': + self._compute_mic(E, C, seed_idcs, target_idcs, n_times, + U_bar_aa, U_bar_bb, con_i) + else: + self._compute_mim(E, seed_idcs, target_idcs, con_i) + + con_i += 1 + + self.reshape_results() + + def _csd_svd(self, csd, n_seeds, seed_rank, target_rank): + """Dimensionality reduction of CSD with SVD.""" + n_times = csd.shape[0] + n_targets = csd.shape[2] - n_seeds + + C_aa = csd[..., :n_seeds, :n_seeds] + C_ab = csd[..., :n_seeds, n_seeds:] + C_bb = csd[..., n_seeds:, n_seeds:] + C_ba = csd[..., n_seeds:, :n_seeds] + + # Eq. 32 + if seed_rank != n_seeds: + U_aa = np.linalg.svd(np.real(C_aa), full_matrices=False)[0] + U_bar_aa = U_aa[..., :seed_rank] + else: + U_bar_aa = np.broadcast_to( + np.identity(n_seeds), + (n_times, self.n_freqs) + (n_seeds, n_seeds)) + + if target_rank != n_targets: + U_bb = np.linalg.svd(np.real(C_bb), full_matrices=False)[0] + U_bar_bb = U_bb[..., :target_rank] + else: + U_bar_bb = np.broadcast_to( + np.identity(n_targets), + (n_times, self.n_freqs) + (n_targets, n_targets)) + + # Eq. 33 + C_bar_aa = np.matmul( + U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul( + U_bar_aa.transpose(0, 1, 3, 2), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul( + U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul( + U_bar_bb.transpose(0, 1, 3, 2), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + + return C_bar, U_bar_aa, U_bar_bb + + def _compute_e(self, csd, n_seeds): + """Compute E from the CSD.""" + C_r = np.real(csd) + + parallel, parallel_compute_t, _ = parallel_func( + _mic_mim_compute_t, self.n_jobs, verbose=False) + + # imag. part of T filled when data is rank-deficient + T = np.zeros(csd.shape, dtype=np.complex128) + for block_i in ProgressBar( + range(self.n_steps), mesg="frequency blocks"): + freqs = self._get_block_indices(block_i, self.n_freqs) + parallel(parallel_compute_t( + C_r[:, f], T[:, f], n_seeds) for f in freqs) + + if not np.isreal(T).all() or not np.isfinite(T).all(): + raise RuntimeError( + 'the transformation matrix of the data must be real-valued ' + 'and contain no NaN or infinity values; check that you are ' + 'using full rank data or specify an appropriate rank for the ' + 'seeds and targets that is less than or equal to their ranks') + T = np.real(T) # make T real if check passes + + # Eq. 4 + D = np.matmul(T, np.matmul(csd, T)) + + # E as imag. part of D between seeds and targets + return np.imag(D[..., :n_seeds, n_seeds:]) + + def _compute_mic(self, E, C, seed_idcs, target_idcs, n_times, U_bar_aa, + U_bar_bb, con_i): + """Compute MIC and the associated spatial patterns.""" + n_seeds = len(seed_idcs) + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + # Eigendecomp. to find spatial filters for seeds and targets + w_seeds, V_seeds = np.linalg.eigh( + np.matmul(E, E.transpose(0, 1, 3, 2))) + w_targets, V_targets = np.linalg.eigh( + np.matmul(E.transpose(0, 1, 3, 2), E)) + if np.all(seed_idcs == target_idcs): + # strange edge-case where the eigenvectors returned should be a set + # of identity matrices with one rotated by 90 degrees, but are + # instead identical (i.e. are not rotated versions of one another). + # This leads to the case where the spatial filters are incorrectly + # applied, resulting in connectivity estimates of e.g. ~0 when they + # should be perfectly correlated ~1. Accordingly, we manually + # create a set of rotated identity matrices to use as the filters. + create_filter = False + stop = False + while not create_filter and not stop: + for time_i in range(n_times): + for freq_i in range(self.n_freqs): + if np.all(V_seeds[time_i, freq_i] == + V_targets[time_i, freq_i]): + create_filter = True + break + stop = True + if create_filter: + n_chans = E.shape[2] + eye_4d = np.zeros_like(V_seeds) + eye_4d[:, :, np.arange(n_chans), np.arange(n_chans)] = 1 + V_seeds = eye_4d + V_targets = np.rot90(eye_4d, axes=(2, 3)) + + # Spatial filters with largest eigval. for seeds and targets + alpha = V_seeds[times[:, None], freqs, :, w_seeds.argmax(axis=2)] + beta = V_targets[times[:, None], freqs, :, w_targets.argmax(axis=2)] + + # Eq. 46 (seed spatial patterns) + self.patterns[0, con_i] = (np.matmul( + np.real(C[..., :n_seeds, :n_seeds]), + np.matmul(U_bar_aa, np.expand_dims(alpha, axis=3))))[..., 0].T + + # Eq. 47 (target spatial patterns) + self.patterns[1, con_i] = (np.matmul( + np.real(C[..., n_seeds:, n_seeds:]), + np.matmul(U_bar_bb, np.expand_dims(beta, axis=3))))[..., 0].T + + # Eq. 7 + self.con_scores[con_i] = (np.einsum( + 'ijk,ijk->ij', alpha, np.matmul(E, np.expand_dims( + beta, axis=3))[..., 0] + ) / np.linalg.norm(alpha, axis=2) * np.linalg.norm(beta, axis=2)).T + + def _compute_mim(self, E, seed_idcs, target_idcs, con_i): + """Compute MIM (a.k.a. GIM if seeds == targets).""" + # Eq. 14 + self.con_scores[con_i] = np.matmul( + E, E.transpose(0, 1, 3, 2)).trace(axis1=2, axis2=3).T + + # Eq. 15 + if all(np.unique(seed_idcs) == np.unique(target_idcs)): + self.con_scores[con_i] *= 0.5 + + def reshape_results(self): + """Remove time dimension from results, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[..., 0] + if self.patterns is not None: + self.patterns = self.patterns[..., 0] + + +def _mic_mim_compute_t(C, T, n_seeds): + """Compute T in place for a single frequency (used for MIC and MIM).""" + for time_i in range(C.shape[0]): + T[time_i, :n_seeds, :n_seeds] = sp.linalg.fractional_matrix_power( + C[time_i, :n_seeds, :n_seeds], -0.5 + ) + T[time_i, n_seeds:, n_seeds:] = sp.linalg.fractional_matrix_power( + C[time_i, n_seeds:, n_seeds:], -0.5 + ) + + +class _MICEst(_MultivariateCohEstBase): + """Multivariate imaginary part of coherency (MIC) estimator.""" + + name = "MIC" + + +class _MIMEst(_MultivariateCohEstBase): + """Multivariate interaction measure (MIM) estimator.""" + + name = "MIM" + + class _PLVEst(_EpochMeanConEstBase): """PLV Estimator.""" name = 'PLV' + accumulate_psd = False def __init__(self, n_cons, n_freqs, n_times): super(_PLVEst, self).__init__(n_cons, n_freqs, n_times) @@ -324,6 +648,7 @@ class _ciPLVEst(_EpochMeanConEstBase): """corrected imaginary PLV Estimator.""" name = 'ciPLV' + accumulate_psd = False def __init__(self, n_cons, n_freqs, n_times): super(_ciPLVEst, self).__init__(n_cons, n_freqs, n_times) @@ -352,6 +677,7 @@ class _PLIEst(_EpochMeanConEstBase): """PLI Estimator.""" name = 'PLI' + accumulate_psd = False def __init__(self, n_cons, n_freqs, n_times): super(_PLIEst, self).__init__(n_cons, n_freqs, n_times) @@ -375,6 +701,7 @@ class _PLIUnbiasedEst(_PLIEst): """Unbiased PLI Square Estimator.""" name = 'Unbiased PLI Square' + accumulate_psd = False def compute_con(self, con_idx, n_epochs): """Compute final con. score for some connections.""" @@ -392,6 +719,7 @@ class _DPLIEst(_EpochMeanConEstBase): """DPLI Estimator.""" name = 'DPLI' + accumulate_psd = False def __init__(self, n_cons, n_freqs, n_times): super(_DPLIEst, self).__init__(n_cons, n_freqs, n_times) @@ -417,6 +745,7 @@ class _WPLIEst(_EpochMeanConEstBase): """WPLI Estimator.""" name = 'WPLI' + accumulate_psd = False def __init__(self, n_cons, n_freqs, n_times): super(_WPLIEst, self).__init__(n_cons, n_freqs, n_times) @@ -455,6 +784,7 @@ class _WPLIDebiasedEst(_EpochMeanConEstBase): """Debiased WPLI Square Estimator.""" name = 'Debiased WPLI Square' + accumulate_psd = False def __init__(self, n_cons, n_freqs, n_times): super(_WPLIDebiasedEst, self).__init__(n_cons, n_freqs, n_times) @@ -498,6 +828,7 @@ class _PPCEst(_EpochMeanConEstBase): """Pairwise Phase Consistency (PPC) Estimator.""" name = 'PPC' + accumulate_psd = False def __init__(self, n_cons, n_freqs, n_times): super(_PPCEst, self).__init__(n_cons, n_freqs, n_times) @@ -528,15 +859,389 @@ def compute_con(self, con_idx, n_epochs): self.con_scores[con_idx] = np.real(con) +class _GCEstBase(_EpochMeanMultivariateConEstBase): + """Base multivariate state-space Granger causality estimator.""" + + accumulate_psd = False + + def __init__(self, n_signals, n_cons, n_freqs, n_times, n_lags, n_jobs=1): + super(_GCEstBase, self).__init__( + n_signals, n_cons, n_freqs, n_times, n_jobs) + + self.freq_res = (self.n_freqs - 1) * 2 + if n_lags >= self.freq_res: + raise ValueError( + 'the number of lags (%i) must be less than double the ' + 'frequency resolution (%i)' % (n_lags, self.freq_res, )) + self.n_lags = n_lags + + def compute_con(self, indices, ranks, n_epochs=1): + """Compute multivariate state-space Granger causality.""" + assert self.name in ['GC', 'GC time-reversed'], ( + 'the class name is not recognised, please contact the ' + 'mne-connectivity developers') + + csd = self.reshape_csd() / n_epochs + + n_times = csd.shape[0] + times = np.arange(n_times) + freqs = np.arange(self.n_freqs) + + con_i = 0 + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + [indices[0]], [indices[1]], ranks[0], ranks[1]): + self._log_connection_number(con_i) + + con_idcs = [*seed_idcs, *target_idcs] + C = csd[np.ix_(times, freqs, con_idcs, con_idcs)] + + con_seeds = np.arange(len(seed_idcs)) + con_targets = np.arange(len(target_idcs)) + len(seed_idcs) + + C_bar = self._csd_svd( + C, con_seeds, con_targets, seed_rank, target_rank) + n_signals = seed_rank + target_rank + con_seeds = np.arange(seed_rank) + con_targets = np.arange(target_rank) + seed_rank + + autocov = self._compute_autocov(C_bar) + if self.name == "GC time-reversed": + autocov = autocov.transpose(0, 1, 3, 2) + + A_f, V = self._autocov_to_full_var(autocov) + A_f_3d = np.reshape( + A_f, (n_times, n_signals, n_signals * self.n_lags), + order="F") + A, K = self._full_var_to_iss(A_f_3d) + + self.con_scores[con_i] = self._iss_to_ugc( + A, A_f_3d, K, V, con_seeds, con_targets) + + con_i += 1 + + self.reshape_results() + + def _csd_svd(self, csd, seeds, targets, seed_rank, target_rank): + """Dimensionality reduction of CSD with SVD on the covariance.""" + # sum over times and epochs to get cov. from CSD + cov = csd.sum(axis=(0, 1)) + + n_seeds = len(seeds) + n_targets = len(targets) + + cov_aa = cov[:n_seeds, :n_seeds] + cov_bb = cov[n_seeds:, n_seeds:] + + if seed_rank != n_seeds: + U_aa = np.linalg.svd(np.real(cov_aa), full_matrices=False)[0] + U_bar_aa = U_aa[:, :seed_rank] + else: + U_bar_aa = np.identity(n_seeds) + + if target_rank != n_targets: + U_bb = np.linalg.svd(np.real(cov_bb), full_matrices=False)[0] + U_bar_bb = U_bb[:, :target_rank] + else: + U_bar_bb = np.identity(n_targets) + + C_aa = csd[..., :n_seeds, :n_seeds] + C_ab = csd[..., :n_seeds, n_seeds:] + C_bb = csd[..., n_seeds:, n_seeds:] + C_ba = csd[..., n_seeds:, :n_seeds] + + C_bar_aa = np.matmul( + U_bar_aa.transpose(1, 0), np.matmul(C_aa, U_bar_aa)) + C_bar_ab = np.matmul( + U_bar_aa.transpose(1, 0), np.matmul(C_ab, U_bar_bb)) + C_bar_bb = np.matmul( + U_bar_bb.transpose(1, 0), np.matmul(C_bb, U_bar_bb)) + C_bar_ba = np.matmul( + U_bar_bb.transpose(1, 0), np.matmul(C_ba, U_bar_aa)) + C_bar = np.append(np.append(C_bar_aa, C_bar_ab, axis=3), + np.append(C_bar_ba, C_bar_bb, axis=3), axis=2) + + return C_bar + + def _compute_autocov(self, csd): + """Compute autocovariance from the CSD.""" + n_times = csd.shape[0] + n_signals = csd.shape[2] + + circular_shifted_csd = np.concatenate( + [np.flip(np.conj(csd[:, 1:]), axis=1), csd[:, :-1]], axis=1) + ifft_shifted_csd = self._block_ifft( + circular_shifted_csd, self.freq_res) + lags_ifft_shifted_csd = np.reshape( + ifft_shifted_csd[:, :self.n_lags + 1], + (n_times, self.n_lags + 1, n_signals ** 2), order="F") + + signs = np.repeat([1], self.n_lags + 1).tolist() + signs[1::2] = [x * -1 for x in signs[1::2]] + sign_matrix = np.repeat( + np.tile(np.array(signs), (n_signals ** 2, 1))[np.newaxis], + n_times, axis=0).transpose(0, 2, 1) + + return np.real(np.reshape( + sign_matrix * lags_ifft_shifted_csd, + (n_times, self.n_lags + 1, n_signals, n_signals), order="F")) + + def _block_ifft(self, csd, n_points): + """Compute block iFFT with n points.""" + shape = csd.shape + csd_3d = np.reshape( + csd, (shape[0], shape[1], shape[2] * shape[3]), order="F") + + csd_ifft = np.fft.ifft(csd_3d, n=n_points, axis=1) + + return np.reshape(csd_ifft, shape, order="F") + + def _autocov_to_full_var(self, autocov): + """Compute full VAR model using Whittle's LWR recursion.""" + if np.any(np.linalg.det(autocov) == 0): + raise RuntimeError( + 'the autocovariance matrix is singular; check if your data is ' + 'rank deficient and specify an appropriate rank argument <= ' + 'the rank of the seeds and targets') + + A_f, V = self._whittle_lwr_recursion(autocov) + + if not np.isfinite(A_f).all(): + raise RuntimeError('at least one VAR model coefficient is ' + 'infinite or NaN; check the data you are using') + + try: + np.linalg.cholesky(V) + except np.linalg.LinAlgError as np_error: + raise RuntimeError( + 'the covariance matrix of the residuals is not ' + 'positive-definite; check the singular values of your data ' + 'and specify an appropriate rank argument <= the rank of the ' + 'seeds and targets') from np_error + + return A_f, V + + def _whittle_lwr_recursion(self, G): + """Solve Yule-Walker eqs. for full VAR params. with LWR recursion. + + See: Whittle P., 1963. Biometrika, DOI: 10.1093/biomet/50.1-2.129 + """ + # Initialise recursion + n = G.shape[2] # number of signals + q = G.shape[1] - 1 # number of lags + t = G.shape[0] # number of times + qn = n * q + + cov = G[:, 0, :, :] # covariance + G_f = np.reshape( + G[:, 1:, :, :].transpose(0, 3, 1, 2), (t, qn, n), + order="F") # forward autocov + G_b = np.reshape( + np.flip(G[:, 1:, :, :], 1).transpose(0, 3, 2, 1), (t, n, qn), + order="F").transpose(0, 2, 1) # backward autocov + + A_f = np.zeros((t, n, qn)) # forward coefficients + A_b = np.zeros((t, n, qn)) # backward coefficients + + k = 1 # model order + r = q - k + k_f = np.arange(k * n) # forward indices + k_b = np.arange(r * n, qn) # backward indices + + try: + A_f[:, :, k_f] = np.linalg.solve( + cov, G_b[:, k_b, :].transpose(0, 2, 1)).transpose(0, 2, 1) + A_b[:, :, k_b] = np.linalg.solve( + cov, G_f[:, k_f, :].transpose(0, 2, 1)).transpose(0, 2, 1) + + # Perform recursion + for k in np.arange(2, q + 1): + var_A = (G_b[:, (r - 1) * n: r * n, :] - + np.matmul(A_f[:, :, k_f], G_b[:, k_b, :])) + var_B = cov - np.matmul(A_b[:, :, k_b], G_b[:, k_b, :]) + AA_f = np.linalg.solve( + var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + + var_A = (G_f[:, (k - 1) * n: k * n, :] - + np.matmul(A_b[:, :, k_b], G_f[:, k_f, :])) + var_B = cov - np.matmul(A_f[:, :, k_f], G_f[:, k_f, :]) + AA_b = np.linalg.solve( + var_B, var_A.transpose(0, 2, 1)).transpose(0, 2, 1) + + A_f_previous = A_f[:, :, k_f] + A_b_previous = A_b[:, :, k_b] + + r = q - k + k_f = np.arange(k * n) + k_b = np.arange(r * n, qn) + + A_f[:, :, k_f] = np.dstack( + (A_f_previous - np.matmul(AA_f, A_b_previous), AA_f)) + A_b[:, :, k_b] = np.dstack( + (AA_b, A_b_previous - np.matmul(AA_b, A_f_previous))) + except np.linalg.LinAlgError as np_error: + raise RuntimeError( + 'the autocovariance matrix is singular; check if your data is ' + 'rank deficient and specify an appropriate rank argument <= ' + 'the rank of the seeds and targets') from np_error + + V = cov - np.matmul(A_f, G_f) + A_f = np.reshape(A_f, (t, n, n, q), order="F") + + return A_f, V + + def _full_var_to_iss(self, A_f): + """Compute innovations-form parameters for a state-space model. + + Parameters computed from a full VAR model using Aoki's method. For a + non-moving-average full VAR model, the state-space parameter C + (observation matrix) is identical to AF of the VAR model. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + t = A_f.shape[0] + m = A_f.shape[1] # number of signals + p = A_f.shape[2] // m # number of autoregressive lags + + I_p = np.dstack(t * [np.eye(m * p)]).transpose(2, 0, 1) + A = np.hstack((A_f, I_p[:, : (m * p - m), :])) # state transition + # matrix + K = np.hstack(( + np.dstack(t * [np.eye(m)]).transpose(2, 0, 1), + np.zeros((t, (m * (p - 1)), m)))) # Kalman gain matrix + + return A, K + + def _iss_to_ugc(self, A, C, K, V, seeds, targets): + """Compute unconditional GC from innovations-form state-space params. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + times = np.arange(A.shape[0]) + freqs = np.arange(self.n_freqs) + z = np.exp(-1j * np.pi * np.linspace(0, 1, self.n_freqs)) # points + # on a unit circle in the complex plane, one for each frequency + + H = self._iss_to_tf(A, C, K, z) # spectral transfer function + V_22_1 = np.linalg.cholesky(self._partial_covar(V, seeds, targets)) + HV = np.matmul(H, np.linalg.cholesky(V)) + S = np.matmul(HV, HV.conj().transpose(0, 1, 3, 2)) # Eq. 6 + S_11 = S[np.ix_(freqs, times, targets, targets)] + HV_12 = np.matmul(H[np.ix_(freqs, times, targets, seeds)], V_22_1) + HVH = np.matmul(HV_12, HV_12.conj().transpose(0, 1, 3, 2)) + + # Eq. 11 + return np.real( + np.log(np.linalg.det(S_11)) - np.log(np.linalg.det(S_11 - HVH))) + + def _iss_to_tf(self, A, C, K, z): + """Compute transfer function for innovations-form state-space params. + + In the frequency domain, the back-shift operator, z, is a vector of + points on a unit circle in the complex plane. z = e^-iw, where -pi < w + <= pi. + + A note on efficiency: solving over the 4D time-freq. tensor is slower + than looping over times and freqs when n_times and n_freqs high, and + when n_times and n_freqs low, looping over times and freqs very fast + anyway (plus tensor solving doesn't allow for parallelisation). + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + t = A.shape[0] + h = self.n_freqs + n = C.shape[1] + m = A.shape[1] + I_n = np.eye(n) + I_m = np.eye(m) + H = np.zeros((h, t, n, n), dtype=np.complex128) + + parallel, parallel_compute_H, _ = parallel_func( + _gc_compute_H, self.n_jobs, verbose=False + ) + H = np.zeros((h, t, n, n), dtype=np.complex128) + for block_i in ProgressBar( + range(self.n_steps), mesg="frequency blocks" + ): + freqs = self._get_block_indices(block_i, self.n_freqs) + H[freqs] = parallel( + parallel_compute_H(A, C, K, z[k], I_n, I_m) for k in freqs) + + return H + + def _partial_covar(self, V, seeds, targets): + """Compute partial covariance of a matrix. + + Given a covariance matrix V, the partial covariance matrix of V between + indices i and j, given k (V_ij|k), is equivalent to V_ij - V_ik * + V_kk^-1 * V_kj. In this case, i and j are seeds, and k are targets. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101. + """ + times = np.arange(V.shape[0]) + W = np.linalg.solve( + np.linalg.cholesky(V[np.ix_(times, targets, targets)]), + V[np.ix_(times, targets, seeds)], + ) + W = np.matmul(W.transpose(0, 2, 1), W) + + return V[np.ix_(times, seeds, seeds)] - W + + def reshape_results(self): + """Remove time dimension from con. scores, if necessary.""" + if self.n_times == 0: + self.con_scores = self.con_scores[:, :, 0] + + +def _gc_compute_H(A, C, K, z_k, I_n, I_m): + """Compute transfer function for innovations-form state-space params. + + See: Barnett, L. & Seth, A.K., 2015, Physical Review, DOI: + 10.1103/PhysRevE.91.040101, Eq. 4. + """ + from scipy import linalg # is this necessary??? + H = np.zeros((A.shape[0], C.shape[1], C.shape[1]), dtype=np.complex128) + for t in range(A.shape[0]): + H[t] = I_n + np.matmul( + C[t], linalg.lu_solve(linalg.lu_factor(z_k * I_m - A[t]), K[t])) + + return H + + +class _GCEst(_GCEstBase): + """[seeds -> targets] state-space GC estimator.""" + + name = "GC" + + +class _GCTREst(_GCEstBase): + """time-reversed[seeds -> targets] state-space GC estimator.""" + + name = "GC time-reversed" + ############################################################################### + + +_multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] +_gc_methods = ['gc', 'gc_tr'] + + def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, - mode, window_fun, eigvals, wavelets, + method, mode, window_fun, eigvals, wavelets, freq_mask, mt_adaptive, idx_map, block_size, psd, accumulate_psd, con_method_types, - con_methods, n_signals, n_times, - accumulate_inplace=True): + con_methods, n_signals, n_signals_use, + n_times, gc_n_lags, accumulate_inplace=True): """Estimate connectivity for one epoch (see spectral_connectivity).""" - n_cons = len(idx_map[0]) + if any(this_method in _multivariate_methods for this_method in method): + n_cons = 1 # UNTIL RAGGED ARRAYS SUPPORTED + n_con_signals = n_signals_use ** 2 + else: + n_cons = len(idx_map[0]) + n_con_signals = n_cons if wavelets is not None: n_times_spectrum = n_times @@ -547,8 +1252,24 @@ def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, if not accumulate_inplace: # instantiate methods only for this epoch (used in parallel mode) - con_methods = [mtype(n_cons, n_freqs, n_times_spectrum) - for mtype in con_method_types] + con_methods = [] + for mtype in con_method_types: + method_params = list(inspect.signature(mtype).parameters) + if "n_signals" in method_params: + # if it's a multivariate connectivity method + if "n_lags" in method_params: + # if it's a Granger causality method + con_methods.append( + mtype(n_signals_use, n_cons, n_freqs, n_times_spectrum, + gc_n_lags) + ) + else: + # if it's a coherence method + con_methods.append( + mtype(n_signals_use, n_cons, n_freqs, n_times_spectrum) + ) + else: + con_methods.append(mtype(n_cons, n_freqs, n_times_spectrum)) _check_option('mode', mode, ('cwt_morlet', 'multitaper', 'fourier')) if len(sig_idx) == n_signals: @@ -624,8 +1345,9 @@ def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, # accumulate connectivity scores if mode in ['multitaper', 'fourier']: - for i in range(0, n_cons, block_size): - con_idx = slice(i, i + block_size) + for i in range(0, n_con_signals, block_size): + n_extra = max(0, i + block_size - n_con_signals) + con_idx = slice(i, i + block_size - n_extra) if mt_adaptive: csd = _csd_from_mt(x_t[idx_map[0][con_idx]], x_t[idx_map[1][con_idx]], @@ -639,8 +1361,9 @@ def _epoch_spectral_connectivity(data, sig_idx, tmin_idx, tmax_idx, sfreq, for method in con_methods: method.accumulate(con_idx, csd) else: # mode == 'cwt_morlet' # reminder to add alternative TFR methods - for i_block, i in enumerate(range(0, n_cons, block_size)): - con_idx = slice(i, i + block_size) + for i in range(0, n_con_signals, block_size): + n_extra = max(0, i + block_size - n_con_signals) + con_idx = slice(i, i + block_size - n_extra) # this codes can be very slow csd = (x_t[idx_map[0][con_idx]] * x_t[idx_map[1][con_idx]].conjugate()) @@ -727,10 +1450,11 @@ def _get_and_verify_data_sizes(data, sfreq, n_signals=None, n_times=None, 'plv': _PLVEst, 'ciplv': _ciPLVEst, 'ppc': _PPCEst, 'pli': _PLIEst, 'pli2_unbiased': _PLIUnbiasedEst, 'dpli': _DPLIEst, 'wpli': _WPLIEst, - 'wpli2_debiased': _WPLIDebiasedEst} + 'wpli2_debiased': _WPLIDebiasedEst, 'mic': _MICEst, + 'mim': _MIMEst, 'gc': _GCEst, 'gc_tr': _GCTREst} -def _check_estimators(method, mode): +def _check_estimators(method): """Check construction of connectivity estimators.""" n_methods = len(method) con_method_types = list() @@ -748,30 +1472,24 @@ def _check_estimators(method, mode): 'not have the method %s' % msg) con_method_types.append(this_method) - # determine how many arguments the compute_con_function needs - n_comp_args = [len(inspect.signature(mtype.compute_con).parameters) - for mtype in con_method_types] - - # we currently only support 3 arguments - if any(n not in (3, 5) for n in n_comp_args): - raise ValueError('The .compute_con method needs to have either ' - '3 or 5 arguments') # if none of the comp_con functions needs the PSD, we don't estimate it - accumulate_psd = any(n == 5 for n in n_comp_args) - return con_method_types, n_methods, accumulate_psd, n_comp_args + accumulate_psd = any( + this_method.accumulate_psd for this_method in con_method_types) + return con_method_types, n_methods, accumulate_psd -@verbose -@fill_doc + +@ verbose +@ fill_doc def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, sfreq=None, mode='multitaper', fmin=None, fmax=np.inf, fskip=0, faverage=False, tmin=None, tmax=None, mt_bandwidth=None, mt_adaptive=False, mt_low_bias=True, cwt_freqs=None, - cwt_n_cycles=7, block_size=1000, n_jobs=1, - verbose=None): - """Compute frequency- and time-frequency-domain connectivity measures. + cwt_n_cycles=7, gc_n_lags=40, rank=None, + block_size=1000, n_jobs=1, verbose=None): + r"""Compute frequency- and time-frequency-domain connectivity measures. The connectivity method(s) are specified using the "method" parameter. All methods are based on estimates of the cross- and power spectral @@ -790,11 +1508,15 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, %(names)s method : str | list of str Connectivity measure(s) to compute. These can be ``['coh', 'cohy', - 'imcoh', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', - 'wpli2_debiased']``. + 'imcoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli', 'wpli', + 'wpli2_debiased', 'gc', 'gc_tr']``. Multivariate methods (``['mic', + 'mim', 'gc', 'gc_tr]``) cannot be called with the other methods. indices : tuple of array | None Two arrays with indices of connections for which to compute - connectivity. If None, all connections are computed. + connectivity. If a multivariate method is called, the indices are for a + single connection between all seeds and all targets. If None, all + connections are computed, unless a Granger causality method is called, + in which case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. @@ -804,8 +1526,6 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, fmin : float | tuple of float The lower frequency of interest. Multiple bands are defined using a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq. - If None the frequency corresponding to an epoch length of 5 cycles - is used. fmax : float | tuple of float The upper frequency of interest. Multiple bands are dedined using a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq. @@ -840,11 +1560,21 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, cwt_n_cycles : float | array of float Number of cycles. Fixed number or one per frequency. Only used in 'cwt_morlet' mode. + gc_n_lags : int + Number of lags to use for the vector autoregressive model when + computing Granger causality. Higher values increase computational cost, + but reduce the degree of spectral smoothing in the results. Only used + if ``method`` contains any of ``['gc', 'gc_tr']``. + rank : tuple of array | None + Two arrays with the rank to project the seed and target data to, + respectively, using singular value decomposition. If None, the rank of + the data is computed and projected to. Only used if ``method`` contains + any of ``['mic', 'mim', 'gc', 'gc_tr']``. block_size : int How many connections to compute at once (higher numbers are faster but require more memory). n_jobs : int - How many epochs to process in parallel. + How many samples to process in parallel. %(verbose)s Returns @@ -858,7 +1588,8 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, when "indices" is None, or (n_con, n_freqs) mode: 'multitaper' or 'fourier' (n_con, n_freqs, n_times) mode: 'cwt_morlet' - when "indices" is specified and "n_con = len(indices[0])". + when "indices" is specified and "n_con = len(indices[0])". If a + multivariate method is called "n_con = 1" even if "indices" is None. See Also -------- @@ -888,11 +1619,11 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, "mode" parameter. By default, the connectivity between all signals is computed (only - connections corresponding to the lower-triangular part of the - connectivity matrix). If one is only interested in the connectivity - between some signals, the "indices" parameter can be used. For example, - to compute the connectivity between the signal with index 0 and signals - "2, 3, 4" (a total of 3 connections) one can use the following:: + connections corresponding to the lower-triangular part of the connectivity + matrix). If one is only interested in the connectivity between some + signals, the "indices" parameter can be used. For example, to compute the + connectivity between the signal with index 0 and signals "2, 3, 4" (a total + of 3 connections) one can use the following:: indices = (np.array([0, 0, 0]), # row indices np.array([2, 3, 4])) # col indices @@ -903,6 +1634,15 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, In this case con.get_data().shape = (3, n_freqs). The connectivity scores are in the same order as defined indices. + For multivariate methods, this is handled differently. If "indices" is + None, connectivity between all signals will attempt to be computed (this is + not possible if a Granger causality method is called). If "indices" is + specified, the seeds and targets are treated as a single connection. For + example, to compute the connectivity between signals 0, 1, 2 and 3, 4, 5, + one would use the same approach as above, however the signals would all be + considered for a single connection and the connectivity scores would have + the shape (1, n_freqs). + **Supported Connectivity Measures** The connectivity method(s) is specified using the "method" parameter. The @@ -928,12 +1668,31 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, C = ---------------------- sqrt(E[Sxx] * E[Syy]) + 'mic' : Maximised Imaginary part of Coherency (MIC) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} + {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} + \parallel}}` + + where: :math:`\boldsymbol{E}` is the imaginary part of the + transformed cross-spectral density between seeds and targets; and + :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are + eigenvectors for the seeds and targets, such that + :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises + connectivity between the seeds and targets. + + 'mim' : Multivariate Interaction Measure (MIM) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIM=tr(\boldsymbol{EE}^T)` + 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given by:: PLV = |E[Sxy/|Sxy|]| - 'ciplv' : corrected imaginary PLV (icPLV) + 'ciplv' : corrected imaginary PLV (ciPLV) :footcite:`BrunaEtAl2018` given by:: |E[Im(Sxy/|Sxy|)]| @@ -965,14 +1724,32 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, 'wpli2_debiased' : Debiased estimator of squared WPLI :footcite:`VinckEtAl2011`. + 'gc' : State-space Granger Causality (GC) :footcite:`BarnettSeth2015` + given by: + + :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert + \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss + \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, + + where: :math:`s` and :math:`t` represent the seeds and targets, + respectively; :math:`\boldsymbol{H}` is the spectral transfer + function; :math:`\boldsymbol{\Sigma}` is the residuals matrix of + the autoregressive model; and :math:`\boldsymbol{S}` is + :math:`\boldsymbol{\Sigma}` transformed by :math:`\boldsymbol{H}`. + + 'gc_tr' : State-space GC on time-reversed signals + :footcite:`BarnettSeth2015,WinklerEtAl2016` given by the same equation + as for 'gc', but where the autocovariance sequence from which the + autoregressive model is produced is transposed to mimic the reversal of + the original signal in time. + References ---------- .. footbibliography:: """ if n_jobs != 1: - parallel, my_epoch_spectral_connectivity, _ = \ - parallel_func(_epoch_spectral_connectivity, n_jobs, - verbose=verbose) + parallel, my_epoch_spectral_connectivity, _ = parallel_func( + _epoch_spectral_connectivity, n_jobs, verbose=verbose) # format fmin and fmax and check inputs if fmin is None: @@ -991,9 +1768,24 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, if not isinstance(method, (list, tuple)): method = [method] # make it a list so we can iterate over it + if n_bands != 1 and any( + this_method in _gc_methods for this_method in method + ): + raise ValueError('computing Granger causality on multiple frequency ' + 'bands is not yet supported') + + if any(this_method in _multivariate_methods for this_method in method): + if not all(this_method in _multivariate_methods for + this_method in method): + raise ValueError( + 'bivariate and multivariate connectivity methods cannot be ' + 'used in the same function call') + multivariate_con = True + else: + multivariate_con = False + # handle connectivity estimators - (con_method_types, n_methods, accumulate_psd, - n_comp_args) = _check_estimators(method=method, mode=mode) + (con_method_types, n_methods, accumulate_psd) = _check_estimators(method) events = None event_id = None @@ -1037,8 +1829,15 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, n_signals, indices_use, warn_times) = _prepare_connectivity( epoch_block=epoch_block, times_in=times_in, tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax, sfreq=sfreq, - indices=indices, mode=mode, fskip=fskip, n_bands=n_bands, - cwt_freqs=cwt_freqs, faverage=faverage) + indices=indices, method=method, mode=mode, fskip=fskip, + n_bands=n_bands, cwt_freqs=cwt_freqs, faverage=faverage) + + # check rank input and compute data ranks if necessary + if multivariate_con: + rank = _check_rank_input(rank, data, sfreq, indices_use) + else: + rank = None + gc_n_lags = None # get the window function, wavelets, etc for different modes (spectral_params, mt_adaptive, n_times_spectrum, @@ -1050,23 +1849,36 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # unique signals for which we actually need to compute PSD etc. sig_idx = np.unique(np.r_[indices_use[0], indices_use[1]]) + n_signals_use = len(sig_idx) # map indices to unique indices idx_map = [np.searchsorted(sig_idx, ind) for ind in indices_use] + if multivariate_con: + indices_use = idx_map + idx_map = np.unique([*idx_map[0], *idx_map[1]]) + idx_map = [np.sort(np.repeat(idx_map, len(sig_idx))), + np.tile(idx_map, len(sig_idx))] # allocate space to accumulate PSD if accumulate_psd: if n_times_spectrum == 0: - psd_shape = (len(sig_idx), n_freqs) + psd_shape = (n_signals_use, n_freqs) else: - psd_shape = (len(sig_idx), n_freqs, n_times_spectrum) + psd_shape = (n_signals_use, n_freqs, n_times_spectrum) psd = np.zeros(psd_shape) else: psd = None # create instances of the connectivity estimators - con_methods = [mtype(n_cons, n_freqs, n_times_spectrum) - for mtype in con_method_types] + con_methods = [] + for mtype_i, mtype in enumerate(con_method_types): + method_params = dict(n_cons=n_cons, n_freqs=n_freqs, + n_times=n_times_spectrum) + if method[mtype_i] in _multivariate_methods: + method_params.update(dict(n_signals=n_signals_use)) + if method[mtype_i] in _gc_methods: + method_params.update(dict(n_lags=gc_n_lags)) + con_methods.append(mtype(**method_params)) sep = ', ' metrics_str = sep.join([meth.name for meth in con_methods]) @@ -1080,33 +1892,35 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, warn_times=warn_times) call_params = dict( - sig_idx=sig_idx, tmin_idx=tmin_idx, - tmax_idx=tmax_idx, sfreq=sfreq, mode=mode, - freq_mask=freq_mask, idx_map=idx_map, block_size=block_size, + sig_idx=sig_idx, tmin_idx=tmin_idx, tmax_idx=tmax_idx, sfreq=sfreq, + method=method, mode=mode, freq_mask=freq_mask, idx_map=idx_map, + block_size=block_size, psd=psd, accumulate_psd=accumulate_psd, mt_adaptive=mt_adaptive, con_method_types=con_method_types, con_methods=con_methods if n_jobs == 1 else None, - n_signals=n_signals, n_times=n_times, + n_signals=n_signals, n_signals_use=n_signals_use, n_times=n_times, + gc_n_lags=gc_n_lags, accumulate_inplace=True if n_jobs == 1 else False) call_params.update(**spectral_params) if n_jobs == 1: # no parallel processing for this_epoch in epoch_block: - logger.info(' computing connectivity for epoch %d' + logger.info(' computing cross-spectral density for epoch %d' % (epoch_idx + 1)) # con methods and psd are updated inplace _epoch_spectral_connectivity(data=this_epoch, **call_params) epoch_idx += 1 else: # process epochs in parallel - logger.info(' computing connectivity for epochs %d..%d' - % (epoch_idx + 1, epoch_idx + len(epoch_block))) + logger.info( + ' computing cross-spectral density for epochs %d..%d' + % (epoch_idx + 1, epoch_idx + len(epoch_block))) out = parallel(my_epoch_spectral_connectivity( - data=this_epoch, **call_params) - for this_epoch in epoch_block) + data=this_epoch, **call_params) + for this_epoch in epoch_block) # do the accumulation for this_out in out: for _method, parallel_method in zip(con_methods, this_out[0]): @@ -1123,12 +1937,11 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, # compute final connectivity scores con = list() - for conn_method, n_args in zip(con_methods, n_comp_args): + patterns = list() + for method_i, conn_method in enumerate(con_methods): + # future estimators will need to be handled here - if n_args == 3: - # compute all scores at once - conn_method.compute_con(slice(0, n_cons), n_epochs) - elif n_args == 5: + if conn_method.accumulate_psd: # compute scores block-wise to save memory for i in range(0, n_cons, block_size): con_idx = slice(i, i + block_size) @@ -1136,26 +1949,47 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, psd_yy = psd[idx_map[1][con_idx]] conn_method.compute_con(con_idx, n_epochs, psd_xx, psd_yy) else: - raise RuntimeError('This should never happen.') + # compute all scores at once + if method[method_i] in _multivariate_methods: + conn_method.compute_con(indices_use, rank, n_epochs) + else: + conn_method.compute_con(slice(0, n_cons), n_epochs) # get the connectivity scores this_con = conn_method.con_scores + this_patterns = conn_method.patterns if this_con.shape[0] != n_cons: - raise ValueError('First dimension of connectivity scores must be ' - 'the same as the number of connections') + raise RuntimeError( + 'first dimension of connectivity scores does not match the ' + 'number of connections; please contact the mne-connectivity ' + 'developers') if faverage: if this_con.shape[1] != n_freqs: - raise ValueError('2nd dimension of connectivity scores must ' - 'be the same as the number of frequencies') + raise RuntimeError( + 'second dimension of connectivity scores does not match ' + 'the number of frequencies; please contact the ' + 'mne-connectivity developers') con_shape = (n_cons, n_bands) + this_con.shape[2:] this_con_bands = np.empty(con_shape, dtype=this_con.dtype) for band_idx in range(n_bands): - this_con_bands[:, band_idx] =\ - np.mean(this_con[:, freq_idx_bands[band_idx]], axis=1) + this_con_bands[:, band_idx] = np.mean( + this_con[:, freq_idx_bands[band_idx]], axis=1) this_con = this_con_bands + if this_patterns is not None: + patterns_shape = ((2, n_cons, len(indices[0]), n_bands) + + this_patterns.shape[4:]) + this_patterns_bands = np.empty(patterns_shape, + dtype=this_patterns.dtype) + for band_idx in range(n_bands): + this_patterns_bands[:, :, :, band_idx] = np.mean( + this_patterns[:, :, :, freq_idx_bands[band_idx]], + axis=3) + this_patterns = this_patterns_bands + con.append(this_con) + patterns.append(this_patterns) freqs_used = freqs if faverage: @@ -1169,7 +2003,7 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, freqs_used = freqs_bands freqs_used = [[np.min(band), np.max(band)] for band in freqs_used] - if indices is None: + if indices is None and not multivariate_con: # return all-to-all connectivity matrices # raveled into a 1D array logger.info(' assembling connectivity matrix') @@ -1186,27 +2020,24 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, this_con = this_con.reshape((n_signals ** 2,) + this_con_flat.shape[1:]) con.append(this_con) - # number of nodes in the original data, + # number of nodes in the original data n_nodes = n_signals + if multivariate_con: + # UNTIL RAGGED ARRAYS SUPPORTED + indices = tuple( + [[np.array(indices_use[0])], [np.array(indices_use[1])]]) + # create a list of connectivity containers conn_list = [] - for _con in con: - kwargs = dict(data=_con, - names=names, - freqs=freqs, - method=method, - n_nodes=n_nodes, - spec_method=mode, - indices=indices, - n_epochs_used=n_epochs, - freqs_used=freqs_used, - times_used=times, - n_tapers=n_tapers, - metadata=metadata, - events=events, - event_id=event_id - ) + for _con, _patterns, _method in zip(con, patterns, method): + kwargs = dict( + data=_con, patterns=_patterns, names=names, freqs=freqs, + method=_method, n_nodes=n_nodes, spec_method=mode, indices=indices, + n_epochs_used=n_epochs, freqs_used=freqs_used, times_used=times, + n_tapers=n_tapers, metadata=metadata, events=events, + event_id=event_id, rank=rank, + n_lags=gc_n_lags if _method in _gc_methods else None) # create the connectivity container if mode in ['multitaper', 'fourier']: klass = SpectralConnectivity @@ -1223,3 +2054,46 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None, conn_list = conn_list[0] return conn_list + + +def _check_rank_input(rank, data, sfreq, indices): + """Check the rank argument is appropriate and compute rank if missing.""" + # UNTIL RAGGED ARRAYS SUPPORTED + indices = np.array([[indices[0]], [indices[1]]]) + + if rank is None: + + rank = np.zeros((2, len(indices[0])), dtype=int) + + if isinstance(data, BaseEpochs): + data_arr = data.get_data() + else: + data_arr = data + + for group_i in range(2): + for con_i, con_idcs in enumerate(indices[group_i]): + s = np.linalg.svd(data_arr[:, con_idcs], compute_uv=False) + rank[group_i][con_i] = np.min( + [np.count_nonzero(epoch >= epoch[0] * 1e-10) + for epoch in s]) + + logger.info('Estimated data ranks:') + con_i = 1 + for seed_rank, target_rank in zip(rank[0], rank[1]): + logger.info(' connection %i - seeds (%i); targets (%i)' + % (con_i, seed_rank, target_rank, )) + con_i += 1 + + rank = tuple((np.array(rank[0]), np.array(rank[1]))) + + else: + for seed_idcs, target_idcs, seed_rank, target_rank in zip( + indices[0], indices[1], rank[0], rank[1]): + if not (0 < seed_rank <= len(seed_idcs) and + 0 < target_rank <= len(target_idcs)): + raise ValueError( + 'ranks for seeds and targets must be > 0 and <= the ' + 'number of channels in the seeds and targets, ' + 'respectively, for each connection') + + return rank diff --git a/mne_connectivity/spectral/tests/data/README.md b/mne_connectivity/spectral/tests/data/README.md new file mode 100644 index 00000000..ea9da2bd --- /dev/null +++ b/mne_connectivity/spectral/tests/data/README.md @@ -0,0 +1,30 @@ +Author: Thomas S. Binns + +The files found here are used for the regression test of the multivariate +connectivity methods for MIC, MIM, GC, and TRGC +(`test_multivariate_spectral_connectivity_epochs_regression()` of +`test_spectral.py`). + +`example_multivariate_data.pkl` consists of four channels of randomly-generated +data with 15 epochs and 200 timepoints per epoch. Connectivity was computed in +MATLAB using the original implementations of these methods and saved as a +dictionary in `example_multivariate_matlab_results.pkl`. A publicly-available +implementation of the methods in MATLAB can be found here: +https://github.com/sccn/roiconnect. + +As the MNE code for computing the cross-spectral density matrix is not +available in MATLAB, the CSD matrix was computed using MNE and then loaded into +MATLAB to compute the connectivity from the original implementations using the +same processing settings in MATLAB and Python. That is: a sampling frequency of +100 Hz; method='multitaper'; fskip=0; faverage=False; tmin=0; tmax=None; +mt_bandwidth=4; mt_low_bias=True; mt_adaptive=False; gc_n_lags=20; +rank=([2], [2]) - i.e. no rank subspace projection; indices=([0, 1], [2, 3]) - +i.e. connection from first two channels to last two channels. It is +important that no changes are made to the settings for computing the CSD or the +final connectivity scores, otherwise this test will be invalid! + +One key difference is that the MATLAB implementation for computing MIC returns +the absolute value of the results, so we must take the absolute value of the +results returned from the MNE function to make the comparison. We do not return +the absolute values of the results, as relevant information such as phase angle +differences are lost. \ No newline at end of file diff --git a/mne_connectivity/spectral/tests/data/example_multivariate_data.pkl b/mne_connectivity/spectral/tests/data/example_multivariate_data.pkl new file mode 100644 index 00000000..ccd9f2a3 Binary files /dev/null and b/mne_connectivity/spectral/tests/data/example_multivariate_data.pkl differ diff --git a/mne_connectivity/spectral/tests/data/example_multivariate_matlab_results.pkl b/mne_connectivity/spectral/tests/data/example_multivariate_matlab_results.pkl new file mode 100644 index 00000000..7b1de665 Binary files /dev/null and b/mne_connectivity/spectral/tests/data/example_multivariate_matlab_results.pkl differ diff --git a/mne_connectivity/spectral/tests/test_spectral.py b/mne_connectivity/spectral/tests/test_spectral.py index 3286bba8..fa8cf44d 100644 --- a/mne_connectivity/spectral/tests/test_spectral.py +++ b/mne_connectivity/spectral/tests/test_spectral.py @@ -1,6 +1,8 @@ +import os import numpy as np from numpy.testing import (assert_allclose, assert_array_almost_equal, assert_array_less) +import pandas as pd import pytest from mne import (EpochsArray, SourceEstimate, create_info) from mne.filter import filter_data @@ -408,6 +410,301 @@ def test_spectral_connectivity(method, mode): assert (out_lens[0] == 10) +@pytest.mark.parametrize('method', ['mic', 'mim', 'gc']) +def test_spectral_connectivity_epochs_multivariate(method): + """Test over-epoch multivariate connectivity methods.""" + mode = 'multitaper' # stick with single mode in interest of time + + sfreq = 100.0 # Hz + n_signals = 4 # should be even! + n_seeds = n_signals // 2 + n_epochs = 10 + n_times = 200 # samples + trans_bandwidth = 2.0 # Hz + delay = 10 # samples (non-zero delay needed for ImCoh and GC to be >> 0) + + indices = tuple([np.arange(n_seeds), np.arange(n_seeds) + n_seeds]) + + # 15-25 Hz connectivity + fstart, fend = 15.0, 25.0 + rng = np.random.RandomState(0) + data = rng.randn(n_signals, n_epochs * n_times + delay) + # simulate connectivity from fstart to fend + data[n_seeds:, :] = filter_data( + data[:n_seeds, :], sfreq, fstart, fend, filter_length='auto', + fir_design='firwin2', l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth) + # add some noise, so the spectrum is not exactly zero + data[n_seeds:, :] += 1e-2 * rng.randn(n_seeds, n_times * n_epochs + delay) + # shift the seeds to that the targets are a delayed version of them + data[:n_seeds, :n_epochs * n_times] = data[:n_seeds, delay:] + data = data[:, :n_times * n_epochs] + data = data.reshape(n_signals, n_epochs, n_times) + data = np.transpose(data, [1, 0, 2]) + + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices, sfreq=sfreq, + gc_n_lags=20) + freqs = con.freqs + gidx = (freqs.index(fstart), freqs.index(fend) + 1) + bidx = (freqs.index(fstart - trans_bandwidth * 2), + freqs.index(fend + trans_bandwidth * 2) + 1) + + if method in ['mic', 'mim']: + lower_t = 0.2 + upper_t = 0.5 + + assert np.abs(con.get_data())[0, gidx[0]:gidx[1]].mean() > upper_t + assert np.abs(con.get_data())[0, :bidx[0]].mean() < lower_t + assert np.abs(con.get_data())[0, bidx[1]:].mean() < lower_t + + elif method == 'gc': + lower_t = 0.2 + upper_t = 0.8 + + assert con.get_data()[0, gidx[0]:gidx[1]].mean() > upper_t + assert con.get_data()[0, :bidx[0]].mean() < lower_t + assert con.get_data()[0, bidx[1]:].mean() < lower_t + + # check that target -> seed connectivity is low + indices_ts = (indices[1], indices[0]) + con_ts = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices_ts, sfreq=sfreq, + gc_n_lags=20) + assert con_ts.get_data()[0, gidx[0]:gidx[1]].mean() < lower_t + + # check that TRGC is positive (i.e. net seed -> target connectivity not + # due to noise) + con_tr = spectral_connectivity_epochs( + data, method='gc_tr', mode=mode, indices=indices, sfreq=sfreq, + gc_n_lags=20) + con_ts_tr = spectral_connectivity_epochs( + data, method='gc_tr', mode=mode, indices=indices_ts, sfreq=sfreq, + gc_n_lags=20) + trgc = ((con.get_data() - con_ts.get_data()) - + (con_tr.get_data() - con_ts_tr.get_data())) + # checks that TRGC is positive and >> 0 (for 15-25 Hz) + assert np.all(trgc[0, gidx[0]:gidx[1]] > 0) + assert np.all(trgc[0, gidx[0]:gidx[1]] > upper_t) + # checks that TRGC is ~ 0 for other frequencies + assert np.allclose(trgc[0, :bidx[0]].mean(), 0, atol=lower_t) + assert np.allclose(trgc[0, bidx[1]:].mean(), 0, atol=lower_t) + + # check all-to-all conn. computed for MIC/MIM when no indices given + if method in ['mic', 'mim']: + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=None, sfreq=sfreq) + assert (np.array(con.indices).tolist() == + [[[0, 1, 2, 3]], [[0, 1, 2, 3]]]) + + # check shape of MIC patterns + if method == 'mic': + for mode in ['multitaper', 'cwt_morlet']: + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices, sfreq=sfreq, + fmin=10, fmax=25, cwt_freqs=np.arange(10, 25), + faverage=True) + + if mode == 'cwt_morlet': + patterns_shape = ( + (len(indices[0]), len(con.freqs), len(con.times)), + (len(indices[1]), len(con.freqs), len(con.times))) + else: + patterns_shape = ( + (len(indices[0]), len(con.freqs)), + (len(indices[1]), len(con.freqs))) + assert np.shape(con.attrs["patterns"][0][0]) == patterns_shape[0] + assert np.shape(con.attrs["patterns"][1][0]) == patterns_shape[1] + + # only check these once for speed + if mode == 'multitaper': + # check patterns averaged over freqs + fmin = (5., 15.) + fmax = (15., 30.) + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices, + sfreq=sfreq, fmin=fmin, fmax=fmax, faverage=True) + assert np.shape(con.attrs["patterns"][0][0])[1] == len(fmin) + assert np.shape(con.attrs["patterns"][1][0])[1] == len(fmin) + + # check patterns shape matches input data, not rank + rank = (np.array([1]), np.array([1])) + con = spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices, + sfreq=sfreq, rank=rank) + assert (np.shape(con.attrs["patterns"][0][0])[0] == + len(indices[0])) + assert (np.shape(con.attrs["patterns"][1][0])[0] == + len(indices[1])) + + +def test_multivariate_spectral_connectivity_epochs_regression(): + """Test multivar. spectral connectivity over epochs for regression. + + The multivariate methods were originally implemented in MATLAB by their + respective authors. To show that this Python implementation is identical + and to avoid any future regressions, we compare the results of the Python + and MATLAB implementations on some example data (randomly generated). + + As the MNE code for computing the cross-spectral density matrix is not + available in MATLAB, the CSD matrix was computed using MNE and then loaded + into MATLAB to compute the connectivity from the original implementations + using the same processing settings in MATLAB and Python. + + It is therefore important that no changes are made to the settings for + computing the CSD or the final connectivity scores! + """ + fpath = os.path.dirname(os.path.realpath(__file__)) + data = pd.read_pickle( + os.path.join(fpath, 'data', 'example_multivariate_data.pkl')) + sfreq = 100 + indices = tuple([[0, 1], [2, 3]]) + methods = ['mic', 'mim', 'gc', 'gc_tr'] + con = spectral_connectivity_epochs( + data, method=methods, indices=indices, mode='multitaper', sfreq=sfreq, + fskip=0, faverage=False, tmin=0, tmax=None, mt_bandwidth=4, + mt_low_bias=True, mt_adaptive=False, gc_n_lags=20, + rank=tuple([[2], [2]]), n_jobs=1) + + # should take the absolute of the MIC scores, as the MATLAB implementation + # returns the absolute values. + mne_results = {this_con.method: np.abs(this_con.get_data()) + for this_con in con} + matlab_results = pd.read_pickle( + os.path.join(fpath, 'data', 'example_multivariate_matlab_results.pkl')) + for method in methods: + assert_allclose(matlab_results[method], mne_results[method], 1e-5) + + +@pytest.mark.parametrize( + 'method', ['mic', 'mim', 'gc', 'gc_tr', ['mic', 'mim', 'gc', 'gc_tr']]) +@pytest.mark.parametrize('mode', ['multitaper', 'fourier', 'cwt_morlet']) +def test_multivar_spectral_connectivity_epochs_error_catch(method, mode): + """Test error catching for multivar. freq.-domain connectivity methods.""" + sfreq = 50. + n_signals = 4 # Do not change! + n_epochs = 8 + n_times = 256 + rng = np.random.RandomState(0) + data = rng.randn(n_epochs, n_signals, n_times) + indices = (np.arange(0, 2), np.arange(2, 4)) + cwt_freqs = np.arange(10, 25 + 1) + + # check bad indices with repeated channels + with pytest.raises(ValueError, + match='seed and target indices cannot contain'): + repeated_indices = tuple([[0, 1, 1], [2, 2, 3]]) + spectral_connectivity_epochs( + data, method=method, mode=mode, indices=repeated_indices, + sfreq=sfreq, gc_n_lags=10) + + # check mixed methods caught + with pytest.raises(ValueError, + match='bivariate and multivariate connectivity'): + if isinstance(method, str): + mixed_methods = [method, 'coh'] + elif isinstance(method, list): + mixed_methods = [*method, 'coh'] + spectral_connectivity_epochs(data, method=mixed_methods, mode=mode, + indices=indices, sfreq=sfreq, + cwt_freqs=cwt_freqs) + + # check bad rank args caught + too_low_rank = (np.array([0]), np.array([0])) + with pytest.raises(ValueError, + match='ranks for seeds and targets must be'): + spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices, + sfreq=sfreq, rank=too_low_rank, cwt_freqs=cwt_freqs) + too_high_rank = (np.array([3]), np.array([3])) + with pytest.raises(ValueError, + match='ranks for seeds and targets must be'): + spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices, + sfreq=sfreq, rank=too_high_rank, cwt_freqs=cwt_freqs) + + # check rank-deficient data caught + bad_data = data.copy() + bad_data[:, 1] = bad_data[:, 0] + bad_data[:, 3] = bad_data[:, 2] + assert np.all(np.linalg.matrix_rank(bad_data[:, (0, 1), :]) == 1) + assert np.all(np.linalg.matrix_rank(bad_data[:, (2, 3), :]) == 1) + if isinstance(method, str): + rank_con = spectral_connectivity_epochs( + bad_data, method=method, mode=mode, indices=indices, sfreq=sfreq, + gc_n_lags=10, cwt_freqs=cwt_freqs) + assert rank_con.attrs["rank"] == (np.array([1]), np.array([1])) + + if method in ['mic', 'mim']: + # check rank-deficient transformation matrix caught + with pytest.raises(RuntimeError, + match='the transformation matrix'): + spectral_connectivity_epochs( + bad_data, method=method, mode=mode, indices=indices, + sfreq=sfreq, rank=(np.array([2]), np.array([2])), + cwt_freqs=cwt_freqs) + + # only check these once for speed + if method == 'gc' and mode == 'multitaper': + # check bad n_lags caught + frange = (5, 10) + n_lags = 200 # will be far too high + with pytest.raises(ValueError, match='the number of lags'): + spectral_connectivity_epochs( + data, method=method, mode=mode, indices=indices, sfreq=sfreq, + fmin=frange[0], fmax=frange[1], gc_n_lags=n_lags, + cwt_freqs=cwt_freqs) + + # check no indices caught + with pytest.raises(ValueError, match='indices must be specified'): + spectral_connectivity_epochs(data, method=method, mode=mode, + indices=None, sfreq=sfreq, + cwt_freqs=cwt_freqs) + + # check intersecting indices caught + bad_indices = (np.array([0, 1]), np.array([0, 2])) + with pytest.raises(ValueError, + match='seed and target indices must not intersect'): + spectral_connectivity_epochs(data, method=method, mode=mode, + indices=bad_indices, sfreq=sfreq, + cwt_freqs=cwt_freqs) + + # check bad fmin/fmax caught + with pytest.raises(ValueError, + match='computing Granger causality on multiple'): + spectral_connectivity_epochs(data, method=method, mode=mode, + indices=indices, sfreq=sfreq, + fmin=(10., 15.), fmax=(15., 20.), + cwt_freqs=cwt_freqs) + + # check rank-deficient autocovariance caught + with pytest.raises(RuntimeError, + match='the autocovariance matrix is singular'): + spectral_connectivity_epochs( + bad_data, method=method, mode=mode, indices=indices, + sfreq=sfreq, rank=(np.array([2]), np.array([2])), + cwt_freqs=cwt_freqs) + + +@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr']) +def test_multivar_spectral_connectivity_parallel(method): + """Test multivar. freq.-domain connectivity methods run in parallel.""" + sfreq = 50. + n_signals = 4 # Do not change! + n_epochs = 8 + n_times = 256 + rng = np.random.RandomState(0) + data = rng.randn(n_epochs, n_signals, n_times) + indices = (np.arange(0, 2), np.arange(2, 4)) + + spectral_connectivity_epochs( + data, method=method, mode="multitaper", indices=indices, sfreq=sfreq, + gc_n_lags=10, n_jobs=2) + spectral_connectivity_time( + data, freqs=np.arange(10, 25), method=method, mode="multitaper", + indices=indices, sfreq=sfreq, gc_n_lags=10, n_jobs=2) + + @ pytest.mark.parametrize('kind', ('epochs', 'ndarray', 'stc', 'combo')) def test_epochs_tmin_tmax(kind): """Test spectral.spectral_connectivity_epochs with epochs and arrays.""" @@ -472,9 +769,9 @@ def test_epochs_tmin_tmax(kind): assert len(w) == 1 # just one even though there were multiple epochs -@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @pytest.mark.parametrize( - 'mode', ['cwt_morlet', 'multitaper']) + 'method', ['coh', 'mic', 'mim', 'plv', 'pli', 'wpli', 'ciplv']) +@pytest.mark.parametrize('mode', ['cwt_morlet', 'multitaper']) @pytest.mark.parametrize('data_option', ['sync', 'random']) def test_spectral_connectivity_time_phaselocked(method, mode, data_option): """Test time-resolved spectral connectivity with simulated phase-locked @@ -500,30 +797,109 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option): wave_freq * epoch_length * np.pi + phase, n_times) data[i, c] = np.squeeze(np.sin(x)) + + multivar_methods = ['mic', 'mim'] + # the frequency band should contain the frequency at which there is a # hypothesized "connection" freq_band_low_limit = (8.) freq_band_high_limit = (13.) freqs = np.arange(freq_band_low_limit, freq_band_high_limit + 1) - con = spectral_connectivity_time(data, freqs, method=method, mode=mode, - sfreq=sfreq, fmin=freq_band_low_limit, - fmax=freq_band_high_limit, - n_jobs=1, - faverage=True, average=True, sm_times=0) - assert con.shape == (n_channels ** 2, len(con.freqs)) - con_matrix = con.get_data('dense')[..., 0] + con = spectral_connectivity_time( + data, freqs, method=method, mode=mode, sfreq=sfreq, + fmin=freq_band_low_limit, fmax=freq_band_high_limit, n_jobs=1, + faverage=True if method != 'mic' else False, + average=True if method != 'mic' else False, sm_times=0) + con_matrix = con.get_data() + + # MIC values can be pos. and neg., so must be averaged after taking the + # absolute values for the test to work + if method in multivar_methods: + if method == 'mic': + con_matrix = np.mean(np.abs(con_matrix), axis=(0, 2)) + assert con.shape == (n_epochs, 1, len(con.freqs)) + else: + assert con.shape == (1, len(con.freqs)) + else: + assert con.shape == (n_channels ** 2, len(con.freqs)) + con_matrix = np.reshape(con_matrix, (n_channels, n_channels))[ + np.tril_indices(n_channels, -1)] + if data_option == 'sync': # signals are perfectly phase-locked, connectivity matrix should be - # a lower triangular matrix of ones - assert np.allclose(con_matrix, - np.tril(np.ones(con_matrix.shape), - k=-1), - atol=0.01) + # a matrix of ones + assert np.allclose(con_matrix, np.ones(con_matrix.shape), atol=0.01) if data_option == 'random': # signals are random, all connectivity values should be small # 0.5 is picked rather arbitrarily such that the obsolete wrong # implementation fails - assert np.all(con_matrix) <= 0.5 + assert np.all(con_matrix <= 0.5) + + +def test_spectral_connectivity_time_delayed(): + """Test per-epoch Granger causality with time-delayed data. + + N.B.: the spectral_connectivity_time method seems to be more unstable than + spectral_connectivity_epochs for GC estimation. Accordingly, we assess + Granger scores only in the context of the noise-corrected TRGC metric, + where the true directionality of the connections seems to identified. + """ + mode = 'multitaper' # stick with single mode in interest of time + + sfreq = 100.0 # Hz + n_signals = 4 # should be even! + n_seeds = n_signals // 2 + n_epochs = 10 + n_times = 200 # samples + trans_bandwidth = 2.0 # Hz + delay = 5 # samples (non-zero delay needed for GC to be >> 0) + + indices = tuple([np.arange(n_seeds), np.arange(n_seeds) + n_seeds]) + + # 20-30 Hz connectivity + fstart, fend = 20.0, 30.0 + rng = np.random.RandomState(0) + data = rng.randn(n_signals, n_epochs * n_times + delay) + # simulate connectivity from fstart to fend + data[n_seeds:, :] = filter_data( + data[:n_seeds, :], sfreq, fstart, fend, filter_length='auto', + fir_design='firwin2', l_trans_bandwidth=trans_bandwidth, + h_trans_bandwidth=trans_bandwidth) + # add some noise, so the spectrum is not exactly zero + data[n_seeds:, :] += 1e-2 * rng.randn(n_seeds, n_times * n_epochs + delay) + # shift the seeds to that the targets are a delayed version of them + data[:n_seeds, :n_epochs * n_times] = data[:n_seeds, delay:] + data = data[:, :n_times * n_epochs] + data = data.reshape(n_signals, n_epochs, n_times) + data = np.transpose(data, [1, 0, 2]) + + freqs = np.arange(2.5, 50, 0.5) + con_st = spectral_connectivity_time( + data, freqs, method=['gc', 'gc_tr'], indices=indices, mode=mode, + sfreq=sfreq, n_jobs=1, gc_n_lags=20, n_cycles=5, average=True) + con_ts = spectral_connectivity_time( + data, freqs, method=['gc', 'gc_tr'], indices=(indices[1], indices[0]), + mode=mode, sfreq=sfreq, n_jobs=1, gc_n_lags=20, n_cycles=5, + average=True) + st = con_st[0].get_data() + st_tr = con_st[1].get_data() + ts = con_ts[0].get_data() + ts_tr = con_ts[1].get_data() + trgc = (st - ts) - (st_tr - ts_tr) + + freqs = con_st[0].freqs + gidx = (freqs.index(fstart), freqs.index(fend) + 1) + bidx = (freqs.index(fstart - trans_bandwidth * 2), + freqs.index(fend + trans_bandwidth * 2) + 1) + + # assert that TRGC (i.e. net, noise-corrected connectivity) is positive and + # >> 0 (i.e. that there is indeed a flow of info. from seeds to targets, + # as simulated) + assert np.all(trgc[:, gidx[0]:gidx[1]] > 0) + assert trgc[:, gidx[0]:gidx[1]].mean() > 0.4 + # check that non-interacting freqs. have close to zero connectivity + assert np.allclose(trgc[0, :bidx[0]].mean(), 0, atol=0.1) + assert np.allclose(trgc[0, bidx[1]:].mean(), 0, atol=0.1) @pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv']) @@ -671,6 +1047,115 @@ def test_spectral_connectivity_time_padding(method, mode, padding): for idx, jdx in triu_inds) +@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr']) +@pytest.mark.parametrize('average', [True, False]) +@pytest.mark.parametrize('faverage', [True, False]) +def test_multivar_spectral_connectivity_time_shapes(method, average, faverage): + """Test result shapes of time-resolved multivar. connectivity methods.""" + sfreq = 50. + n_signals = 4 # Do not change! + n_epochs = 8 + n_times = 500 + rng = np.random.RandomState(0) + data = rng.randn(n_epochs, n_signals, n_times) + indices = (np.arange(0, 2), np.arange(2, 4)) + freqs = np.arange(10, 25 + 1) + + con_shape = [1] + if faverage: + con_shape.append(1) + else: + con_shape.append(len(freqs)) + if not average: + con_shape = [n_epochs, *con_shape] + + # check shape of results when averaging across epochs + con = spectral_connectivity_time( + data, freqs, indices=indices, method=method, sfreq=sfreq, + faverage=faverage, average=average, gc_n_lags=10) + assert con.shape == tuple(con_shape) + + # check shape of MIC patterns are correct + if method == 'mic': + patterns_shape = [len(indices[0])] + if faverage: + patterns_shape.append(1) + else: + patterns_shape.append(len(freqs)) + if not average: + patterns_shape = [n_epochs, *patterns_shape] + patterns_shape = [2, *patterns_shape] + assert np.array(con.attrs['patterns']).shape == tuple(patterns_shape) + + +@pytest.mark.parametrize( + 'method', ['mic', 'mim', 'gc', 'gc_tr']) +def test_multivar_spectral_connectivity_time_error_catch(method): + """Test error catching for time-resolved multivar. connectivity methods.""" + sfreq = 50. + n_signals = 4 # Do not change! + n_epochs = 8 + n_times = 256 + data = np.random.rand(n_epochs, n_signals, n_times) + indices = (np.arange(0, 2), np.arange(2, 4)) + freqs = np.arange(10, 25 + 1) + + # check bad indices with repeated channels + with pytest.raises(ValueError, + match='seed and target indices cannot contain'): + repeated_indices = tuple([[0, 1, 1], [2, 2, 3]]) + spectral_connectivity_time(data, freqs, method=method, + indices=repeated_indices, sfreq=sfreq) + + # check mixed methods caught + with pytest.raises(ValueError, + match='bivariate and multivariate connectivity'): + mixed_methods = [method, 'coh'] + spectral_connectivity_time(data, freqs, method=mixed_methods, + indices=indices, sfreq=sfreq) + + # check bad rank args caught + too_low_rank = (np.array([0]), np.array([0])) + with pytest.raises(ValueError, + match='ranks for seeds and targets must be'): + spectral_connectivity_time( + data, freqs, method=method, indices=indices, sfreq=sfreq, + rank=too_low_rank) + too_high_rank = (np.array([3]), np.array([3])) + with pytest.raises(ValueError, + match='ranks for seeds and targets must be'): + spectral_connectivity_time( + data, freqs, method=method, indices=indices, sfreq=sfreq, + rank=too_high_rank) + + # check all-to-all conn. computed for MIC/MIM when no indices given + if method in ['mic', 'mim']: + con = spectral_connectivity_epochs( + data, freqs, method=method, indices=None, sfreq=sfreq) + assert (np.array(con.indices).tolist() == + [[[0, 1, 2, 3]], [[0, 1, 2, 3]]]) + + if method in ['gc', 'gc_tr']: + # check no indices caught + with pytest.raises(ValueError, match='indices must be specified'): + spectral_connectivity_time(data, freqs, method=method, + indices=None, sfreq=sfreq) + + # check intersecting indices caught + bad_indices = (np.array([0, 1]), np.array([0, 2])) + with pytest.raises(ValueError, + match='seed and target indices must not intersect'): + spectral_connectivity_time(data, freqs, method=method, + indices=bad_indices, sfreq=sfreq) + + # check bad fmin/fmax caught + with pytest.raises(ValueError, + match='computing Granger causality on multiple'): + spectral_connectivity_time(data, freqs, method=method, + indices=indices, sfreq=sfreq, + fmin=(5., 15.), fmax=(15., 30.)) + + def test_save(tmp_path): """Test saving results of spectral connectivity.""" rng = np.random.RandomState(0) diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index b11a9640..7c1aabe6 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -1,5 +1,6 @@ # Authors: Adam Li # Santeri Ruuskanen +# Thomas S. Binns # # License: BSD (3-clause) @@ -12,11 +13,16 @@ from mne.utils import (logger, verbose) from ..base import (SpectralConnectivity, EpochSpectralConnectivity) -from .epochs import _compute_freq_mask +from .epochs import (_MICEst, _MIMEst, _GCEst, _GCTREst, _compute_freq_mask, + _check_rank_input) from .smooth import _create_kernel, _smooth_spectra from ..utils import check_indices, fill_doc +_multivariate_methods = ['mic', 'mim', 'gc', 'gc_tr'] +_gc_methods = ['gc', 'gc_tr'] + + @verbose @fill_doc def spectral_connectivity_time(data, freqs, method='coh', average=False, @@ -24,8 +30,9 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, fmax=None, fskip=0, faverage=False, sm_times=0, sm_freqs=1, sm_kernel='hanning', padding=0, mode='cwt_morlet', mt_bandwidth=None, - n_cycles=7, decim=1, n_jobs=1, verbose=None): - """Compute time-frequency-domain connectivity measures. + n_cycles=7, gc_n_lags=40, rank=None, decim=1, + n_jobs=1, verbose=None): + r"""Compute time-frequency-domain connectivity measures. This function computes spectral connectivity over time from epoched data. The data may consist of a single epoch. @@ -44,20 +51,29 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, ``fmax`` are used. method : str | list of str Connectivity measure(s) to compute. These can be - ``['coh', 'plv', 'ciplv', 'pli', 'wpli']``. These are: + ``['coh', 'mic', 'mim', 'plv', 'ciplv', 'pli', 'wpli', 'gc', + 'gc_tr']``. These are: * 'coh' : Coherence + * 'mic' : Maximised Imaginary part of Coherency (MIC) + * 'mim' : Multivariate Interaction Measure (MIM) * 'plv' : Phase-Locking Value (PLV) * 'ciplv' : Corrected imaginary Phase-Locking Value * 'pli' : Phase-Lag Index * 'wpli' : Weighted Phase-Lag Index + * 'gc' : State-space Granger Causality (GC) + * 'gc_tr' : State-space GC on time-reversed signals + Multivariate methods (``['mic', 'mim', 'gc', 'gc_tr]``) cannot be + called with the other methods. average : bool Average connectivity scores over epochs. If ``True``, output will be an instance of :class:`SpectralConnectivity`, otherwise :class:`EpochSpectralConnectivity`. indices : tuple of array_like | None Two arrays with indices of connections for which to compute - connectivity. I.e. it is a ``(n_pairs, 2)`` array essentially. - If `None`, all connections are computed. + connectivity. If a multivariate method is called, the indices are for a + single connection between all seeds and all targets. If None, all + connections are computed, unless a Granger causality method is called, + in which case an error is raised. sfreq : float The sampling frequency. Required if data is not :class:`Epochs `. @@ -103,6 +119,16 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, frequency. The number of cycles ``n_cycles`` and the frequencies of interest ``cwt_freqs`` define the temporal window length. For details, see :func:`mne.time_frequency.tfr_array_morlet` documentation. + gc_n_lags : int + Number of lags to use for the vector autoregressive model when + computing Granger causality. Higher values increase computational cost, + but reduce the degree of spectral smoothing in the results. Only used + if ``method`` contains any of ``['gc', 'gc_tr']``. + rank : tuple of array | None + Two arrays with the rank to project the seed and target data to, + respectively, using singular value decomposition. If `None`, the rank + of the data is computed and projected to. Only used if ``method`` + contains any of ``['mic', 'mim', 'gc', 'gc_tr']``. decim : int To reduce memory usage, decimation factor after time-frequency decomposition. Returns ``tfr[…, ::decim]``. @@ -119,9 +145,10 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, or a list of instances corresponding to connectivity measures if several connectivity measures are specified. The shape of each connectivity dataset is - (n_epochs, n_signals, n_signals, n_freqs) when ``indices`` is `None` - and (n_epochs, n_nodes, n_nodes, n_freqs) when ``indices`` is specified - and ``n_nodes = len(indices[0])``. + (n_epochs, n_signals, n_signals, n_freqs) when ``indices`` is `None`, + (n_epochs, n_nodes, n_nodes, n_freqs) when ``indices`` is specified + and ``n_nodes = len(indices[0])``, or (n_epochs, 1, 1, n_freqs) when a + multi-variate method is called regardless of "indices". See Also -------- @@ -159,11 +186,11 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, (i.e., time window length). By default, the connectivity between all signals is computed (only - connections corresponding to the lower-triangular part of the - connectivity matrix). If one is only interested in the connectivity - between some signals, the ``indices`` parameter can be used. For example, - to compute the connectivity between the signal with index 0 and signals - 2, 3, 4 (a total of 3 connections), one can use the following:: + connections corresponding to the lower-triangular part of the connectivity + matrix). If one is only interested in the connectivity between some + signals, the "indices" parameter can be used. For example, to compute the + connectivity between the signal with index 0 and signals "2, 3, 4" (a total + of 3 connections) one can use the following:: indices = (np.array([0, 0, 0]), # row indices np.array([2, 3, 4])) # col indices @@ -174,6 +201,15 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, In this case ``con.get_data().shape = (3, n_freqs)``. The connectivity scores are in the same order as defined indices. + For multivariate methods, this is handled differently. If "indices" is + None, connectivity between all signals will attempt to be computed (this is + not possible if a Granger causality method is called). If "indices" is + specified, the seeds and targets are treated as a single connection. For + example, to compute the connectivity between signals 0, 1, 2 and 3, 4, 5, + one would use the same approach as above, however the signals would all be + considered for a single connection and the connectivity scores would have + the shape (1, n_freqs). + **Supported Connectivity Measures** The connectivity method(s) is specified using the ``method`` parameter. The @@ -187,12 +223,31 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, C = --------------------- sqrt(E[Sxx] * E[Syy]) + 'mic' : Maximised Imaginary part of Coherency (MIC) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIC=\Large{\frac{\boldsymbol{\alpha}^T \boldsymbol{E \beta}} + {\parallel\boldsymbol{\alpha}\parallel \parallel\boldsymbol{\beta} + \parallel}}` + + where: :math:`\boldsymbol{E}` is the imaginary part of the + transformed cross-spectral density between seeds and targets; and + :math:`\boldsymbol{\alpha}` and :math:`\boldsymbol{\beta}` are + eigenvectors for the seeds and targets, such that + :math:`\boldsymbol{\alpha}^T \boldsymbol{E \beta}` maximises + connectivity between the seeds and targets. + + 'mim' : Multivariate Interaction Measure (MIM) + :footcite:`EwaldEtAl2012` given by: + + :math:`MIM=tr(\boldsymbol{EE}^T)` + 'plv' : Phase-Locking Value (PLV) :footcite:`LachauxEtAl1999` given by:: PLV = |E[Sxy/|Sxy|]| - 'ciplv' : Corrected imaginary PLV (icPLV) :footcite:`BrunaEtAl2018` + 'ciplv' : Corrected imaginary PLV (ciPLV) :footcite:`BrunaEtAl2018` given by:: |E[Im(Sxy/|Sxy|)]| @@ -210,6 +265,25 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, WPLI = ------------------ E[|Im(Sxy)|] + 'gc' : State-space Granger Causality (GC) :footcite:`BarnettSeth2015` + given by: + + :math:`GC = ln\Large{(\frac{\lvert\boldsymbol{S}_{tt}\rvert}{\lvert + \boldsymbol{S}_{tt}-\boldsymbol{H}_{ts}\boldsymbol{\Sigma}_{ss + \lvert t}\boldsymbol{H}_{ts}^*\rvert}})`, + + where: :math:`s` and :math:`t` represent the seeds and targets, + respectively; :math:`\boldsymbol{H}` is the spectral transfer + function; :math:`\boldsymbol{\Sigma}` is the residuals matrix of + the autoregressive model; and :math:`\boldsymbol{S}` is + :math:`\boldsymbol{\Sigma}` transformed by :math:`\boldsymbol{H}`. + + 'gc_tr' : State-space GC on time-reversed signals + :footcite:`BarnettSeth2015,WinklerEtAl2016` given by the same equation + as for 'gc', but where the autocovariance sequence from which the + autoregressive model is produced is transposed to mimic the reversal of + the original signal in time. + Parallel computation can be activated by setting the ``n_jobs`` parameter. Under the hood, this utilizes the ``joblib`` library. For effective parallelization, you should activate memory mapping in MNE-Python by @@ -284,6 +358,22 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, if np.any(fmin > fmax): raise ValueError('fmax must be larger than fmin') + if len(fmin) != 1 and any( + this_method in _gc_methods for this_method in method + ): + raise ValueError('computing Granger causality on multiple frequency ' + 'bands is not yet supported') + + if any(this_method in _multivariate_methods for this_method in method): + if not all(this_method in _multivariate_methods for + this_method in method): + raise ValueError( + 'bivariate and multivariate connectivity methods cannot be ' + 'used in the same function call') + multivariate_con = True + else: + multivariate_con = False + # convert kernel width in time to samples if isinstance(sm_times, (int, float)): sm_times = int(np.round(sm_times * sfreq)) @@ -302,12 +392,45 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, # get indices of pairs of (group) regions if indices is None: - indices_use = np.tril_indices(n_signals, k=-1) + if multivariate_con: + if any(this_method in _gc_methods for this_method in method): + raise ValueError( + 'indices must be specified when computing Granger ' + 'causality, as all-to-all connectivity is not supported') + logger.info('using all indices for multivariate connectivity') + indices_use = (np.arange(n_signals, dtype=int), + np.arange(n_signals, dtype=int)) + else: + logger.info('only using indices for lower-triangular matrix') + indices_use = np.tril_indices(n_signals, k=-1) else: + if multivariate_con: + if ( + len(np.unique(indices[0])) != len(indices[0]) or + len(np.unique(indices[1])) != len(indices[1]) + ): + raise ValueError( + 'seed and target indices cannot contain repeated ' + 'channels for multivariate connectivity') + if any(this_method in _gc_methods for this_method in method): + if set(indices[0]).intersection(indices[1]): + raise ValueError( + 'seed and target indices must not intersect when ' + 'computing Granger causality') indices_use = check_indices(indices) source_idx = indices_use[0] target_idx = indices_use[1] - n_pairs = len(source_idx) + n_pairs = len(source_idx) if not multivariate_con else 1 + + # unique signals for which we actually need to compute the CSD of + signals_use = np.unique(np.r_[indices_use[0], indices_use[1]]) + + # check rank input and compute data ranks if necessary + if multivariate_con: + rank = _check_rank_input(rank, data, sfreq, indices_use) + else: + rank = None + gc_n_lags = None # check freqs if isinstance(freqs, (int, float)): @@ -354,26 +477,38 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, out_freqs = freqs conn = dict() + conn_patterns = dict() for m in method: conn[m] = np.zeros((n_epochs, n_pairs, n_freqs)) + conn_patterns[m] = np.full((n_epochs, 2, len(source_idx), n_freqs), + np.nan) logger.info('Connectivity computation...') # parameters to pass to the connectivity function call_params = dict( method=method, kernel=kernel, foi_idx=foi_idx, - source_idx=source_idx, target_idx=target_idx, + source_idx=source_idx, target_idx=target_idx, signals_use=signals_use, mode=mode, sfreq=sfreq, freqs=freqs, faverage=faverage, - n_cycles=n_cycles, mt_bandwidth=mt_bandwidth, - decim=decim, padding=padding, kw_cwt={}, kw_mt={}, n_jobs=n_jobs, - verbose=verbose) + n_cycles=n_cycles, mt_bandwidth=mt_bandwidth, gc_n_lags=gc_n_lags, + rank=rank, decim=decim, padding=padding, kw_cwt={}, kw_mt={}, + n_jobs=n_jobs, verbose=verbose, multivariate_con=multivariate_con) for epoch_idx in np.arange(n_epochs): logger.info(f' Processing epoch {epoch_idx+1} / {n_epochs} ...') - conn_tr = _spectral_connectivity(data[epoch_idx], **call_params) + scores, patterns = _spectral_connectivity(data[epoch_idx], + **call_params) for m in method: - conn[m][epoch_idx] = np.stack(conn_tr[m], axis=0) + conn[m][epoch_idx] = np.stack(scores[m], axis=0) + if multivariate_con and patterns[m] is not None: + conn_patterns[m][epoch_idx] = np.stack(patterns[m], axis=0) + for m in method: + if np.isnan(conn_patterns[m]).all(): + conn_patterns[m] = None + else: + # epochs x 2 x n_channels x n_freqs + conn_patterns[m] = conn_patterns[m].transpose((1, 0, 2, 3)) - if indices is None: + if indices is None and not multivariate_con: conn_flat = conn conn = dict() for m in method: @@ -385,18 +520,28 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, conn_flat[m].shape[2:]) conn[m] = this_conn - # create a Connectivity container - if average: - out = [SpectralConnectivity( - conn[m].mean(axis=0), freqs=out_freqs, n_nodes=n_signals, - names=names, indices=indices, method=method, spec_method=mode, - events=events, event_id=event_id, metadata=metadata) - for m in method] - else: - out = [EpochSpectralConnectivity( - conn[m], freqs=out_freqs, n_nodes=n_signals, names=names, - indices=indices, method=method, spec_method=mode, events=events, - event_id=event_id, metadata=metadata) for m in method] + if multivariate_con: + # UNTIL RAGGED ARRAYS SUPPORTED + indices = tuple( + [[np.array(indices_use[0])], [np.array(indices_use[1])]]) + + # create the connectivity containers + out = [] + for m in method: + store_params = { + 'data': conn[m], 'patterns': conn_patterns[m], 'freqs': out_freqs, + 'n_nodes': n_signals, 'names': names, 'indices': indices, + 'method': method, 'spec_method': mode, 'events': events, + 'event_id': event_id, 'metadata': metadata, 'rank': rank, + 'n_lags': gc_n_lags if m in _gc_methods else None} + if average: + store_params['data'] = np.mean(store_params['data'], axis=0) + if conn_patterns[m] is not None: + store_params['patterns'] = np.mean(store_params['patterns'], + axis=1) + out.append(SpectralConnectivity(**store_params)) + else: + out.append(EpochSpectralConnectivity(**store_params)) logger.info('[Connectivity computation done]') @@ -408,10 +553,10 @@ def spectral_connectivity_time(data, freqs, method='coh', average=False, def _spectral_connectivity(data, method, kernel, foi_idx, - source_idx, target_idx, + source_idx, target_idx, signals_use, mode, sfreq, freqs, faverage, n_cycles, - mt_bandwidth, decim, padding, kw_cwt, kw_mt, - n_jobs, verbose): + mt_bandwidth, gc_n_lags, rank, decim, padding, + kw_cwt, kw_mt, n_jobs, verbose, multivariate_con): """Estimate time-resolved connectivity for one epoch. Parameters @@ -428,6 +573,8 @@ def _spectral_connectivity(data, method, kernel, foi_idx, Defines the signal pairs of interest together with ``target_idx``. target_idx : array_like, shape (n_pairs,) Defines the signal pairs of interest together with ``source_idx``. + signals_use : list of int + The unique signals on which connectivity is to be computed. mode : str Time-frequency transformation method. sfreq : float @@ -443,19 +590,34 @@ def _spectral_connectivity(data, method, kernel, foi_idx, frequency. mt_bandwidth : float | None Multitaper time-bandwidth. + gc_n_lags : int + Number of lags to use for the vector autoregressive model when + computing Granger causality. + rank : tuple of array + Ranks to project the seed and target data to. decim : int Decimation factor after time-frequency decomposition. padding : float Amount of time to consider as padding at the beginning and end of each epoch in seconds. + multivariate_con : bool + Whether or not multivariate connectivity is to be computed. Returns ------- - this_conn : list of array - List of connectivity estimates corresponding to the metrics in - ``method``. Each element is an array of shape (n_pairs, n_freqs) or - (n_pairs, n_fbands) if ``faverage`` is `True`. + scores : dict + Dictionary containing the connectivity estimates corresponding to the + metrics in ``method``. Each element is an array of shape (n_pairs, + n_freqs) or (n_pairs, n_fbands) if ``faverage`` is `True`. + + patterns : dict + Dictionary containing the connectivity patterns (for reconstructing the + connectivity components in source-space) corresponding to the metrics + in ``method``, if multivariate methods are called, else an empty + dictionary. Each element is an array of shape (2, n_channels, n_freqs) + or (2, n_channels, 1) if ``faverage`` is `True`, where 2 corresponds to + the seed and target signals (respectively). """ n_pairs = len(source_idx) data = np.expand_dims(data, axis=0) @@ -500,13 +662,20 @@ def _spectral_connectivity(data, method, kernel, foi_idx, else None # compute for each connectivity method - this_conn = {} + scores = {} + patterns = {} conn = _parallel_con(out, method, kernel, foi_idx, source_idx, target_idx, - n_jobs, verbose, n_pairs, faverage, weights) + signals_use, gc_n_lags, rank, n_jobs, verbose, + n_pairs, faverage, weights, multivariate_con) for i, m in enumerate(method): - this_conn[m] = [out[i] for out in conn] + if multivariate_con: + scores[m] = conn[0][i] + patterns[m] = conn[1][i][:, 0] if conn[1][i] is not None else None + else: + scores[m] = [out[i] for out in conn] + patterns[m] = None - return this_conn + return scores, patterns ############################################################################### @@ -515,8 +684,9 @@ def _spectral_connectivity(data, method, kernel, foi_idx, ############################################################################### ############################################################################### -def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, - verbose, total, faverage, weights): +def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, + signals_use, gc_n_lags, rank, n_jobs, verbose, total, + faverage, weights, multivariate_con): """Compute spectral connectivity in parallel. Parameters @@ -533,6 +703,13 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, Defines the signal pairs of interest together with ``target_idx``. target_idx : array_like, shape (n_pairs,) Defines the signal pairs of interest together with ``source_idx``. + signals_use : list of int + The unique signals on which connectivity is to be computed. + gc_n_lags : int + Number of lags to use for the vector autoregressive model when + computing Granger causality. + rank : tuple of array + Ranks to project the seed and target data to. n_jobs : int Number of parallel jobs. total : int @@ -541,12 +718,17 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, Average over frequency bands. weights : array_like, shape (n_tapers, n_freqs, n_times) Multitaper weights. + multivariate_con : bool + Whether or not multivariate connectivity is being computed. Returns ------- - out : array_like, shape (n_pairs, n_methods, n_freqs_out) + out : tuple of list of array Connectivity estimates for each signal pair, method, and frequency or - frequency band. + frequency band. If bivariate methods are called, the output is a tuple + of a list of arrays containing the connectivity scores. If multivariate + methods are called, the output is a tuple of lists containing arrays + for the connectivity scores and patterns, respectively. """ if 'coh' in method: # psd @@ -564,18 +746,22 @@ def _parallel_con(w, method, kernel, foi_idx, source_idx, target_idx, n_jobs, else: psd = None - # only show progress if verbosity level is DEBUG - if verbose != 'DEBUG' and verbose != 'debug' and verbose != 10: - total = None + if not multivariate_con: + # only show progress if verbosity level is DEBUG + if verbose != 'DEBUG' and verbose != 'debug' and verbose != 10: + total = None + + # define the function to compute in parallel + parallel, my_pairwise_con, n_jobs = parallel_func( + _pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total) - # define the function to compute in parallel - parallel, my_pairwise_con, n_jobs = parallel_func( - _pairwise_con, n_jobs=n_jobs, verbose=verbose, total=total) + return tuple(parallel( + my_pairwise_con(w, psd, s, t, method, kernel, foi_idx, faverage, + weights) for s, t in zip(source_idx, target_idx))) - return parallel( - my_pairwise_con(w, psd, s, t, method, kernel, - foi_idx, faverage, weights) - for s, t in zip(source_idx, target_idx)) + return _multivariate_con(w, source_idx, target_idx, signals_use, method, + kernel, foi_idx, faverage, weights, gc_n_lags, + rank, n_jobs) def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, @@ -639,6 +825,96 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, return out +def _multivariate_con(w, source_idx, target_idx, signals_use, method, kernel, + foi_idx, faverage, weights, gc_n_lags, rank, n_jobs): + """Compute spectral connectivity metrics between multiple signals. + + Parameters + ---------- + w : array_like, shape (n_chans, n_tapers, n_freqs, n_times) + Time-frequency data. + x : int + Channel index. + y : int + Channel index. + method : str + Connectivity method. + kernel : array_like, shape (n_sm_fres, n_sm_times) + Smoothing kernel. + foi_idx : array_like, shape (n_foi, 2) + Upper and lower bound indices of frequency bands. + faverage : bool + Average over frequency bands. + weights : array_like, shape (n_tapers, n_freqs, n_times) | None + Multitaper weights. + + Returns + ------- + scores : list + List of connectivity scores between seed and target signals for each + connectivity method. Each element is an array with shape (n_freqs,) or + (n_fbands) depending on ``faverage``. + + patterns : list + List of connectivity patterns between seed and target signals for each + connectivity method. Each element is an array of length 2 corresponding + to the seed and target patterns, respectively, each with shape + (n_channels, n_freqs,) or (n_channels, n_fbands) depending on + ``faverage``. + """ + csd = [] + for x in signals_use: + for y in signals_use: + w_x, w_y = w[x], w[y] + if weights is not None: + s_xy = np.sum(weights * w_x * np.conj(weights * w_y), axis=0) + s_xy = s_xy * 2 / (weights * np.conj(weights)).real.sum(axis=0) + else: + s_xy = w_x * np.conj(w_y) + s_xy = np.squeeze(s_xy, axis=0) + csd.append(_smooth_spectra(s_xy, kernel).mean(axis=-1)) + csd = np.array(csd) + + # initialise connectivity estimators and add CSD information + conn_class = {'mic': _MICEst, 'mim': _MIMEst, 'gc': _GCEst, + 'gc_tr': _GCTREst} + conn = [] + for m in method: + # N_CONS = 1 UNTIL RAGGED ARRAYS SUPPORTED + call_params = {'n_signals': len(signals_use), 'n_cons': 1, + 'n_freqs': csd.shape[1], 'n_times': 0, + 'n_jobs': n_jobs} + if m in _gc_methods: + call_params['n_lags'] = gc_n_lags + con_est = conn_class[m](**call_params) + for con_i, con_csd in enumerate(csd): + con_est.accumulate(con_i, con_csd) + conn.append(con_est) + + # compute connectivity + scores = [] + patterns = [] + for con_est in conn: + con_est.compute_con(np.array([source_idx, target_idx]), rank) + scores.append(con_est.con_scores[..., np.newaxis]) + patterns.append(con_est.patterns) + if patterns[-1] is not None: + patterns[-1] = patterns[-1][..., np.newaxis] + + for i, _ in enumerate(scores): + # mean inside frequency sliding window (if needed) + if isinstance(foi_idx, np.ndarray) and faverage: + scores[i] = _foi_average(scores[i], foi_idx) + if patterns[i] is not None: + patterns[i] = _foi_average(patterns[i], foi_idx) + # squeeze time dimension + scores[i] = scores[i].squeeze(axis=-1) + if patterns[i] is not None: + patterns[i] = patterns[i].squeeze(axis=-1) + + return scores, patterns + + def _plv(s_xy): """Compute phase-locking value given the cross power spectral density.