Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phase-slope index using spectral_connectivity_time instead of spectral_connectivity_epochs #210

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
196 changes: 194 additions & 2 deletions mne_connectivity/effective.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
import numpy as np
from mne.utils import logger, verbose

from .base import SpectralConnectivity, SpectroTemporalConnectivity
from .spectral import spectral_connectivity_epochs
from .base import (
EpochSpectralConnectivity,
SpectralConnectivity,
SpectroTemporalConnectivity,
)
seqasim marked this conversation as resolved.
Show resolved Hide resolved
from .spectral import spectral_connectivity_epochs, spectral_connectivity_time
from .utils import fill_doc


Expand Down Expand Up @@ -243,3 +247,191 @@ def phase_slope_index(
)

return conn


@verbose
@fill_doc
def phase_slope_index_time(data,
names=None,
indices=None,
sfreq=2 * np.pi,
mode="multitaper",
fmin=None,
fmax=np.inf,
mt_bandwidth=None,
freqs=None,
n_cycles=7,
padding=0,
n_jobs=1,
):
"""Compute the Phase Slope Index (PSI) connectivity measure across time rather than epochs.
seqasim marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think the default should be "cwt_morlet", like in spectral_connectivity_time().

seqasim marked this conversation as resolved.
Show resolved Hide resolved

The PSI is an effective connectivity measure, i.e., a measure which can
give an indication of the direction of the information flow (causality).
For two time series, and one computes the PSI between the first and the
second time series as follows

indices = (np.array([0]), np.array([1]))
psi = phase_slope_index(data, indices=indices, ...)

A positive value means that time series 0 is ahead of time series 1 and
a negative value means the opposite.

The PSI is computed from the coherency (see spectral_connectivity_epochs),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make sure we're referencing the right function, and while doing that, let's make it a link in the docs (though I see it isn't already in phase_slope_index).

Suggested change
The PSI is computed from the coherency (see spectral_connectivity_epochs),
The PSI is computed from the coherency (see :class:`spectral_connectivity_time`),

details can be found in :footcite:`NolteEtAl2008`.

This function computes PSI over time from epoched data.
The data may consist of a single epoch.
seqasim marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
data : array-like, shape=(n_epochs, n_signals, n_times)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also include Epochs (again, I realise you took this from the existing function which doesn't have this, so should change there too).

Can also be a list/generator of array, shape =(n_signals, n_times);
list/generator of SourceEstimate; or Epochs.
The data from which to compute connectivity. Note that it is also
possible to combine multiple signals by providing a list of tuples,
e.g., data = [(arr_0, stc_0), (arr_1, stc_1), (arr_2, stc_2)],
corresponds to 3 epochs, and arr_* could be an array with the same
number of time points as stc_*.
%(names)s
indices : tuple of array | None
Two arrays with indices of connections for which to compute
connectivity. If None, all connections are computed.
sfreq : float
The sampling frequency.
mode : str
Spectrum estimation mode can be either: 'multitaper', 'fourier', or
'cwt_morlet'.
Comment on lines +303 to +305
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No Fourier mode for spec_conn_time, but can also make this a little more consistent with the description given there.

Suggested change
mode : str
Spectrum estimation mode can be either: 'multitaper', 'fourier', or
'cwt_morlet'.
mode : str
Time-frequency decomposition method. Can be either: 'multitaper', or
'cwt_morlet'.

fmin : float | tuple of float
The lower frequency of interest. Multiple bands are defined using
a tuple, e.g., (8., 20.) for two bands with 8Hz and 20Hz lower freq.
If None the frequency corresponding to an epoch length of 5 cycles
is used.
fmax : float | tuple of float
The upper frequency of interest. Multiple bands are dedined using
a tuple, e.g. (13., 30.) for two band with 13Hz and 30Hz upper freq.
mt_bandwidth : float | None
The bandwidth of the multitaper windowing function in Hz.
Only used in 'multitaper' mode.
freqs : array
Array of frequencies of interest. Only used in 'cwt_morlet' mode.
n_cycles : float | array of float
Number of cycles. Fixed number or one per frequency. Only used in
'cwt_morlet' mode.
n_jobs : int
How many epochs to process in parallel.
%(verbose)s

Returns
-------
conn : instance of Connectivity
seqasim marked this conversation as resolved.
Show resolved Hide resolved
Computed connectivity measure(s). ``EpochSpectralConnectivity``
container. The shape of each array is
(n_signals ** 2, n_bands, n_epochs)
when "indices" is None, or
(n_con, n_bands, n_epochs)
when "indices" is specified and "n_con = len(indices[0])".

See Also
--------
mne_connectivity.EpochSpectralConnectivity
Comment on lines +336 to +338
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it's good to reiterate that spectral_connectivity_time() is a relevant function. Though I realise you based this on the existing function docstring which doesn't link to spec_conn_epochs().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is a good idea.


References
----------
.. footbibliography::
"""
logger.info("Estimating phase slope index (PSI) across time")

