mne-python
mne-python copied to clipboard
TimeFrequency Estimator modifies parameters in constructor
Describe the bug
The mne.decoding.TimeFrequency transformer modifies constructor arguments, violating scikit-learn guidance on estimators. This leads to a cloning error when using the function in a pipeline. I was able to resolve the issue by moving the _check_tfr_param call to the transform method, in line with other checks performed at that time. See the changes made to mne.decoding.time_frequency.py
Steps to reproduce
import mne
import numpy as np
from sklearn import pipeline, linear_model
tfr_data = np.ones((100, 10, 1000))
freqs = np.array([5.])
estimator = pipeline.make_pipeline(
mne.decoding.TimeFrequency(freqs, 10, "morlet", freqs/5., output="power"),
mne.decoding.Vectorizer(),
linear_model.LogisticRegression(),
)
mne.decoding.cross_val_multiscore(estimator, tfr_data, np.random.binomial(1, 0.5, 100))
Expected results
Successful completion of cross validation.
Actual results
Traceback (most recent call last):
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 822, in dispatch_one_batch
tasks = self._ready_batches.get(block=False)
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/queue.py", line 168, in get
raise Empty
_queue.Empty
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3552, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-1fb399c6f0bb>", line 1, in <cell line: 1>
runfile('error.py', wdir='/Users/daniel/Documents/Coding_Projects/GitHub/mne-python')
File "/Users/daniel/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/221.6008.17/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Users/daniel/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/221.6008.17/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "error.py", line 15, in <module>
mne.decoding.cross_val_multiscore(estimator, tfr_data, np.random.binomial(1, 0.5, 100))
File "<decorator-gen-447>", line 12, in cross_val_multiscore
File "/Users/daniel/Documents/Coding_Projects/GitHub/mne-python/mne/decoding/base.py", line 435, in cross_val_multiscore
scores = parallel(
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 1043, in __call__
if self.dispatch_one_batch(iterator):
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/joblib/parallel.py", line 833, in dispatch_one_batch
islice = list(itertools.islice(iterator, big_batch_size))
File "/Users/daniel/Documents/Coding_Projects/GitHub/mne-python/mne/decoding/base.py", line 437, in <genexpr>
estimator=clone(estimator), X=X, y=y, scorer=scorer, train=train,
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 87, in clone
new_object_params[name] = clone(param, safe=False)
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in clone
return estimator_type([clone(e, safe=safe) for e in estimator])
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in <listcomp>
return estimator_type([clone(e, safe=safe) for e in estimator])
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in clone
return estimator_type([clone(e, safe=safe) for e in estimator])
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 65, in <listcomp>
return estimator_type([clone(e, safe=safe) for e in estimator])
File "/Users/daniel/miniconda3/envs/mne-python/lib/python3.10/site-packages/sklearn/base.py", line 96, in clone
raise RuntimeError(
RuntimeError: Cannot clone object TimeFrequency(None), as the constructor either does not set or modifies parameter n_cycles
Additional information
Platform: macOS-11.6.6-x86_64-i386-64bit
Python: 3.10.5 | packaged by conda-forge | (main, Jun 14 2022, 07:09:13) [Clang 13.0.1 ]
Executable: /Users/daniel/miniconda3/envs/mne-python/bin/python
CPU: i386: 4 cores
Memory: 16.0 GB
mne: 0.23.4
numpy: 1.22.4 {blas=NO_ATLAS_INFO, lapack=lapack}
scipy: 1.8.1
matplotlib: 3.5.2 {backend=module://backend_interagg}
sklearn: 1.1.1
numba: 0.55.2
nibabel: 4.0.1
nilearn: 0.6.2
dipy: 1.5.0
cupy: Not found
pandas: 1.4.3
mayavi: 4.8.0
pyvista: 0.35.2 {pyvistaqt=0.9.0, OpenGL 4.1 ATI-4.6.21 via AMD Radeon R9 M295X OpenGL Engine}
vtk:
PyQt5: 5.12.3
Hello! 👋 Thanks for opening your first issue here! ❤️ We will try to get back to you soon. 🚴🏽♂️
@Dod12 agreed this seems like a bug, would you be up for making a PR to fix it? The minimal example above is already a good start for a unit test!
@larsoner Sure, I'll work on the tests over the weekend.