WIP: Add GED transformer
What does this implement/fix?
Adds transformer for generalized eigenvalue decomposition (or approximate joint diagonalization) of covariance matrices. It generalizes xdawn, csp, ssd, and spoc algorithms.
Additional information
Steps:
- [x] test that it outputs identical filters and patterns as child classes for all tests by adding temporary
assert_allclosecalls in code - [x] cover tests for _GEDTransformer and core functions
- [x] add _validate_params to _XdawnTransformer
- [x] add feature to perform GED in the principal subspace for Xdawn and SPoC
- [x] add option for CSP and SSD to select restr_type and provide info for CSP
- [x] add entry in mne's implementation details
- [x] move SSD and Xdawn pattern computation from np.linalg.pinv to mne's pinv
- [x] change SSD's multiplication order for dimension reduction for consistency
- [x] fix SSD's filters_ shape inconsistency
- [x] move mne.preprocessing._XdawnTransformer to decoding and make it public
- [x] rename _XdawnTransformer method_params to cov_method_params for consistency
- [ ] SSD performs spectral sorting of components each time .transform() is applied. This could be optimized by sorting filters_, evals_ and patterns_ already in fit and will suit current GED design better
- [ ] in SSD's .transform() when return_filtered=True, subsetting with self.picks_ is done two times - looks like bug
- [ ] remove
assert_allclosecalls in code - [ ] clean up newly redundant code from the child classes
- [ ] complete tests for child classes
- [ ] perhaps ssd should use mne.cov.compute_whitener() instead of its own whitener implementation. It won't be identical, but conceptually seems to do the same thing
Then it should be ready for merge!
Already have a failure but fortunately it's just a tol issue I think:
mne/decoding/tests/test_csp.py:444: in test_spoc
spoc.fit(X, y)
mne/decoding/csp.py:985: in fit
np.testing.assert_allclose(old_filters, self.filters_)
E AssertionError:
E Not equal to tolerance rtol=1e-07, atol=0
E
E Mismatched elements: 1 / 100 (1%)
E Max absolute difference among violations: 9.04019248e-09
E Max relative difference among violations: 1.11806536e-07
E ACTUAL: array([[ 2.037415, 1.424886, 2.718162, -3.07798 , -3.862132,
E 1.412549, -3.821452, 1.276637, 1.899782, -2.389858],
E [ 11.534231, -22.178034, -12.321628, -52.410096, 62.876084,...
E DESIRED: array([[ 2.037415, 1.424886, 2.718162, -3.07798 , -3.862132,
E 1.412549, -3.821452, 1.276637, 1.899782, -2.389858],
E [ 11.534231, -22.178034, -12.321628, -52.410096, 62.876084,...
I would just bump the rtol a bit here to 1e-6, and if you know the magnitudes are in the single/double digits then an atol=1e-7 would also be reasonable (could do both).
Thanks! Interesting how it passed macos-13/mamba/3.12, but didn't pass macos-latest/mamba/3.12
It might be that the small difference between filters_ will propagate and increase in patterns_, so rtol/atol won't be much of help for patterns_. But let's see
Different architectures, macos-13 is Intel x86_64 and macos-latest is ARM / M1. And Windows also failed, could be use of MKL there or something. I'm cautiously optimistic it's just floating point errors...
@larsoner, I think I covered tests for the core GEDTransformer cases. Could you check that it's enough and I can move to the next step?
Thanks Eric!
Great that the
assertstatements are passing! Just a few comments below. Also, can you see if you can get closer to 100% coverage here? https://app.codecov.io/gh/mne-tools/mne-python/pull/13259
That's a cool tool, like that! Will do
FYI I modified your top comment to have checkboxes (you can see how it's done if you go to edit it yourself) and a rough plan. Can you see if the plan is what you have in mind and update accordingly if needed? Then I can see where you (think you) are after your next push, and when you ask "okay to move on" I'll know what you want to do next 😄
Alright :)
Hi @larsoner! Is there anything else I should improve before moving on to the next steps?
Also, could you please confirm that I'm not confused with these two?
- SSD performs spectral sorting of components each time .transform() is applied. This could be optimized by sorting filters_, evals_ and patterns_ already in fit and will suit current GED design better.
- in SSD's .transform() when return_filtered=True, subsetting with self.picks_ is done two times - looks like bug.
SSD performs spectral sorting of components each time .transform() is applied. This could be optimized by sorting filters_, evals_ and patterns_ already in fit and will suit current GED design better.
I think this should be okay. Is the sorting by variance explained or something? Before they're sorted how are they ordered? If it's random for example it's fine to sort, but if it's not random and has some meaning them maybe we shouldn't store them sorted some other way...
in SSD's .transform() when return_filtered=True, subsetting with self.picks_ is done two times - looks like bug.
Yes it sounds like it probably is. Best thing you can do is prove to yourself (and us in review :smile: ) that it's a bug by pushing a tiny test that would fail on main but pass on this PR. Maybe by taking the testing dataset data which has MEG channels followed by EEG channels, and telling it to operate on just EEG channels or something? If it picks twice the second picking should fail because the indices will be out of range...
I think this should be okay. Is the sorting by variance explained or something? Before they're sorted how are they ordered? If it's random for example it's fine to sort, but if it's not random and has some meaning them maybe we shouldn't store them sorted some other way...
They are stored sorted by descending eigenvalues by default. But I think given the sort_by_spectral_ratio parameter it is expected that the filters will be stored according to this sorting when the parameter is True.
Best thing you can do is prove to yourself (and us in review 😄 ) that it's a bug by pushing a tiny test that would fail on main but pass on this PR.
I pushed the test here, but how do I show you that it fails on main?
Is there anything left to do before I can remove the asserts and clean up the classes?
I pushed the test here, but how do I show you that it fails on main?
If this were true test-driven development (TDD) you could for example make this your first commit, show it failed on CIs, then proceed to fix it in subsequent commits. But in practice this isn't typically done and would be annoying to do here, so what I would suggest is that you take whatever tiny test segment should fail on main, copy it, switch over to the main branch, paste it back in, and make sure it fails. In principle I could do this as well but better for you to give it a shot and say "yep it failed" and I'll trust you to have done it properly :smile:
Is there anything left to do before I can remove the asserts and clean up the classes?
Currently I see a bunch of red CIs. So I would say yes, get those green first. An implicit sub-step of all steps is "make sure CIs are still green" before moving onto the next step (unless of course one of your steps someday is something like, "TDD: add breaking test and see CIs red" or whatever, but it's not that way here). Then proceeding with the plan in the top comment looks good to me, the next step of which is to remove the asserts, then redundant code, etc.!
If this were true test-driven development (TDD) you could for example make this your first commit, show it failed on CIs, then proceed to fix it in subsequent commits. But in practice this isn't typically done and would be annoying to do here, so what I would suggest is that you take whatever tiny test segment should fail on main, copy it, switch over to the main branch, paste it back in, and make sure it fails. In principle I could do this as well but better for you to give it a shot and say "yep it failed" and I'll trust you to have done it properly 😄
Of course, I checked that it fails before I even started implementing the fix. Thanks for the explanation!
Currently I see a bunch of red CIs. So I would say yes, get those green first. An implicit sub-step of all steps is "make sure CIs are still green" before moving onto the next step (unless of course one of your steps someday is something like, "TDD: add breaking test and see CIs red" or whatever, but it's not that way here). Then proceeding with the plan in the top comment looks good to me, the next step of which is to remove the asserts, then redundant code, etc.!
Yeah, I rushed a bit with a comment before checking that everything's green :) I handled the major problem, but the failure I get now doesn't make much sense to me. The issue seems to be order-related (which is fair given my last commits), but it doesn't appear neither on my windows machine, nor on most of the CIs here. Is there a way to get a more elaborate traceback?
Is there a way to get a more elaborate traceback?
For debugging you can add stuff to the tests like print(...) and you'll see it in the captured stdout of the failed test. You could do this on CIs but it'll take ~50 minutes per try which is a big pain.
Fortunately it sounds like you have a Windows setup you can test on. I would look at
https://dev.azure.com/mne-tools/mne-python/_build/results?buildId=33458&view=logs&jobId=2bd7b19d-6351-5e7f-8417-63f327ab45bc&j=2bd7b19d-6351-5e7f-8417-63f327ab45bc&t=549c8367-5889-5507-1708-200535760ded
And make sure you're using the same NumPy with the same OpenBLAS version on your Windows machine. You could do this in a venv, or create and activate a basic conda env and then install via pip, which is what CI run does. Things can change slightly based on BLAS/LAPACK (MKL vs OpenBLAS, OpenBLAS version etc.) bindings.
Hmm, at least on my Windows VM I upgraded NumPy, SciPy, and scikit-learn and couldn't replicate the problem.
I think there are a few options:
- Replicate locally (somehow!) and fix the issue. Best, but probably hard.
- Add a bunch of
printstatements that maybe allow you to figure out the problem, and get the CIs to give you debugging info. Also a good solution, but would probably take too long. - "Improve" the
assertstatements in (just) SSD to do some sorting to work around this issue. Assuming the decomposition actually works properly, then we live with it. Theseassertstatements go away soon anyway, right? - Temporarily add a:
which can also be removed once the internalfailing_checks = ( "check_readonly_memmap_input", "check_estimators_fit_returns_self", "check_estimators_overwrite_params", ) if ( platform.system() == "Windows" and os.getenv("MNE_CI_KIND", "") == "pip" and check in failing_checks ): pytest.xfail("Broken on Windows pip CIs")assert_allclosestatements go away.
For expedience I'd go with (4) assuming you can't actually replicate locally after giving it a shot.
Thanks Eric! I tried (1) and didn't replicate the failure with the CI's versions of python, numpy an OpenBlas (setting up conda was pain :)) (4) seem to work, so I'm proceeding forward
Hmmm, current fail doesn't seem to be relevant to this PR.
Anyway, I started thinking and playing with the last point in the plan:
perhaps ssd should use mne.cov.compute_whitener() instead of its own whitener implementation. It won't be identical, but conceptually seems to do the same thing
Setting aside for now subtle algebraic differences between compute_whitener and current ssd whitener, this will be a two-line change and I'll be able to clean up some ssd-specific conditionals from ged functions.
The problem is that _smart_eigh throws a warning when there is no average reference in info, which breaks some ssd tests.
I probably can fix this in the tests by adding avg projector, but I'm not sure I agree with seriousness of this warning. Perhaps it can be downgraded from warn to info? What do you think?
Also, is there a way to see which lines are not test-covered for csp/spoc/ssd/xdawn so I can target them specifically?
The problem is that _smart_eigh throws a warning when there is no average reference in info, which breaks some ssd tests. I probably can fix this in the tests by adding avg projector, but I'm not sure I agree with seriousness of this warning. Perhaps it can be downgraded from warn to info? What do you think?
This is really a warning relevant for source imaging. We should keep it as-is in cov.py since it's widely used by MNE end users for source imaging, but in your code you can pass verbose="error" to compute_whitener and it won't emit a warning where you use it in SSD.
Also, is there a way to see which lines are not test-covered for csp/spoc/ssd/xdawn so I can target them specifically?
Yeah this
Will lead you to for example
https://app.codecov.io/gh/mne-tools/mne-python/pull/13259?dropdown=coverage&src=pr&el=h1&utm_medium=referral&utm_source=github&utm_content=checks&utm_campaign=pr+comments&utm_term=mne-tools
Thanks!
Will lead you to for example https://app.codecov.io/gh/mne-tools/mne-python/pull/13259?dropdown=coverage&src=pr&el=h1&utm_medium=referral&utm_source=github&utm_content=checks&utm_campaign=pr+comments&utm_term=mne-tools
IIUC, this only shows the lines relevant to the PR, I thought I could check if something important is missing in general in their files. I've found it now (here), it seems they are covered pretty well, so I'll only fix what's left uncovered for this PR.
@larsoner, I think mostly everything I planned and didn't plan to do for this part is done.
About making _XdawnTransformer class public: I moved it to decoding and removed the _. But if I understood you correctly the last time, there are additional steps in making a class public in mne?
About making
_XdawnTransformerclass public: I moved it to decoding and removed the_. But if I understood you correctly the last time, there are additional steps in making a class public in mne?
It will need to be added to mne.decoding.__init__.pyi (to be importable under mne.decoding.XdawnTransformer), and also to doc/api/decoding.rst (to be cross-referencable in our documentation)
@larsoner, @drammock, I can't make sense of the failures in inverse_sparse tests, do you think it's somehow related to my changes in the code? Other than that I think the PR is ready!
Pushed https://github.com/mne-tools/mne-python/pull/13315 (which was green) and merged the changes into this branch, if it doesn't come back green then it suggests there is something odd about this branch but I'd be surprised. I should be able to look Monday!
... I also clicked the "Ready for Review" button and changed the title
@larsoner, I checked in the last commits whether this MxNE-related failure is the only problem by skipping it - tests came back green
@Genuster can you replicate locally? I can't. If someone can, then they could git bisect or revert bits of changes to figure out what's causing the breakage. I'm not sure what it could be!
@Genuster can you replicate locally?
On my Windows machine I also don't replicate it. But I don't have easy access to Ubuntu to replicate the CI setup that fails.
If someone can, then they could git bisect or revert bits of changes to figure out what's causing the breakage. I'm not sure what it could be!
These failures first appeared when I added XdawnTransformer to decoding.rst and __init__.pyi and the changelog entry. Makes no sense at all with respect to the MxNE tests
Okay @Genuster one way to debug this would be to use https://github.com/mxschmitt/action-tmate . It allows you to SSH into the GitHub action runner. You could use this to debug the failure. Do you want to try?
If not, I can give it a shot hopefully today or tomorrow
Okay @Genuster one way to debug this would be to use https://github.com/mxschmitt/action-tmate . It allows you to SSH into the GitHub action runner. You could use this to debug the failure. Do you want to try? If not, I can give it a shot hopefully today or tomorrow
There is still a lot of work to be done for the GSoC project, I'd be happy if you could help with this. I've started to work on the spatial filters visualization in the meantime.
Thanks @Genuster !