# estimate the coherency

cohy = spectral_connectivity_time(
data,
freqs=freqs,
method="cohy",
average=False,
indices=indices,
sfreq=sfreq,
fmin=fmin,
fmax=fmax,
fskip=0,
faverage=False,
sm_times=0,
sm_freqs=1,
sm_kernel="hanning",
padding=padding,
Comment on lines +360 to +363
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we should expose more of these options to the user, e.g. why just padding and not the smoothing stuff?

mode=mode,
mt_bandwidth=mt_bandwidth,
n_cycles=n_cycles,
decim=1,
n_jobs=n_jobs,
verbose=None,
)

freqs_ = np.array(cohy.freqs)
names = cohy.names
n_tapers = cohy.attrs.get("n_tapers")
n_nodes = cohy.n_nodes
metadata = cohy.metadata
events = cohy.events
event_id = cohy.event_id

logger.info(f"Computing PSI from estimated Coherency: {cohy}")
# compute PSI in the requested bands
if fmin is None:
fmin = -np.inf # set it to -inf, so we can adjust it later

bands = list(zip(np.asarray((fmin,)).ravel(), np.asarray((fmax,)).ravel()))
n_bands = len(bands)

freq_dim = -2

# allocate space for output
out_shape = list(cohy.shape)
out_shape[freq_dim] = n_bands
psi = np.zeros(out_shape, dtype=np.float64)

# allocate accumulator
acc_shape = copy.copy(out_shape)
acc_shape.pop(freq_dim)
acc = np.empty(acc_shape, dtype=np.complex128)

# create list for frequencies used and frequency bands
# of resulting connectivity data
freqs = list()
freq_bands = list()
idx_fi = [slice(None)] * len(out_shape)
idx_fj = [slice(None)] * len(out_shape)
for band_idx, band in enumerate(bands):
freq_idx = np.where((freqs_ > band[0]) & (freqs_ < band[1]))[0]
freqs.append(freqs_[freq_idx])
freq_bands.append(np.mean(freqs_[freq_idx]))

acc.fill(0.0)
for fi, fj in zip(freq_idx, freq_idx[1:]):
idx_fi[freq_dim] = fi
idx_fj[freq_dim] = fj
acc += (
np.conj(cohy.get_data()[tuple(idx_fi)]) * cohy.get_data()[tuple(idx_fj)]
)

idx_fi[freq_dim] = band_idx
psi[tuple(idx_fi)] = np.imag(acc)
logger.info("[PSI Estimation Done]")

# create a connectivity container
conn = EpochSpectralConnectivity(
data=psi,
names=names,
freqs=freq_bands,
n_nodes=n_nodes,
method="phase-slope-index",
spec_method=mode,
indices=indices,
freqs_computed=freqs,
n_tapers=n_tapers,
metadata=metadata,
events=events,
event_id=event_id,
)

