Type incompatibility between sklearn's API and `SlidingEstimator`
I wanted to pass a SlidingEstimator instance to sklearn's cross_val_predict(), but it gets red squiggly lines:
The error message is:
Argument of type "SlidingEstimator" cannot be assigned to parameter "estimator" of type "BaseEstimator" in function "cross_val_predict" "SlidingEstimator" is incompatible with "BaseEstimator"
The reason is that SlidingEstimator doesn't inherit from sklearn's BaseEstimator, but from mne.fixes.BaseEstimator.
My current workaround is cast()ing the type:
from typing import cast
from sklearn.base import BaseEstimator
predictions = cross_val_predict(
cast(BaseEstimator, model),
X,
labels["cue"],
cv=cv,
n_jobs=-1,
)
I think historically we had a mne.fixes.BaseEstimator because
- We wanted to leave sklearn as an optional dependency, and
- We imported all submodules on
mneimport
Now that we don't fully do (2) (via lazy loading) we can probably move toward un-nesting all our imports of sklearn at least in mne/decoding. If you want to use anything in that submodule, I think it's fair to require sklearn.
The two other places that use mne.fixes.BaseEstimator, though, are mne/preprocessing/xdawn.py and mne/cov.py. Stuff from mne/cov.py ends up in the root namespace so we can't just import sklearn there, at least not without a try/except. So I don't see a clear/obvious path for getting rid of all mne.fixes.BaseEstimator uses without changing our dependency structure.
you get red squiggly lines due to typing but does it crash? if it's a typing issue you can work out a solution using maybe a protocol definition but honestly I would not fight for this.
Message ID: @.***>
Thanks @agramfort, good point. It runs just fine, so it seems to adhere to the sklearn API perfectly, but the type checker doesn't know about this. I'm wondering if we could simply ship a type definition that makes the type checker treat this like a "proper" sklearn estimator, and be done with it... otherwise, defining a protocol seems to be the only alternative left ;/
I won't prevent you from fighting this battle. Your choice :)
Message ID: @.***>
@larsoner How about we switch everything in mne.decoding over to using sklearn's BaseEstimator, like you proposed, and let XDawn and Cov be for now? This would already help in many use cases, and be trivial to implement.
please don't make sklearn suddenly required in some places just for typing reasons and red squiggly lines
@agramfort I have an idea :) will propose something later today
Bumping the milestone on this one
Okay this is causing us some maintenance burden... see for example https://github.com/scikit-learn/scikit-learn/pull/29677 which is now causing:
mne/decoding/tests/test_search_light.py:373: in test_sklearn_compliance
check(est)
../virtualenvs/base/lib/python3.12/site-packages/sklearn/utils/estimator_checks.py:3893: in check_estimator_tags_renamed
assert not hasattr(estimator_orig, "_more_tags"), (
E AssertionError: ('_more_tags() was removed in 1.6. Please use __sklearn_tags__ instead.',)
We will need to adapt and vendor maybe two code paths in our mne.fixes.BaseEstimator :scream: for stuff like this. It seems very fragile.
My proposal would be to make anything in mne.decoding use sklearn with un-nested imports. With lazy loading this is easy. I think the impact on users should be very minimal -- anyone using decoding can/will probably have sklearn instealled anyway. One possible exception is maybe XDawn which can also be a preprocessing step but I think that's okay.
Assuming that's okay I'll have to figure out some way to handle EmpericalCovariance in mne.cov needing to also be an estimator. But I think that'll be doable.
Thoughts @agramfort @drammock ?
I have no objection to putting sklearn un-nested only within the decoding module. As you say, lazy loading provides a nice benefit here, in that it shouldn't affect anyone who isn't using/importing from the decoding module.