mne-python icon indicating copy to clipboard operation
mne-python copied to clipboard

XFit-style dipole fitting GUI

Open wmvanvliet opened this issue 10 months ago • 5 comments

This adds a GUI to perform guided dipole modeling in the spirit of MEGIN's XFit program. This PR contains the base functionality needed to make the GUI useful. Useful enough to include in the next release of MNE-Python. The plan is to keep adding features in future PRs as well as some sorely needed speed improvements.

See here for a list of currently supported features: #11977 This PR depends on: #12071

Screenshot 2025-01-21 183100

Minimal example:

import mne

data_path = mne.datasets.sample.data_path()
evoked = mne.read_evokeds(
    f"{data_path}/MEG/sample/sample_audvis-ave.fif", condition="Left Auditory"
)
evoked.apply_baseline((-0.2, 0))
trans = mne.read_trans(f"{data_path}/MEG/sample/sample_audvis_raw-trans.fif")
cov = mne.read_cov(f"{data_path}/MEG/sample/sample_audvis-cov.fif")
bem = mne.read_bem_solution(f"{data_path}/subjects/sample/bem/sample-5120-5120-5120-bem-sol.fif")
subject = "sample"
subjects_dir = data_path / "subjects"

evoked.pick("grad")

# Quick and dirty. No MRI whatsoever.
g = mne.gui.DipoleFitUI(evoked)

# Slow and proper. With proper BEM model.
# g = mne.gui.DipoleFitUI(evoked, trans=trans, cov=cov, bem=bem, subject=subject, subjects_dir=subjects_dir)

Todo:

  • [ ] Unit tests
  • [ ] Proper documentation page
  • [ ] Towncrier

wmvanvliet avatar Jan 21 '25 17:01 wmvanvliet

Great to see progress here, let me know when it would be useful to get usability and/or code feedback!

larsoner avatar Jan 21 '25 18:01 larsoner

The actual functionality is ready for testing. I'd love some feedback on usability, ideas for improvements and overall thoughts.

wmvanvliet avatar Jan 21 '25 20:01 wmvanvliet

Would it be worth discussing if the various GUIs should be moved to a separate package? I feel like things like dipole fitting don't have to live in the main MNE-Python package. Thoughts?

cbrnr avatar Jan 22 '25 15:01 cbrnr

Interactive dipole fitting is perhaps the oldest source localization method and still quite widely used, so to me it's within scope to put it in in MNE-Python itself. To me it's something we've always been missing / a deficiency of our software. I don't see too many advantages to splitting this one off unless @wmvanvliet plans to need to iterate and release faster than MNE-Python itself.

larsoner avatar Jan 22 '25 16:01 larsoner

Would it be worth discussing if the various GUIs should be moved to a separate package?

It's certainly worth discussing. As @larsoner says, well established source estimation techniques should be within scope of MNE-Python. It may at some point still be wise to split it off into its own package when we want to become more ambitious with its interface. The precedent for this is mne-qt-browser. As long as the functionality provided by this GUI works within the Figure3d class, all is good, and I think this will be the case for the foreseable future. When we want access to the full Qt functionality we should probably split off.

wmvanvliet avatar Jan 22 '25 18:01 wmvanvliet

@larsoner and everyone: could you help brainstorm a bit?

The functionality is now complete for a first release. Here is a test script I've been using to set up the GUI:

import mne
path = mne.datasets.sample.data_path()
meg_dir = path / "MEG" / "sample"
subjects_dir = path / "subjects"

bem = mne.read_bem_solution(subjects_dir / "sample" / "bem" / "sample-5120-bem-sol.fif")
cov = mne.read_cov(meg_dir / "sample_audvis-cov.fif")
trans = mne.read_trans(meg_dir / "sample_audvis_raw-trans.fif")
evoked = mne.read_evokeds(meg_dir / "sample_audvis-ave.fif", condition="Left Auditory")
evoked = evoked.pick("meg").apply_baseline((-0.2, 0))
evoked = evoked.interpolate_bads()
cov = cov.pick_channels(evoked.ch_names)
inv = mne.minimum_norm.read_inverse_operator(meg_dir / "sample_audvis-meg-oct-6-meg-inv.fif")
stc = mne.minimum_norm.apply_inverse(evoked, inv)