return conn
33 changes: 29 additions & 4 deletions mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def spectral_connectivity_time(
conn_patterns = dict()
for m in method:
# CaCoh complex-valued, all other methods real-valued
seqasim marked this conversation as resolved.
Show resolved Hide resolved
if m == "cacoh":
if m in ["cacoh", "cohy"]:
con_scores_dtype = np.complex128
else:
con_scores_dtype = np.float64
Expand Down Expand Up @@ -904,7 +904,7 @@ def _parallel_con(
methods are called, the output is a tuple of lists containing arrays
for the connectivity scores and patterns, respectively.
"""
if "coh" in method:
if ("coh" in method) or ("cohy" in method):
# psd
if weights is not None:
psd = weights * w
Expand Down Expand Up @@ -995,9 +995,9 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, faverage, weights):
s_xy = np.squeeze(s_xy, axis=0)
s_xy = _smooth_spectra(s_xy, kernel)
out = []
conn_func = {"plv": _plv, "ciplv": _ciplv, "pli": _pli, "wpli": _wpli, "coh": _coh}
conn_func = {"plv": _plv, "ciplv": _ciplv, "pli": _pli, "wpli": _wpli, "coh": _coh, "cohy": _cohy}
for m in method:
if m == "coh":
if m in ["coh", "cohy"]:
s_xx = psd[x]
s_yy = psd[y]
out.append(conn_func[m](s_xx, s_yy, s_xy))
Expand Down Expand Up @@ -1234,6 +1234,31 @@ def _coh(s_xx, s_yy, s_xy):
coh = con_num / con_den
return coh

def _cohy(s_xx, s_yy, s_xy):
"""Compute coherencey given the cross spectral density and PSD.

Parameters
----------
s_xx : array-like, shape (n_freqs, n_times)
The PSD of channel 'x'.
s_yy : array-like, shape (n_freqs, n_times)
The PSD of channel 'y'.
s_xy : array-like, shape (n_freqs, n_times)
The cross PSD between channel 'x' and channel 'y' across
frequency and time points.

Returns
-------
cohy : array-like, shape (n_freqs, n_times)
The estimated COHY.
"""
con_num = s_xy.mean(axis=-1, keepdims=True)
con_den = np.sqrt(
s_xx.mean(axis=-1, keepdims=True) * s_yy.mean(axis=-1, keepdims=True)
)
cohy = con_num / con_den
return cohy
seqasim marked this conversation as resolved.
Show resolved Hide resolved


def _compute_csd(x, y, weights):
"""Compute cross spectral density between signals x and y."""
Expand Down
39 changes: 38 additions & 1 deletion mne_connectivity/tests/test_effective.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from numpy.testing import assert_array_almost_equal

from mne_connectivity.effective import phase_slope_index
from mne_connectivity.effective import phase_slope_index, phase_slope_index_time


def test_psi():
Expand Down Expand Up @@ -39,3 +39,40 @@ def test_psi():

assert np.all(conn_cwt.get_data() > 0)
assert conn_cwt.shape[-1] == n_times


def test_psi_time():
"""Test Phase Slope Index (PSI) estimation across time."""
Comment on lines +44 to +45
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick glance the test looks good to me, let's see if it passes once the verbose param is added to the new function.

sfreq = 50.0
n_signals = 3
n_epochs = 10
n_times = 500
rng = np.random.RandomState(42)
data = rng.randn(n_epochs, n_signals, n_times)

# simulate time shifts
for i in range(n_epochs):
data[i, 1, 10:] = data[i, 0, :-10] # signal 0 is ahead
data[i, 2, :-10] = data[i, 0, 10:] # signal 2 is ahead

conn = phase_slope_index_time(data, mode="fourier", sfreq=sfreq)

assert conn.get_data(output="dense")[1, 0, 0] < 0
assert conn.get_data(output="dense")[2, 0, 0] > 0

# only compute for a subset of the indices
indices = (np.array([0]), np.array([1]))
conn_2 = phase_slope_index_time(data, mode="fourier", sfreq=sfreq, indices=indices)

# the measure is symmetric (sign flip)
assert_array_almost_equal(
conn_2.get_data()[0, 0], -conn.get_data(output="dense")[1, 0, 0]
)

freqs = np.arange(5.0, 20, 0.5)
conn_cwt = phase_slope_index_time(
data, mode="cwt_morlet", sfreq=sfreq, freqs=freqs, indices=indices
)

assert np.all(conn_cwt.get_data() > 0)
assert conn_cwt.shape[-1] == n_epochs
Loading