mne-connectivity
mne-connectivity copied to clipboard
[GSOC] implement example of state-space model for connectivity
PR Description
Google Summer of Code (2022) project
Closes #99
WIP: Linear Dynamic System (state-space model using EM algorithm to find autoregressive coefficients) to infer functional connectivity by interpreting autoregressive coefficients as connectivity strength. The model uses M/EEG data as input, and outputs time-varying autoregressive coefficients for source space labels.
Completed during GSoC
- [x] A user-friendly API that allows the user to work easily with MEG and/or EEG EEG data, following MNE-Python's local standards and conventions for usability as much as possible:
- Most of the code complexity is hidden from the user in the backend for the simplest interface:
data_path = mne.datasets.sample.data_path()
raw = mne.io.read_raw_fif(raw_fname).crop(tmax=60)
events = mne.find_events(raw, ...)
event_dict = {...}
epochs = mne.Epochs(raw, events, tmin=-0.2, tmax=0.7, event_id=event_dict,
preload=True).pick_types(meg=True,eeg=True)
fwd_fname = sample_folder / '....'
fwd = mne.read_forward_solution(fwd_fname)
cov_fname = sample_folder / 'sample_audvis-cov.fif'
cov = mne.read_cov(cov_fname)
label_names = ['Aud-lh', 'Aud-rh', 'Vis-lh', 'Vis-rh']
labels = [mne.read_label(sample_folder / 'labels' / f'{label}.label',
subject='sample') for label in label_names]
model = LDS(lam0=0, lam1=100)
model.add_subject('sample', condition, epochs, labels, fwd, cov)
model.fit(...)
model.fit()
At = model.A
assert At.shape == (len(labels), len(labels), len(epochs.times))
- [x] Preprocess it to a format meant to increase the SNR of the data
- [x] Downsample the data for faster processing.
- [x] Utilizes forward and covariance matrices in the API well as the labels for the regions of interest (ROIs) of the dataset
- [x] PCA is used to reduce the dimensionality of the data
- [x]
model.fitusing the Expectation Maximization algorithm to fit the autoregressive coefficients of the state-space model, mapping the sensor data to each ROI, and computing the connectivity strength between ROI pairs - [x] Plotting of the time-varying coefficients in a matrix format to observe the strength of connection between each pair of ROIs
- [x] An example script showing how to use the function and interpret its outputs using the MNE-Python
sampledataset - [x] Basic unit tests have been written and partially incorporated
Check-out this link to see my weekly progress. All of the code in this PR is new to MNE-Python's repositories.
Todo after GSoC
- [ ] Finish unit tests
- [ ] Finish replacing redundant functionality with MNE-Python equivalents (e.g.,
scale_datawithcompute_whitenerand dot products) - [ ] The current implementation uses
autogradas a dependency, which is no longer actively developed (but still maintained). The code should be updated to incorporateJAX, which includes the features ofautogradand is being actively developed (autograd_linalgshould be replaced byscipy.linalg).
Merge checklist
Maintainer, please confirm the following before merging:
- [ ] All comments resolved
- [ ] This is not your own PR
- [ ] All CIs are happy
- [ ] PR title starts with [MRG]
- [ ] whats_new.rst is updated
- [ ] PR description includes phrase "closes <#issue-number>"
Hello! Just wanted to say hi to include myself in the loop 😄. Could someone quickly explain what the aim of this PR is? Is it to add VAR model based connectivity estimates? State-space sounds like it's implemented as a Kalman filter. At the risk of sounding repetitive, but could we look at https://github.com/scot-dev/scot to see if anything could be re-used here? I've implemented least squares VAR estimation (optionally with regularization) to compute several popular (directed) connectivity measures (for a list see https://scot-dev.github.io/scot-doc/api/scot/scot.html#module-scot.connectivity).
Hi @cbrnr you are correct that this PR aims to implement a Kalman filter using an AR model to measure connectivity. Reviewing SCoT is still on my to-do list, thanks for the reminder.
From a quick chat with Jordan, here is what we fleshed out a bit based on my suggestion for the public API:
Internal implementation sketch and public API
# Internal code
class MEGLDS:
def __init__(self, ...):
...
self._subject_data = dict()
def add_subject(subject, forward, cov, ...):
self._subject_data[subject] = dict()
self._subject_data[subject]['G'] = something(forward, ...)
self._subject_data[subject]['C'] = something_else(cov, ...)
def fit(self):
Gs = np.array([val['G'] for val in self._subject_data.values()])
...
# User API should only have:
# - subjects_dir, then per-subject:
# - subject
# - Forward
# - Covariance
# - Epochs
# - list of Label
data_path = mne.datasets.sample.data_path()
subjects_dir = data_path / 'subjects'
model = MEGLDS(lambda0, lambda1, ...)
forward = mne.read_forward_solution(...)
cov = mne.read_covariance(...)
labels = list()
label_names = ('Aud.lh', 'Aud.rh', 'Visual.lh', 'Visual.rh')
for name in label_names:
labels.append(mne.read_label(subjects_dir / 'sample' / 'labels' / name))
model.add_subject('sample', subjects_dir=subjects_dir, labels=labels,
forward=forward, cov=cov)
...
model.fit()
At = model.A
assert At.shape == (len(labels), len(labels), len(epochs.times))
EDIT: I resolved the conversations above where I talk about this since I think it's general enough to discuss in this main thread rather than inline
examples/state_space_connectivity.py is functioning properly on my machine with the output depicted below. I think for a single subject, these results look good. Along the diagonal, values are close to 1, as expected for a computation similar to an autocorrelation. For the condition auditory/left there seems to be a connection from Aud-lh to Aud-rh as seen by the non-zero values in graph [0,1]. I expect measurements to be less noisy when running for a large number of subjects. Please run CI checks.
x-axis: time (seconds); y-axis: connectivity strength (autoregressive coefficients)

I am at a research conference for the next 7 days. When I return I have the following to-do:
- work to replace
scale_sensor_datawithmne.make_ad_hoc_cov()andcompute_whitener() - test for different conditions; see how results change
- use different dataset with multiple subjects for sanity check
Looking forward to your feedback!
I am currently working to incorporate an old dataset to see if these scripts produce the expected results.
Hey @adam2392 one of the CI errors is due to autograd: ModuleNotFoundError: No module named 'autograd'. I'd really like to move to the next step of the proposal and work to integrate jax (to replace autograd) later on in order to keep pace with the summer milestones. Is it alright if I install autograd in the dependencies for mnedev (assuming that will fix the error) for now and work to integrate jax later on?
Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on.
Perhaps we can start another sep issue to track migration to jax later on?
Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on.
Perhaps we can start another sep issue to track migration to jax later on?
@adam2392 Cool! I do not where to add autograd as an optional Dev dependency. Should it be installed in mnedev?
Where should I leave the comment that it can be replaced with jax later on? Just start a new issue on Github with the suggestion?
Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on. Perhaps we can start another sep issue to track migration to jax later on?
@adam2392 Cool! I do not where to add autograd as an optional Dev dependency. Should it be installed in mnedev?
You can add it here: https://github.com/mne-tools/mne-connectivity/blob/main/requirements_testing.txt and then just add a comment e.g.
...
# TODO: we will replace this with Jax
autograd
Then when you import autograd, you should maybe import it within the functions you want to use and not at the top of the Python file. That way there is no error when someone tries to run import mne_connectivity if they don't have autograd/jax installed. For example, this function inside MNE-Python needs pyqt, but imports it within the function so mne still works if the user only has numpy/scipy installed.
Where should I leave the comment that it can be replaced with jax later on? Just start a new issue on Github with the suggestion?
Yeah I'll create the GH issue.
Yeah sure! You can include it as an optional Dev dependency. Do you know where to add? Please also add a comment if you can that we replace it later on. Perhaps we can start another sep issue to track migration to jax later on?
@adam2392 Cool! I do not where to add autograd as an optional Dev dependency. Should it be installed in mnedev?
You can add it here: https://github.com/mne-tools/mne-connectivity/blob/main/requirements_testing.txt and then just add a comment e.g.
... # TODO: we will replace this with Jax autogradThen when you import autograd, you should maybe import it within the functions you want to use and not at the top of the Python file. That way there is no error when someone tries to run
import mne_connectivityif they don't have autograd/jax installed. For example, this function inside MNE-Python needspyqt, but imports it within the function so mne still works if the user only hasnumpy/scipyinstalled.Where should I leave the comment that it can be replaced with jax later on? Just start a new issue on Github with the suggestion?
Yeah I'll create the GH issue.
Ok I think that's complete. Thanks!
Update: I got the code to work for an old dataset, however the results are not what I expected. I am currently working to get the sample data (much smaller than the old dataset) to work on my original model. This output will be my new ground truth and I can work iteratively to make sure each edit I make to the original model to conform it the MNE-Python API standards produces the same output. Lesson learned - have a simple ground truth to work with from the beginning :)
Here is the output of my original model using the sample dataset for the auditory/left condition for the 12 ROIs from a different dataset. My next step is to get this same output in the mnedev environment. Then I will work to use only the 4 ROIs commonly used with the sample dataset. Then piece by piece I will recreate the API.

