Skip to content

Commit

Permalink
[GSOC] Add support for multiple components of multivariate connectivi…
Browse files Browse the repository at this point in the history
…ty (#213)

Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
  • Loading branch information
tsbinns and larsoner authored Jul 24, 2024
1 parent 43542ef commit 2eb8f02
Show file tree
Hide file tree
Showing 12 changed files with 1,001 additions and 501 deletions.
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"epochs",
"freqs",
"times",
"components",
"arrays",
"lists",
"func",
Expand Down
4 changes: 3 additions & 1 deletion examples/decoding/cohy_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@
mode="multitaper",
fmin=FMIN,
fmax=FMAX,
rank=(3, 3),
rank=(3, 3), # project to rank subspace to avoid overfitting to noise
n_components=1, # the data contains only one simulated component of connectivity
)

########################################################################################
Expand Down Expand Up @@ -371,6 +372,7 @@
fmin=FMIN,
fmax=FMAX,
rank=(3, 3),
n_components=1,
)

# Time fitting of filters
Expand Down
10 changes: 6 additions & 4 deletions examples/decoding/cohy_decomposition_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,12 @@
# bound between :math:`[-1, 1]`.
#
# Plotting the patterns for 20-30 Hz connectivity below, we find the strongest
# connectivity between the left and right hemispheres comes from centromedial left and
# frontolateral right sensors, based on the areas with the largest absolute values. As
# these patterns come from decomposition on sensor-space data, we make no assumptions
# about the underlying brain regions involved in this connectivity.
# connectivity ('MIC0', i.e. 1st component) between the left and right hemispheres comes
# from centromedial left and frontolateral right sensors, based on the areas with the
# largest absolute values. Patterns for the weaker connectivity components ('MIC1' &
# 'MIC2' are also shown). As these patterns come from decomposition on sensor-space
# data, we make no assumptions about the underlying brain regions involved in this
# connectivity.

# %%

