mne-python icon indicating copy to clipboard operation
mne-python copied to clipboard

Type incompatibility between sklearn's API and `SlidingEstimator`

Open hoechenberger opened this issue 1 year ago • 8 comments

I wanted to pass a SlidingEstimator instance to sklearn's cross_val_predict(), but it gets red squiggly lines: Screenshot 2024-07-25 at 09 17 41

The error message is:

Argument of type "SlidingEstimator" cannot be assigned to parameter "estimator" of type "BaseEstimator" in function "cross_val_predict" "SlidingEstimator" is incompatible with "BaseEstimator"

The reason is that SlidingEstimator doesn't inherit from sklearn's BaseEstimator, but from mne.fixes.BaseEstimator.

My current workaround is cast()ing the type:

from typing import cast
from sklearn.base import BaseEstimator

predictions = cross_val_predict(
    cast(BaseEstimator, model),
    X,
    labels["cue"],
    cv=cv,
    n_jobs=-1,
)

hoechenberger avatar Jul 25 '24 07:07 hoechenberger

I think historically we had a mne.fixes.BaseEstimator because

  1. We wanted to leave sklearn as an optional dependency, and
  2. We imported all submodules on mne import

Now that we don't fully do (2) (via lazy loading) we can probably move toward un-nesting all our imports of sklearn at least in mne/decoding. If you want to use anything in that submodule, I think it's fair to require sklearn.

The two other places that use mne.fixes.BaseEstimator, though, are mne/preprocessing/xdawn.py and mne/cov.py. Stuff from mne/cov.py ends up in the root namespace so we can't just import sklearn there, at least not without a try/except. So I don't see a clear/obvious path for getting rid of all mne.fixes.BaseEstimator uses without changing our dependency structure.

larsoner avatar Jul 25 '24 17:07 larsoner

you get red squiggly lines due to typing but does it crash? if it's a typing issue you can work out a solution using maybe a protocol definition but honestly I would not fight for this.

Message ID: @.***>

agramfort avatar Jul 26 '24 06:07 agramfort

Thanks @agramfort, good point. It runs just fine, so it seems to adhere to the sklearn API perfectly, but the type checker doesn't know about this. I'm wondering if we could simply ship a type definition that makes the type checker treat this like a "proper" sklearn estimator, and be done with it... otherwise, defining a protocol seems to be the only alternative left ;/

hoechenberger avatar Jul 26 '24 06:07 hoechenberger

I won't prevent you from fighting this battle. Your choice :)

Message ID: @.***>

agramfort avatar Jul 26 '24 07:07 agramfort

@larsoner How about we switch everything in mne.decoding over to using sklearn's BaseEstimator, like you proposed, and let XDawn and Cov be for now? This would already help in many use cases, and be trivial to implement.

hoechenberger avatar Jul 26 '24 07:07 hoechenberger

please don't make sklearn suddenly required in some places just for typing reasons and red squiggly lines

agramfort avatar Jul 26 '24 07:07 agramfort

@agramfort I have an idea :) will propose something later today

hoechenberger avatar Jul 26 '24 07:07 hoechenberger

Bumping the milestone on this one

larsoner avatar Aug 07 '24 19:08 larsoner

Okay this is causing us some maintenance burden... see for example https://github.com/scikit-learn/scikit-learn/pull/29677 which is now causing:

mne/decoding/tests/test_search_light.py:373: in test_sklearn_compliance
    check(est)
../virtualenvs/base/lib/python3.12/site-packages/sklearn/utils/estimator_checks.py:3893: in check_estimator_tags_renamed
    assert not hasattr(estimator_orig, "_more_tags"), (
E   AssertionError: ('_more_tags() was removed in 1.6. Please use __sklearn_tags__ instead.',)

We will need to adapt and vendor maybe two code paths in our mne.fixes.BaseEstimator :scream: for stuff like this. It seems very fragile.

My proposal would be to make anything in mne.decoding use sklearn with un-nested imports. With lazy loading this is easy. I think the impact on users should be very minimal -- anyone using decoding can/will probably have sklearn instealled anyway. One possible exception is maybe XDawn which can also be a preprocessing step but I think that's okay.

Assuming that's okay I'll have to figure out some way to handle EmpericalCovariance in mne.cov needing to also be an estimator. But I think that'll be doable.

Thoughts @agramfort @drammock ?

larsoner avatar Sep 06 '24 14:09 larsoner

I have no objection to putting sklearn un-nested only within the decoding module. As you say, lazy loading provides a nice benefit here, in that it shouldn't affect anyone who isn't using/importing from the decoding module.

drammock avatar Sep 06 '24 14:09 drammock