g = mne.gui.dipolefit(evoked, cov=cov, stc=stc, bem=bem, trans=trans, subject="sample", subjects_dir=subjects_dir, n_jobs=-1)

It works great, however fitting dipoles is painfully slow. I've identified some bottlenecks:

  1. While optimizing the dipole position, every iteration a check is performed testing if the candidate dipole position is inside the mesh: https://github.com/mne-tools/mne-python/blob/ace9381898b3ad9da9590a3d0b99abedaa74e033/mne/dipole.py#L1271 which is a very expensive operation.
  2. After fitting the dipoles, we need the timecourses, which is done through MNE. Every time we change something about the dipoles, the forward model needs to be updated appropriately. Right now we just call mne.make_forward_dipole from scratch which takes a long time as it needs to perform all kinds of checks and set up everything.

Any ideas on how to improve speed? Xfit is lightning fast and I would like our implementation to at least approximate that.

wmvanvliet avatar Aug 18 '25 13:08 wmvanvliet

For the interior check, we should use a single instance of _CheckInside and reuse it. That class does some neat shortcuts to check interior-ness quickly. By either 1) not using it or 2) reconstructing it every time (not sure offhand which of these _points_outside_surface does!) we are doing things slower than if we reused a single instance of _CheckInside. AFAIK Xfit doesn't do any check like this (?), so if we really want, we could add an checkbox Check dipole interior or whatever to allow people to disable the check if they want faster fitting, and would allow you to compare to Xfit better. But it seems buggy to allow points to go outside the surface so we probably do want to warn people that it's not a good idea, leave it checked by default etc...

For make_forward_dipole, we should refactor it into two stages: a prep stage and a computation stage. Something like this is done in mne.simulation already for different head positions so hopefully the private functions are already there and it's not too much work to make the transition. We should see how close that gets us in terms of speed. Then we can see if there are additional steps we can split into ones that can be done once, and ones that must be recomputed each time. For example I suspect some computations having to do with the BEM surfaces and sensors / volume currents could be computed once and reused, leaving just the parts having to do with individual dipoles / primary currents (src points) to be computed. This would probably be an even bigger gain, and is not IIRC done by mne.simulation so would require a bit of code refactoring / private function use again.

Just a final comment about the setup script above -- we probably want to long-term do something like we do with mne coreg where we have mne dipole ... or similar where you can pass like --cov <path> and have it launch a GUI with the cov set etc. I would consider this a higher priority than the optimization stuff above unless the slowness of usage is really a show-stopper for your use cases.

larsoner avatar Aug 18 '25 14:08 larsoner

mne dipolefit command line utility added

wmvanvliet avatar Aug 19 '25 10:08 wmvanvliet

I've added a tutorial that showcases the features the GUI currently supports: https://output.circle-artifacts.com/output/job/2e331c1c-0641-42d8-8107-e4451230b6cd/artifacts/0/html/auto_tutorials/inverse/21_interactive_dipole_fit.html#sphx-glr-auto-tutorials-inverse-21-interactive-dipole-fit-py

wmvanvliet avatar Aug 21 '25 12:08 wmvanvliet

@wmvanvliet awesome!

Do you want help with the optimization stuff?

I think we should improve the docs "scraper" so that when it's a dipole fitting GUI it captures the UI window. That way users on the tutorial page get to see the UI elements. I think it'll make it clearer than just scraping the PyVistaQt scene / 3D plot. It's not too hard to do it. You might be able to do it just by tweaking _GUIScraper a bit -- though looking through our tutorials I'm not 100% sure that's working since the mne coreg one does not show UI elements. Basically we'd just want to screenshot the Qt window, and Qt has a method for doing it, it's pretty simple at the end of the day IIRC...

larsoner avatar Aug 21 '25 17:08 larsoner

Do you want help with the optimization stuff?

I'd love some help with this! I've made a start in this branch: https://github.com/wmvanvliet/mne-python/tree/xfit-optim The idea there was to create a DipolerFitter class that pre-computes as much as it can and then has a fit method that will fit a dipole using the chosen sensors. It's not working correctly though and I wonder if this is the right approach after all.