Expand Down
111 changes: 66 additions & 45 deletions mne_connectivity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def __repr__(self) -> str:
if "times" in self.dims:
r += f"time : [{self.times[0]}, {self.times[-1]}], " # type: ignore
r += f", nave : {self.n_epochs_used}"
r += f", nodes, n_estimated : {self.n_nodes}, " f"{self.n_estimated_nodes}"
r += f", nodes, n_estimated : {self.n_nodes}, {self.n_estimated_nodes}"
if "components" in self.dims:
r += f", n_components : {len(self.coords['components'])}, "
r += f", ~{sizeof_fmt(self._size)}"
r += ">"
return r
Expand Down Expand Up @@ -488,6 +490,9 @@ def _prepare_xarray(
if self.is_epoched:
coords["epochs"] = list(map(str, range(data.shape[0])))
coords["node_in -> node_out"] = n_estimated_list
if "components" in kwargs:
coords["components"] = kwargs.pop("components")
dims.append("components")
if "freqs" in kwargs:
coords["freqs"] = kwargs.pop("freqs")
dims.append("freqs")
Expand Down Expand Up @@ -531,17 +536,17 @@ def _check_data_consistency(self, data, indices, n_nodes):
raise TypeError("Connectivity data must be passed in as a numpy array.")

if self.is_epoched:
if data.ndim < 2 or data.ndim > 4:
if data.ndim < 2 or data.ndim > 5:
raise RuntimeError(
"Data using an epoched data structure should have at least 2 "
f"dimensions and at most 4 dimensions. Your data was {data.shape} "
f"dimensions and at most 5 dimensions. Your data was {data.shape} "
"shape."
)
else:
if data.ndim > 3:
if data.ndim > 4:
raise RuntimeError(
"Data not using an epoched data structure should have at least 1 "
f"dimensions and at most 3 dimensions. Your data was {data.shape} "
f"dimensions and at most 4 dimensions. Your data was {data.shape} "
"shape."
)

Expand Down Expand Up @@ -709,11 +714,15 @@ def get_data(self, output="compact"):
if output == "raveled":
data = self._data
else:
if self.method in ["cacoh", "mic", "mim", "gc", "gc_tr"]:
# multivariate results cannot be returned in a dense form as a
# single set of results would correspond to multiple entries in
# the matrix, and there could also be cases where multiple
# results correspond to the same entries in the matrix.
if (
isinstance(self.indices, tuple)
and not isinstance(self.indices[0], int)
and not isinstance(self.indices[1], int)
): # i.e. check if multivariate results based on nested indices
# multivariate results cannot be returned in a dense form as a single
# set of results would correspond to multiple entries in the matrix, and
# there could also be cases where multiple results correspond to the
# same entries in the matrix.
raise ValueError(
"cannot return multivariate connectivity data in a dense form"
)
Expand All @@ -728,6 +737,8 @@ def get_data(self, output="compact"):
# and thus appends the connectivity matrices side by side, so the
# shape is N x N * lags
new_shape.extend([self.n_nodes, self.n_nodes])
if "components" in self.dims:
new_shape.append(len(self.coords["components"]))
if "freqs" in self.dims:
new_shape.append(len(self.coords["freqs"]))
if "times" in self.dims:
Expand Down Expand Up @@ -870,9 +881,10 @@ def save(self, fname):
class SpectralConnectivity(BaseConnectivity, SpectralMixin):
"""Spectral connectivity class.
This class stores connectivity data that varies over
frequencies. The underlying data is an array of shape
(n_connections, n_freqs), or (n_nodes, n_nodes, n_freqs).
This class stores connectivity data that varies over frequencies. The underlying
data is an array of shape (n_connections, [n_components], n_freqs), or (n_nodes,
n_nodes, [n_components], n_freqs). ``n_components`` is an optional dimension for
multivariate methods where each connection has multiple components of connectivity.
Parameters
----------
Expand Down Expand Up @@ -924,11 +936,12 @@ def __init__(
class TemporalConnectivity(BaseConnectivity, TimeMixin):
"""Temporal connectivity class.
This is an array of shape (n_connections, n_times),
or (n_nodes, n_nodes, n_times). This describes how connectivity
varies over time. It describes sample-by-sample time-varying
connectivity (usually on the order of milliseconds). Here
time (t=0) is the same for all connectivity measures.
This is an array of shape (n_connections, [n_components], n_times), or (n_nodes,
n_nodes, [n_components], n_times). This describes how connectivity varies over
time. It describes sample-by-sample time-varying connectivity (usually on the order
of milliseconds). Here time (t=0) is the same for all connectivity measures.
``n_components`` is an optional dimension for multivariate methods where each
connection has multiple components of connectivity.
Parameters
----------
Expand All @@ -943,12 +956,11 @@ class TemporalConnectivity(BaseConnectivity, TimeMixin):
Notes
-----
`mne_connectivity.EpochConnectivity` is a similar connectivity
class to this one. However, that describes one connectivity snapshot
for each epoch. These epochs might be chunks of time that have
different meaning for time ``t=0``. Epochs can mean separate trials,
where the beginning of the trial implies t=0. These Epochs may
also be discontiguous.
`mne_connectivity.EpochConnectivity` is a similar connectivity class to this one.
However, that describes one connectivity snapshot for each epoch. These epochs might
be chunks of time that have different meaning for time ``t=0``. Epochs can mean
separate trials, where the beginning of the trial implies t=0. These Epochs may also
be discontiguous.
"""

expected_n_dim = 2
Expand Down Expand Up @@ -980,13 +992,14 @@ def __init__(
class SpectroTemporalConnectivity(BaseConnectivity, SpectralMixin, TimeMixin):
"""Spectrotemporal connectivity class.
This class stores connectivity data that varies over both frequency
and time. The temporal part describes sample-by-sample time-varying
connectivity (usually on the order of milliseconds). Note the
difference relative to Epochs.
This class stores connectivity data that varies over both frequency and time. The
temporal part describes sample-by-sample time-varying connectivity (usually on the
order of milliseconds). Note the difference relative to Epochs.
The underlying data is an array of shape (n_connections, n_freqs,
n_times), or (n_nodes, n_nodes, n_freqs, n_times).
The underlying data is an array of shape (n_connections, [n_components], n_freqs,
n_times), or (n_nodes, n_nodes, [n_components], n_freqs, n_times). ``n_components``
is an optional dimension for multivariate methods where each connection has multiple
components of connectivity.
Parameters
----------
Expand Down Expand Up @@ -1038,9 +1051,11 @@ def __init__(
class EpochSpectralConnectivity(SpectralConnectivity):
"""Spectral connectivity class over Epochs.
This is an array of shape (n_epochs, n_connections, n_freqs),
or (n_epochs, n_nodes, n_nodes, n_freqs). This describes how
connectivity varies over frequencies for different epochs.
This is an array of shape (n_epochs, n_connections, [n_components], n_freqs), or
(n_epochs, n_nodes, n_nodes, [n_components], n_freqs). This describes how
connectivity varies over frequencies for different epochs. ``n_components`` is an
optional dimension for multivariate methods where each connection has multiple
components of connectivity.
Parameters
----------
Expand Down Expand Up @@ -1088,9 +1103,11 @@ def __init__(
class EpochTemporalConnectivity(TemporalConnectivity):
"""Temporal connectivity class over Epochs.
This is an array of shape (n_epochs, n_connections, n_times),
or (n_epochs, n_nodes, n_nodes, n_times). This describes how
connectivity varies over time for different epochs.
This is an array of shape (n_epochs, n_connections, [n_components], n_times), or
(n_epochs, n_nodes, n_nodes, [n_components], n_times). This describes how
connectivity varies over time for different epochs. ``n_components`` is an optional
dimension for multivariate methods where each connection has multiple components of
connectivity.
Parameters
----------
Expand Down Expand Up @@ -1129,9 +1146,11 @@ def __init__(
class EpochSpectroTemporalConnectivity(SpectroTemporalConnectivity):
"""Spectrotemporal connectivity class over Epochs.
This is an array of shape (n_epochs, n_connections, n_freqs, n_times),
or (n_epochs, n_nodes, n_nodes, n_freqs, n_times). This describes how
connectivity varies over frequencies and time for different epochs.
This is an array of shape (n_epochs, n_connections, [n_components], n_freqs,
n_times), or (n_epochs, n_nodes, n_nodes, [n_components], n_freqs, n_times). This
describes how connectivity varies over frequencies and time for different epochs.
``n_components`` is an optional dimension for multivariate methods where each
connection has multiple components of connectivity.
Parameters
----------
Expand Down Expand Up @@ -1178,9 +1197,10 @@ def __init__(
class Connectivity(BaseConnectivity):
"""Connectivity class without frequency or time component.
This is an array of shape (n_connections,),
or (n_nodes, n_nodes). This describes a connectivity matrix/graph
that does not vary over time, frequency, or epochs.
This is an array of shape (n_connections, [n_components]), or (n_nodes, n_nodes,
[n_components]). This describes a connectivity matrix/graph that does not vary
over time, frequency, or epochs. ``n_components`` is an optional dimension for
multivariate methods where each connection has multiple components of connectivity.
Parameters
----------
Expand Down Expand Up @@ -1222,9 +1242,10 @@ def __init__(
class EpochConnectivity(BaseConnectivity):
"""Epoch connectivity class.
This is an array of shape (n_epochs, n_connections),
or (n_epochs, n_nodes, n_nodes). This describes how
connectivity varies for different epochs.
This is an array of shape (n_epochs, n_connections, [n_components]), or (n_epochs,
n_nodes, n_nodes, [n_components]). This describes how connectivity varies for
different epochs. ``n_components`` is an optional dimension for multivariate methods
where each connection has multiple components of connectivity.
Parameters
----------
Expand Down
Loading

0 comments on commit 2eb8f02

Please sign in to comment.