Excellent!
A good next step is to make sure all random seeds can be set such that if you run this again you get the exact same output (to numerical precision at least). Then you can save the At result to disk now and compare to it each time you replace some piece of code
I have changed the PCA method from being based on the rank of the matrix to a method based on explaining 99% of the variance. This method allows the fitting of the model to run much faster as it produces 147 principal components vs 360 components produced from the rank method. The model output is noticeably but not extremely different. My next step is to perform the processing steps using the 4 labels provided in the sample dataset, which should reduce processing time even further. All processing was completed within the mnedev conda environment.

Processing completed with 4 labels from sample.

Bootstrapping of epochs, and PCA of epochs.get_data() and forward matrices completed in API. Model fitting completed in command line model. Output from API (LDS) compared to command line model (MEGLDS) are not identical but extremely similar.

Nice!
It would be good to know what the differences are that make them not identical, but really if this version of the API works on our UW data as well (maybe even just for one subject?) then I'd say you could use this as the "ground truth" for correctness of additional changes!
@larsoner Can you look at megssm/mne_util.py L114. Am I using the scaler correctly? Because the results do not agree with the original _scale_sensor_data (L140). Thanks.
I'll get to the CI checks first thing Monday!
Hi @jadrew43 and @larsoner any help needed to review code / look at prelim results here? Feel free to lmk where I can help!
IIRC code is still WIP / needs to be systematically converted to MNE conventions, but some bugs have been found along the way which is good!