I haven't even looked yet at the forward operator stuff, which I think may actually have a bigger impact.

wmvanvliet avatar Aug 21 '25 17:08 wmvanvliet

Okay just to do a little checking to see where we should try to optimize, I wrote a ~70-line gist that computes a 366-channel M/EEG forward solution for 10 different source locations, either all at once (using a single src or single dip to contain all 10) or in a loop (using 10 srcs or 10 dips). The results are:

Single _CheckInside: 0.1s
Loop _CheckInside:   0.6s

Single make_forward_solution: 9.7s
Loop make_forward_solution:   95.6s

Single make_forward_dipole: 9.3s
Loop make_forward_dipole:   95.9s

So while _CheckInside does have some overhead, it's small compared to the total forward computation time.

The difference between computing the forward for all points at once vs in a loop is a huge, ~10x speedup for the 10 locations, suggesting that almost all the time is spent doing computations that are independent of the actual dipoles whose forwards are being calculated. So I expect that splitting up the fwd calculation into separate steps should yield a lot of speedups.

I would try something like a cf = ComputeForward(info, bem, trans) class with a __call__ so you can do sol_data = cf(rr). This should be slow for instantiation (setting up everything that can be set up independently of the source locations), but then quite fast for __call__ (which uses the source locations). I don't think it should take too much work to split the code like this. I'll try to look next week!

larsoner avatar Aug 22 '25 20:08 larsoner

Okay a little proof-of-concept:

_ForwardModeler
from copy import deepcopy
import numpy as np
import mne
import time

from mne.bem import _bem_find_surface
from mne.forward._make_forward import _prepare_for_forward, _to_forward_dict, _merge_fwds, _FWD_ORDER
from mne.forward._compute_forward import _compute_forwards_meeg, _prep_field_computation
from mne.source_space._source_space import _ensure_src, _filter_source_spaces
from mne.surface import transform_surface_to
from mne.transforms import _get_trans
from mne.utils import verbose

data_path = mne.datasets.sample.data_path()
subjects_dir = data_path / 'subjects'
n_dips = 10
radius = 0.05
rr = np.random.default_rng(0).normal(size=(n_dips, 3))
rr /= np.linalg.norm(rr, axis=1)[:, np.newaxis] / radius * 0.9
nn = np.cross([1, 0, 0], rr)
nn /= np.linalg.norm(nn, axis=1)[:, np.newaxis]
assert nn.shape == rr.shape
bem = mne.read_bem_solution(data_path / "subjects" / "sample" / "bem" / "sample-5120-5120-5120-bem-sol.fif")
assert bem["surfs"][-1]["id"] == mne.io.constants.FIFF.FIFFV_BEM_SURF_ID_BRAIN
rr += bem["surfs"][-1]["rr"].mean(axis=0)  # center the points around the BEM center
evoked = mne.read_evokeds(data_path / "MEG" / "sample" / "sample_audvis-ave.fif", condition=0)
evoked.crop(tmin=0.1, tmax=0.1)

# %%
# make_forward_solution
trans = mne.read_trans(data_path / "MEG" / "sample" / "sample_audvis_raw-trans.fif")
t0 = time.time()
src = mne.setup_volume_source_space("sample", pos=dict(rr=rr, nn=nn))
fwd = mne.make_forward_solution(evoked.info, trans, src, bem)
print(f"Single make_forward_solution: {time.time() - t0:0.1f}s")

