mne-python
mne-python copied to clipboard
XFit-style dipole fitting GUI
This adds a GUI to perform guided dipole modeling in the spirit of MEGIN's XFit program. This PR contains the base functionality needed to make the GUI useful. Useful enough to include in the next release of MNE-Python. The plan is to keep adding features in future PRs as well as some sorely needed speed improvements.
See here for a list of currently supported features: #11977 This PR depends on: #12071
Minimal example:
import mne
data_path = mne.datasets.sample.data_path()
evoked = mne.read_evokeds(
f"{data_path}/MEG/sample/sample_audvis-ave.fif", condition="Left Auditory"
)
evoked.apply_baseline((-0.2, 0))
trans = mne.read_trans(f"{data_path}/MEG/sample/sample_audvis_raw-trans.fif")
cov = mne.read_cov(f"{data_path}/MEG/sample/sample_audvis-cov.fif")
bem = mne.read_bem_solution(f"{data_path}/subjects/sample/bem/sample-5120-5120-5120-bem-sol.fif")
subject = "sample"
subjects_dir = data_path / "subjects"
evoked.pick("grad")
# Quick and dirty. No MRI whatsoever.
g = mne.gui.DipoleFitUI(evoked)
# Slow and proper. With proper BEM model.
# g = mne.gui.DipoleFitUI(evoked, trans=trans, cov=cov, bem=bem, subject=subject, subjects_dir=subjects_dir)
Todo:
- [ ] Unit tests
- [ ] Proper documentation page
- [ ] Towncrier
Great to see progress here, let me know when it would be useful to get usability and/or code feedback!
The actual functionality is ready for testing. I'd love some feedback on usability, ideas for improvements and overall thoughts.
Would it be worth discussing if the various GUIs should be moved to a separate package? I feel like things like dipole fitting don't have to live in the main MNE-Python package. Thoughts?
Interactive dipole fitting is perhaps the oldest source localization method and still quite widely used, so to me it's within scope to put it in in MNE-Python itself. To me it's something we've always been missing / a deficiency of our software. I don't see too many advantages to splitting this one off unless @wmvanvliet plans to need to iterate and release faster than MNE-Python itself.
Would it be worth discussing if the various GUIs should be moved to a separate package?
It's certainly worth discussing. As @larsoner says, well established source estimation techniques should be within scope of MNE-Python. It may at some point still be wise to split it off into its own package when we want to become more ambitious with its interface. The precedent for this is mne-qt-browser. As long as the functionality provided by this GUI works within the Figure3d class, all is good, and I think this will be the case for the foreseable future. When we want access to the full Qt functionality we should probably split off.
@larsoner and everyone: could you help brainstorm a bit?
The functionality is now complete for a first release. Here is a test script I've been using to set up the GUI:
import mne
path = mne.datasets.sample.data_path()
meg_dir = path / "MEG" / "sample"
subjects_dir = path / "subjects"
bem = mne.read_bem_solution(subjects_dir / "sample" / "bem" / "sample-5120-bem-sol.fif")
cov = mne.read_cov(meg_dir / "sample_audvis-cov.fif")
trans = mne.read_trans(meg_dir / "sample_audvis_raw-trans.fif")
evoked = mne.read_evokeds(meg_dir / "sample_audvis-ave.fif", condition="Left Auditory")
evoked = evoked.pick("meg").apply_baseline((-0.2, 0))
evoked = evoked.interpolate_bads()
cov = cov.pick_channels(evoked.ch_names)
inv = mne.minimum_norm.read_inverse_operator(meg_dir / "sample_audvis-meg-oct-6-meg-inv.fif")
stc = mne.minimum_norm.apply_inverse(evoked, inv)
g = mne.gui.dipolefit(evoked, cov=cov, stc=stc, bem=bem, trans=trans, subject="sample", subjects_dir=subjects_dir, n_jobs=-1)
It works great, however fitting dipoles is painfully slow. I've identified some bottlenecks:
- While optimizing the dipole position, every iteration a check is performed testing if the candidate dipole position is inside the mesh: https://github.com/mne-tools/mne-python/blob/ace9381898b3ad9da9590a3d0b99abedaa74e033/mne/dipole.py#L1271 which is a very expensive operation.
- After fitting the dipoles, we need the timecourses, which is done through MNE. Every time we change something about the dipoles, the forward model needs to be updated appropriately. Right now we just call
mne.make_forward_dipolefrom scratch which takes a long time as it needs to perform all kinds of checks and set up everything.
Any ideas on how to improve speed? Xfit is lightning fast and I would like our implementation to at least approximate that.
For the interior check, we should use a single instance of _CheckInside and reuse it. That class does some neat shortcuts to check interior-ness quickly. By either 1) not using it or 2) reconstructing it every time (not sure offhand which of these _points_outside_surface does!) we are doing things slower than if we reused a single instance of _CheckInside. AFAIK Xfit doesn't do any check like this (?), so if we really want, we could add an checkbox Check dipole interior or whatever to allow people to disable the check if they want faster fitting, and would allow you to compare to Xfit better. But it seems buggy to allow points to go outside the surface so we probably do want to warn people that it's not a good idea, leave it checked by default etc...
For make_forward_dipole, we should refactor it into two stages: a prep stage and a computation stage. Something like this is done in mne.simulation already for different head positions so hopefully the private functions are already there and it's not too much work to make the transition. We should see how close that gets us in terms of speed. Then we can see if there are additional steps we can split into ones that can be done once, and ones that must be recomputed each time. For example I suspect some computations having to do with the BEM surfaces and sensors / volume currents could be computed once and reused, leaving just the parts having to do with individual dipoles / primary currents (src points) to be computed. This would probably be an even bigger gain, and is not IIRC done by mne.simulation so would require a bit of code refactoring / private function use again.
Just a final comment about the setup script above -- we probably want to long-term do something like we do with mne coreg where we have mne dipole ... or similar where you can pass like --cov <path> and have it launch a GUI with the cov set etc. I would consider this a higher priority than the optimization stuff above unless the slowness of usage is really a show-stopper for your use cases.
mne dipolefit command line utility added
I've added a tutorial that showcases the features the GUI currently supports: https://output.circle-artifacts.com/output/job/2e331c1c-0641-42d8-8107-e4451230b6cd/artifacts/0/html/auto_tutorials/inverse/21_interactive_dipole_fit.html#sphx-glr-auto-tutorials-inverse-21-interactive-dipole-fit-py
@wmvanvliet awesome!
Do you want help with the optimization stuff?
I think we should improve the docs "scraper" so that when it's a dipole fitting GUI it captures the UI window. That way users on the tutorial page get to see the UI elements. I think it'll make it clearer than just scraping the PyVistaQt scene / 3D plot. It's not too hard to do it. You might be able to do it just by tweaking _GUIScraper a bit -- though looking through our tutorials I'm not 100% sure that's working since the mne coreg one does not show UI elements. Basically we'd just want to screenshot the Qt window, and Qt has a method for doing it, it's pretty simple at the end of the day IIRC...
Do you want help with the optimization stuff?
I'd love some help with this! I've made a start in this branch: https://github.com/wmvanvliet/mne-python/tree/xfit-optim
The idea there was to create a DipolerFitter class that pre-computes as much as it can and then has a fit method that will fit a dipole using the chosen sensors. It's not working correctly though and I wonder if this is the right approach after all.
I haven't even looked yet at the forward operator stuff, which I think may actually have a bigger impact.
Okay just to do a little checking to see where we should try to optimize, I wrote a ~70-line gist that computes a 366-channel M/EEG forward solution for 10 different source locations, either all at once (using a single src or single dip to contain all 10) or in a loop (using 10 srcs or 10 dips). The results are:
Single _CheckInside: 0.1s
Loop _CheckInside: 0.6s
Single make_forward_solution: 9.7s
Loop make_forward_solution: 95.6s
Single make_forward_dipole: 9.3s
Loop make_forward_dipole: 95.9s
So while _CheckInside does have some overhead, it's small compared to the total forward computation time.
The difference between computing the forward for all points at once vs in a loop is a huge, ~10x speedup for the 10 locations, suggesting that almost all the time is spent doing computations that are independent of the actual dipoles whose forwards are being calculated. So I expect that splitting up the fwd calculation into separate steps should yield a lot of speedups.
I would try something like a cf = ComputeForward(info, bem, trans) class with a __call__ so you can do sol_data = cf(rr). This should be slow for instantiation (setting up everything that can be set up independently of the source locations), but then quite fast for __call__ (which uses the source locations). I don't think it should take too much work to split the code like this. I'll try to look next week!
Okay a little proof-of-concept:
_ForwardModeler
from copy import deepcopy
import numpy as np
import mne
import time
from mne.bem import _bem_find_surface
from mne.forward._make_forward import _prepare_for_forward, _to_forward_dict, _merge_fwds, _FWD_ORDER
from mne.forward._compute_forward import _compute_forwards_meeg, _prep_field_computation
from mne.source_space._source_space import _ensure_src, _filter_source_spaces
from mne.surface import transform_surface_to
from mne.transforms import _get_trans
from mne.utils import verbose
data_path = mne.datasets.sample.data_path()
subjects_dir = data_path / 'subjects'
n_dips = 10
radius = 0.05
rr = np.random.default_rng(0).normal(size=(n_dips, 3))
rr /= np.linalg.norm(rr, axis=1)[:, np.newaxis] / radius * 0.9
nn = np.cross([1, 0, 0], rr)
nn /= np.linalg.norm(nn, axis=1)[:, np.newaxis]
assert nn.shape == rr.shape
bem = mne.read_bem_solution(data_path / "subjects" / "sample" / "bem" / "sample-5120-5120-5120-bem-sol.fif")
assert bem["surfs"][-1]["id"] == mne.io.constants.FIFF.FIFFV_BEM_SURF_ID_BRAIN
rr += bem["surfs"][-1]["rr"].mean(axis=0) # center the points around the BEM center
evoked = mne.read_evokeds(data_path / "MEG" / "sample" / "sample_audvis-ave.fif", condition=0)
evoked.crop(tmin=0.1, tmax=0.1)
# %%
# make_forward_solution
trans = mne.read_trans(data_path / "MEG" / "sample" / "sample_audvis_raw-trans.fif")
t0 = time.time()
src = mne.setup_volume_source_space("sample", pos=dict(rr=rr, nn=nn))
fwd = mne.make_forward_solution(evoked.info, trans, src, bem)
print(f"Single make_forward_solution: {time.time() - t0:0.1f}s")
class _ForwardModeler:
@verbose
def __init__(self, info, trans, bem, *, meg=True, eeg=True, mindist=0.0, n_jobs=1, bem_extra=0, info_extra=0, ignore_ref=False, verbose=None):
self.mri_head_t, _ = _get_trans(trans)
self.mindist = mindist
self.n_jobs = n_jobs
# TODO: Make `src` optional in _prepare_for_forward
src = mne.setup_volume_source_space("", pos=dict(rr=np.zeros((1, 3)), nn=np.array([[0, 0, 1]])), verbose="error")
self.sensors, _, _, _, bem = _prepare_for_forward(
src,
self.mri_head_t,
info,
bem,
mindist,
n_jobs,
bem_extra="",
trans="",
info_extra="",
meg=meg,
eeg=eeg,
ignore_ref=ignore_ref,
)
# TODO: Remove rr from the args here, it's unused
self.fwd_data = _prep_field_computation([], sensors=self.sensors, bem=bem, n_jobs=n_jobs)
def compute(self, src):
src = _ensure_src(src).copy()
for s in src:
transform_surface_to(s, "head", self.mri_head_t)
# TODO: We have a bug that we don't filter for spherical BEMs!
if not bem["is_sphere"]:
inner_skull = _bem_find_surface(bem, "inner_skull")
# TODO: Allow reusing _CheckInside object for speed as well
_filter_source_spaces(
inner_skull, self.mindist, self.mri_head_t, src, self.n_jobs
)
rr = np.concatenate([s["rr"][s["vertno"]] for s in src])
if len(rr) < 1:
raise RuntimeError(
"No points left in source space after excluding "
"points close to inner skull."
)
sensors = deepcopy(self.sensors)
fwd_data = deepcopy(self.fwd_data)
fwds = _compute_forwards_meeg(rr, sensors=sensors, fwd_data=fwd_data, n_jobs=self.n_jobs)
fwds = {
key: _to_forward_dict(fwds[key], sensors[key]["ch_names"])
for key in _FWD_ORDER
if key in fwds
}
fwd = _merge_fwds(fwds, verbose=False)
del fwds
return fwd
t0 = time.time()
fwds = list()
fm = _ForwardModeler(evoked.info, trans, bem)
for ri in range(n_dips):
this_src = mne.setup_volume_source_space("sample", pos=dict(rr=rr[[ri]], nn=nn[[ri]]))
# fwds.append(mne.make_forward_solution(evoked.info, trans, this_src, bem))
fwds.append(fm.compute(this_src))
print(f"Loop make_forward_solution: {time.time() - t0:0.1f}s")
print()
check = np.concatenate([f["sol"]["data"] for f in fwds], axis=1)
np.testing.assert_allclose(check, fwd["sol"]["data"])
Gives:
Single make_forward_solution: 9.4s
Loop make_forward_solution: 10.2s
and the remaining time loss is mostly from _CheckInside recomputation, which is easy enough to fix, too.
@wmvanvliet agree this is worthwhile? If so I'll open a hopefully quick PR to add this private class and reorganize the forward code a tiny bit so you can then rebase (or merge with main) this PR to make things much faster...
@larsoner I've integrated your code into the PR and also added _CheckInside to the dipole fitting algorithm. This makes a huge difference. The dipole fitting is fast enough now to be usable. Hooray!
For some reason, on some of the versions in the CI, the fitted dipole is in a slightly different location. I have no idea how to debug this.
Usually it suggests that numerical issues are cropping up... a bad/incorrect rank is one common culprit that could be an issue
@wmvanvliet I've been meaning to set this up for a while... you could set up a remote SSH to the GHA runners following "ssh" here from Spyder
https://github.com/spyder-ide/spyder/blob/595113cf37b4c6d07a60e1b7a78bb5ae37ec03cd/.github/workflows/test-linux.yml#L40
That way you can at least test what's going on remotely
@wmvanvliet looking at this PR we are at +1,653 −69 and a coverage 68.92% of diff hit and red CIs. I worry a bit about reviewing + testing getting harder and harder if more features are added... would now be a good spot to get the PR into a mergeable state, and iterate in smaller PRs after that? (Maybe a few WIP features need to be removed for now, which would be okay.)
absolutely. This was what I was planning on doing. The feature set now is a minimum set to be useful. It needs more unit test coverage though.