[BUG]: fold_groups argument doesn't accept column name string
pycaret version checks
-
[X] I have checked that this issue has not already been reported here.
-
[X] I have confirmed this bug exists on the latest version of pycaret.
-
[ ] I have confirmed this bug exists on the master branch of pycaret (pip install -U git+https://github.com/pycaret/pycaret.git@master).
Issue Description
So my data frame has a column called 'groups'.
when using 'groupkfold' as fold_strategy, I gave 'groups' as fold_groups argument. But when I run compare_models(), the function runs in less than a second and returns an empty list.
I tried giving train_data['groups'] as fold_groups argument instead and it "seems" like it's working. I say it seems like because I don't have any ways to verify whether it's actually doing groupkfold correctly or not.
Also, do I have to give ['groups'] as ignore_features argument too when using groupkfold or not?
Reproducible Example
from pycaret.regression import *
s = setup(
train_data,
target = 'label',
test_data=test_data,
imputation_type=None,
normalize=True,
html=True,
profile=False,
verbose=True,
session_id=(3964),
feature_selection=True,
feature_selection_method='sequential',
n_features_to_select=40,
polynomial_features=False,
remove_multicollinearity=True,
fold_strategy='groupkfold',
fold_groups='groups',
fold=10,
remove_outliers=False,
transformation=False,
ignore_features=['groups']
)
bests = compare_models(turbo=False, n_select=4, sort='R2')
Expected Behavior
compare_models returns trained models.
Actual Results
bests = []
Installed Versions
PyCaret required dependencies: pip: 22.2.2 setuptools: 65.3.0 pycaret: 3.0.0.rc3 IPython: 7.33.0 ipywidgets: 8.0.1 tqdm: 4.64.0 numpy: 1.21.6 pandas: 1.4.3 jinja2: 3.1.2 scipy: 1.8.1 joblib: 1.1.0 sklearn: 1.1.2 pyod: Installed but version unavailable imblearn: 0.9.1 category_encoders: 2.5.0 lightgbm: 3.3.2 numba: 0.55.2 requests: 2.28.1 matplotlib: 3.6.0rc2 scikitplot: 0.3.7 yellowbrick: 1.5 plotly: 5.10.0 kaleido: 0.2.1 statsmodels: 0.13.2 sktime: 0.11.4 tbats: Installed but version unavailable pmdarima: 2.0.1 psutil: 5.9.2