class _ForwardModeler:
    @verbose
    def __init__(self, info, trans, bem, *, meg=True, eeg=True, mindist=0.0, n_jobs=1, bem_extra=0, info_extra=0, ignore_ref=False, verbose=None):
        self.mri_head_t, _ = _get_trans(trans)
        self.mindist = mindist
        self.n_jobs = n_jobs
        # TODO: Make `src` optional in _prepare_for_forward
        src = mne.setup_volume_source_space("", pos=dict(rr=np.zeros((1, 3)), nn=np.array([[0, 0, 1]])), verbose="error")
        self.sensors, _, _, _, bem = _prepare_for_forward(
            src,
            self.mri_head_t,
            info,
            bem,
            mindist,
            n_jobs,
            bem_extra="",
            trans="",
            info_extra="",
            meg=meg,
            eeg=eeg,
            ignore_ref=ignore_ref,
        )
        # TODO: Remove rr from the args here, it's unused
        self.fwd_data = _prep_field_computation([], sensors=self.sensors, bem=bem, n_jobs=n_jobs)

    def compute(self, src):
        src = _ensure_src(src).copy()
        for s in src:
            transform_surface_to(s, "head", self.mri_head_t)

        # TODO: We have a bug that we don't filter for spherical BEMs!
        if not bem["is_sphere"]:
            inner_skull = _bem_find_surface(bem, "inner_skull")
            # TODO: Allow reusing _CheckInside object for speed as well
            _filter_source_spaces(
                inner_skull, self.mindist, self.mri_head_t, src, self.n_jobs
            )
        rr = np.concatenate([s["rr"][s["vertno"]] for s in src])
        if len(rr) < 1:
            raise RuntimeError(
                "No points left in source space after excluding "
                "points close to inner skull."
            )

        sensors = deepcopy(self.sensors)
        fwd_data = deepcopy(self.fwd_data)
        fwds = _compute_forwards_meeg(rr, sensors=sensors, fwd_data=fwd_data, n_jobs=self.n_jobs)
        fwds = {
            key: _to_forward_dict(fwds[key], sensors[key]["ch_names"])
            for key in _FWD_ORDER
            if key in fwds
        }
        fwd = _merge_fwds(fwds, verbose=False)
        del fwds
        return fwd


t0 = time.time()
fwds = list()
fm = _ForwardModeler(evoked.info, trans, bem)
for ri in range(n_dips):
    this_src = mne.setup_volume_source_space("sample", pos=dict(rr=rr[[ri]], nn=nn[[ri]]))
    # fwds.append(mne.make_forward_solution(evoked.info, trans, this_src, bem))
    fwds.append(fm.compute(this_src))
print(f"Loop make_forward_solution:   {time.time() - t0:0.1f}s")
print()
check = np.concatenate([f["sol"]["data"] for f in fwds], axis=1)
np.testing.assert_allclose(check, fwd["sol"]["data"])

Gives:

Single make_forward_solution: 9.4s
Loop make_forward_solution:   10.2s

and the remaining time loss is mostly from _CheckInside recomputation, which is easy enough to fix, too.

@wmvanvliet agree this is worthwhile? If so I'll open a hopefully quick PR to add this private class and reorganize the forward code a tiny bit so you can then rebase (or merge with main) this PR to make things much faster...

larsoner avatar Aug 28 '25 17:08 larsoner

@larsoner I've integrated your code into the PR and also added _CheckInside to the dipole fitting algorithm. This makes a huge difference. The dipole fitting is fast enough now to be usable. Hooray!

wmvanvliet avatar Aug 28 '25 21:08 wmvanvliet

For some reason, on some of the versions in the CI, the fitted dipole is in a slightly different location. I have no idea how to debug this.

wmvanvliet avatar Sep 19 '25 08:09 wmvanvliet

Usually it suggests that numerical issues are cropping up... a bad/incorrect rank is one common culprit that could be an issue

larsoner avatar Sep 19 '25 17:09 larsoner

@wmvanvliet I've been meaning to set this up for a while... you could set up a remote SSH to the GHA runners following "ssh" here from Spyder

https://github.com/spyder-ide/spyder/blob/595113cf37b4c6d07a60e1b7a78bb5ae37ec03cd/.github/workflows/test-linux.yml#L40

That way you can at least test what's going on remotely

larsoner avatar Sep 19 '25 17:09 larsoner

@wmvanvliet looking at this PR we are at +1,653 −69 and a coverage 68.92% of diff hit and red CIs. I worry a bit about reviewing + testing getting harder and harder if more features are added... would now be a good spot to get the PR into a mergeable state, and iterate in smaller PRs after that? (Maybe a few WIP features need to be removed for now, which would be okay.)

larsoner avatar Nov 05 '25 18:11 larsoner

absolutely. This was what I was planning on doing. The feature set now is a minimum set to be useful. It needs more unit test coverage though.

wmvanvliet avatar Nov 05 '25 20:11 wmvanvliet