Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSOC] Add a 'decoding' module to expand flexibility of multivariate methods #182

Closed
tsbinns opened this issue May 22, 2024 · 1 comment
Closed

Comments

@tsbinns
Copy link
Collaborator

tsbinns commented May 22, 2024

Background

@larsoner @wmvanvliet @drammock

A key limitation with the existing implementation of multivariate connectivity methods in spectral_connectivity_epochs() and spectral_connectivity_time() is the incompatibility with standard decoding paradigms, e.g. fitting filters to some training data and applying them to other data. Note that this applies only to those methods which use filters (CaCoh and MIC; not MIM or GC). Additionally, the low computational cost of this approach (avoiding repeated re-calculation of filters) would make it much more feasible for real-time analyses e.g. for BCIs.

Following the decoding module of MNE-Python, a decoding module in MNE-Connectivity can be created containing classes for each compatible connectivity method, i.e. a CaCoh class and a MIC class. These classes would support:

  • fitting filters to epoched data to optimise connectivity (fit() method); would also store the corresponding patterns of connectivity.
  • transforming data based on the fitted filters (transform() method).
  • visualising the filters and patterns as topomaps (plot_filters() and plot_patterns() methods).

Both classes will have the same methods & parameters, and also work very similarly internally, allowing for a base class that handles most of the functionality. I envision the following structure (will denote points I think are especially worth discussing with bold italics):


Class initialisation

Similar to the classes in MNE-Python's decoding module, parameters could include:

  • Info object containing key information for the epoched data that will be fit to/transformed.
  • indices specifying the channels to treat as seeds and targets.
  • fmin and fmax specifying the frequency band in which connectivity should be optimised.
    • like for the decoding.SSD class, I think fitting filters for only a single frequency band is best.
  • tmin and tmax specifying the time range to use when computing the CSD; default the whole time period.
  • n_components specifying the number of filters to fit; default as many components as the rank of the data that will be fit to (so the exact number will need to be determined when the fit() method is called).
  • rank specifying any rank subspace projection to perform on the CSD for the seeds and targets; default projecting only to rank of the data that will be fit to (so again the exact number must be determined later).
  • options for determining the CSD computation.
    • here there could be the option of using a Fourier or multitaper approach, but I am unsure about Morlet wavelets since I don't think it makes much sense in this context to have time-resolved filters (also not supported in the SSD class).

Here it would just be the case of checking the validity of all input parameters, elements of which already exist within MNE-Connectivity and MNE-Python.


fit method

Similar to the classes in MNE-Python's decoding module, parameters could include:

  • X the epochs array of data to fit filters to.
  • y only included for sci-kit learn compatibility; default None

Processing steps could include:

  • check validity of data (i.e. is it a 3D array, do number of channels match indices, etc...).
  • check pre-specified rank input against data (if default None, need to compute rank).
  • instantiate connectivity estimator class (e.g. _CaCohEst and _MICEst from epochs_multivariate.py).
  • compute CSD for frequencies between fmin and fmax.
    • big question is whether MNE-Python's dedicated classes are used, or the code that exists internally within MNE-Connectivity.
    • as it currently stands, the _epoch_spectral_connectivity() function in epochs.py where CSD computation is handled has a lot of superfluous code, so we would need to either refactor this to make it work for both traditional and decoding approaches, or write a new function tailored to the decoding approach with a lot of repeated code.
    • perhaps then using the dedicated classes from MNE-Python would be simpler??
  • store CSD in the connectivity estimator class.
  • compute filters and patterns using the compute_con of the estimator class.
    • already patterns are stored in this class, so we would just need to also store filters (and make sure the filters are converted from rank subspace back to channel space if required, as we already do for patterns).
    • to save memory, we could add a flag when instantiating the connectivity estimator classes that is by default False as to whether we need to allocate an array for the filters and just pass True when using the decoding approach.

Technical point: what exactly do we fit filters for?

  • the CaCoh and MIC methods are able to optimise connectivity for individual frequency bins, but in this 'decoding' context, I do not think this is the best approach since fitting filters to all frequency bins within a given band is not very efficient.
  • instead, once the CSD has been computed for the frequency bins, this information can be combined into a single bin by summing over all others, leaving a single 'frequency' which the filters are fit to.
  • this is a perfectly reasonable approach if we consider the frequency band we are trying to optimise as a single entity capturing the information of interest, and follows the same philosophy as other eigendecomposition-based optimisation problems like SSD or CSP.
  • in addition to the low computational cost associated with only needing to optimise a 'single frequency', it also fits with the idea that this 'decoding' approach is used once you already have an idea about some frequency band of interest in your data e.g. following some pilot study (whereas the existing implementations in spectral_connectivity_epochs() and spectral_connectivity_time() can be used for a more exploratory analysis).

Ultimately I think treating the frequency band as a single entity and creating a single set of filters for this band collectively is the best approach and keeps things very consistent with the behaviour of SSD and CSP in MNE-Python's decoding module. Happy to discuss the technical details further if people have concerns about this.


transform method

Similar to the classes in MNE-Python's decoding module, parameters could include:

  • X the epochs array of data to transform using the fitted filters.

The function would return epoched data with shape [epochs x components x times], which simply involves multiplying the data with the filters and returning up to n_components as specified earlier.


plot_filters and plot_patterns methods

The overall layout of these functions can generally follow those in the CSP class, but with the difference that we need to handle plotting filters/patterns for both the seeds and targets.

My initial thought is to have separate plots returned for the seed components and target components.


Examples/tutorials

Examples for this new 'decoding' approach will be created demonstrating how the classes can be used to fit to/transform data. Other features could be:

  • showing how connectivity can be computed on the transformed data.
    • e.g. if you had transformed the data using a filter fit with the CaCoh class, you could call spectral_connectivity_epochs() or spectral_connectivity_time() and use the corresponding bivariate methods cohy or coh, or if you had fit filters with the MIC class you would use the imcoh method.
  • showing the speed increase of the 'decoding' approach vs. the existing implementations.
  • showing that the 'decoding' approach is a valid way of optimising connectivity.
    • e.g. you could fit filters to one set of epochs and apply the filters to that same set of epochs, and show that the connectivity you get from the transformed data is very equivalent to what you see from the existing implementations.

Unit tests

Unit tests would be added for the new 'decoding' classes, covering many of the same points as the unit tests for the existing implementations.


Other important considerations

Right now, the filters and patterns of only 1 component are being used in the connectivity estimator classes. This will be addressed by another aim of the GSoC project where users will have the option to return connectivity and filters/patterns for more than just the first component (see #183).


Timeline

This was estimated to be the largest part of the GSoC project, with time allocated until the start of July. However, time will also be spent once support for >1 component is added (as described above) for updating the unit tests and examples.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant