Serialization problem with GridSearchCV
Hello, I am a beginner with GPU accelerated computing and I can’t find what is wrong with my code. I am getting this serialization error and don’t understand why.
import numpy as np
import cudf
from dask.distributed import Client
from sklearn.metrics import classification_report
import pandas as pd
from dask_cuda import LocalCUDACluster
from cuml.dask.ensemble import RandomForestClassifier as cuRF
import dask_cudf
from cuml.dask.common.utils import persist_across_workers
import pickle
import cloudpickle
import dask_ml.model_selection as dcv
def generate_synthetic_data(n_samples=10000):
np.random.seed(42)
Y = np.random.randn(n_samples)
A = np.random.randint(0, 5, size=n_samples)
B = np.random.randint(0, 5, size=n_samples)
C = np.random.randint(1, 3, size=n_samples)
DATE = pd.date_range(start='1/1/2022', periods=n_samples, freq='min')
data = {
'A': A,
'B': B,
'C': C,
'DATE': DATE,
'Y': Y,
}
return pd.DataFrame(data)
def main():
# Initialize Dask client for GPU with LocalCUDACluster
cluster = LocalCUDACluster()
client = Client(cluster)
# Load and preprocess data
df_data = generate_synthetic_data()
# Data preprocessing
df_data['DATE'] = pd.to_datetime(df_data['DATE'], errors='coerce')
df_data.fillna(0, inplace=True)
df_data['C'] = df_data['C'].astype(float)
df_data.drop_duplicates(inplace=True)
df_data = df_data.loc[df_data['C'] == 1]
df_data['Y_category'] = df_data['Y'].apply(lambda x: 'over 0' if x > 0 else ('under 0' if x < 0 else 'equal to 0'))
df_encoded = df_data.drop(columns=['Y'])
label_mapping = {'over 0': 2, 'under 0': 1, 'equal to 0': 0}
df_encoded['Y_category'] = df_encoded['Y_category'].map(label_mapping)
df_encoded.sort_values(by='DATE', inplace=True)
df_encoded = cudf.DataFrame.from_pandas(df_encoded)
# Split data into features and target
X = df_encoded.drop(columns=['DATE', 'Y_category']).astype('float32')
y = df_encoded['Y_category'].astype('int32')
split_point = int(len(df_encoded) * 0.8)
X_train, X_test = X.iloc[:split_point], X.iloc[split_point:]
y_train, y_test = y.iloc[:split_point], y.iloc[split_point:]
# Balance the classes using undersampling
y_train_counts = y_train.value_counts().to_pandas()
min_samples = y_train_counts.min()
sampled_indices = []
for label in y_train_counts.index:
indices = y_train[y_train == label].index.to_pandas().to_series()
sampled = indices.sample(n=min_samples, random_state=42).tolist()
sampled_indices.extend(sampled)
sampled_indices = np.array(sampled_indices)
# Ensure indices are unique and within bounds
sampled_indices = np.unique(sampled_indices)
sampled_indices = sampled_indices[sampled_indices < len(X_train)]
X_train_balanced = X_train.iloc[sampled_indices]
y_train_balanced = y_train.iloc[sampled_indices]
# Convert to Dask DataFrame directly
X_train_dask = dask_cudf.from_cudf(X_train_balanced, npartitions=10).persist(optimize_graph=True)
y_train_dask = dask_cudf.from_cudf(y_train_balanced, npartitions=10).persist(optimize_graph=True)
X_train_dask, y_train_dask = persist_across_workers(client,
[X_train_dask,
y_train_dask])
#Define the parameter grid
param_grid = {
'max_depth': [10, 20, 30],
'max_features': [0.1, 0.5, 0.75, "auto"],
'n_estimators': [10, 20, 30]
}
# Initialize and fit the model using GridSearchCV
model_rf = cuRF(random_state=42)
grid_search = dcv.GridSearchCV(model_rf, param_grid, cv=5, scoring='f1_weighted')
grid_search.fit(X_train_dask, y_train_dask) # Fit with Dask arrays
# Train the model with the best parameters
best_rf = cuRF(**grid_search.best_params_, random_state=42)
best_rf.fit(X_train_dask, y_train_dask)
# Predict on the test set
X_test_dask = dask_cudf.from_cudf(X_test, npartitions=1).to_dask_array(lengths=True)
y_pred_best = best_rf.predict(X_test_dask)
# Evaluate the model
report_best = classification_report(y_test.to_pandas(), y_pred_best.compute().get())
print(report_best)
if __name__ == "__main__":
main()
The error I get is this
/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/dask_expr/_collection.py:301: UserWarning: Dask annotations {'workers': ['tcp://127.0.0.1:45649']} detected. Annotations will be ignored when using query-planning.
warnings.warn(
/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cuml/internals/api_decorators.py:344: UserWarning: For reproducible results in Random Forest Classifier or for almost reproducible results in Random Forest Regressor, n_streams=1 is recommended. If n_streams is > 1, results may vary due to stream/thread timing differences, even when random_state is set
return func(**kwargs)
2024-08-20 14:17:44,015 - distributed.protocol.pickle - ERROR - Failed to serialize <ToPickle: HighLevelGraph with 1 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f18e4fadff0>
0. 139744897497344
>.
Traceback (most recent call last):
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 63, in dumps
result = pickle.dumps(x, **dump_kwargs)
_pickle.PicklingError: Can't pickle <function _concat at 0x7f19b0cfe050>: it's not the same object as dask.dataframe.core._concat
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 68, in dumps
pickler.dump(x)
_pickle.PicklingError: Can't pickle <function _concat at 0x7f19b0cfe050>: it's not the same object as dask.dataframe.core._concat
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 81, in dumps
result = cloudpickle.dumps(x, **dump_kwargs)
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
cp.dump(obj)
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
return super().dump(obj)
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cuml/dask/common/base.py", line 60, in __getstate__
internal_model = self._get_internal_model().result()
AttributeError: 'NoneType' object has no attribute 'result'
Traceback (most recent call last):
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 63, in dumps
result = pickle.dumps(x, **dump_kwargs)
_pickle.PicklingError: Can't pickle <function _concat at 0x7f19b0cfe050>: it's not the same object as dask.dataframe.core._concat
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 68, in dumps
pickler.dump(x)
_pickle.PicklingError: Can't pickle <function _concat at 0x7f19b0cfe050>: it's not the same object as dask.dataframe.core._concat
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/serialize.py", line 366, in serialize
header, frames = dumps(x, context=context) if wants_context else dumps(x)
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/serialize.py", line 78, in pickle_dumps
frames[0] = pickle.dumps(
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 81, in dumps
result = cloudpickle.dumps(x, **dump_kwargs)
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 1479, in dumps
cp.dump(obj)
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cloudpickle/cloudpickle.py", line 1245, in dump
return super().dump(obj)
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/cuml/dask/common/base.py", line 60, in __getstate__
internal_model = self._get_internal_model().result()
AttributeError: 'NoneType' object has no attribute 'result'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/mnt/c/Users/user/PythonProjects/Snowflake/Stats/Clean RF - Forums.py", line 164, in <module>
main()
File "/mnt/c/Users/user/PythonProjects/Snowflake/Stats/Clean RF - Forums.py", line 149, in main
grid_search.fit(X_train_dask, y_train_dask) # Fit with Dask arrays
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/dask_ml/model_selection/_search.py", line 1266, in fit
futures = scheduler(
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/client.py", line 3456, in get
futures = self._graph_to_futures(
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/client.py", line 3351, in _graph_to_futures
header, frames = serialize(ToPickle(dsk), on_error="raise")
File "/home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/protocol/serialize.py", line 392, in serialize
raise TypeError(msg, str_x) from exc
TypeError: ('Could not serialize object of type HighLevelGraph', '<ToPickle: HighLevelGraph with 1 layers.\n<dask.highlevelgraph.HighLevelGraph object at 0x7f18e4fadff0>\n 0. 139744897497344\n>')
2024-08-20 14:17:44,026 - distributed.worker.state_machine - WARNING - Async instruction for <Task cancelled name="execute(('frompandas-a91f82e90b4590ee2b57246953d5e528', 7))" coro=<Worker.execute() done, defined at /home/user/miniconda3/envs/ML/lib/python3.10/site-packages/distributed/worker_state_machine.py:3615>> ended with CancelledError
2024-08-20 14:17:44,027 - distributed.scheduler - WARNING - Removing worker 'tcp://127.0.0.1:45649' caused the cluster to lose already computed task(s), which will be recomputed elsewhere: {('frompandas-9de00375790a2a30c476ba68f4ea2723', 8), '_construct_rf-44891487-927e-43da-8d8d-2bd8190520f0', ('frompandas-760adf4c445d5f36154f13a21c47df23', 9), ('frompandas-9de00375790a2a30c476ba68f4ea2723', 7), ('frompandas-760adf4c445d5f36154f13a21c47df23', 8), ('frompandas-9de00375790a2a30c476ba68f4ea2723', 9)} (stimulus_id='handle-worker-cleanup-1724177864.0271368')
I created a case on dask forum but they told me to put my problem here since it seemed to be caused by an incompatibility between dask-cuda and dask-ml.
Here are the info on the system I use:
Python version: 3.10.9 | packaged by conda-forge | (main, Feb 2 2023, 20:20:04) [GCC 11.3.0] CUDA version: nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Mar_28_02:18:24_PDT_2024 Cuda compilation tools, release 12.4, V12.4.131 Build cuda_12.4.r12.4/compiler.34097967_0
cuDF version: 24.08.00a405
cuML version: 24.08.00a50
Dask version: 2024.7.1
CUDA version (nvidia-smi):
Tue Aug 20 14:17:43 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.106 Driver Version: 552.86 CUDA Version: 12.4 |
Also, I am using a WSL2 environment.
Thanks a lot for your help!!
I just changed version to cuDF version: 24.08.02 cuML version: 24.08.00
but it still does not work
Thanks @hchekired, I have moved the issue to cuml. It looks like a cuml issue, for some reason, _get_internal_model() returns None when dask uses cloudpickle.dumps() to serialize cuml's BaseEstimator.
I won't be around for the next couple of weeks, but @viclafargue will take a look here
Hello @viclafargue, hope everything is going well for you. Do you have an idea where the problem comes from?
Thanks.
Hello @hchekired, sorry for the late reply. Thanks for opening the issue. I could reproduce it successfully. It looks like the GridSearchCV estimator is serializing the estimator prior to training which causes a bug. I will open a PR to fix the serialization of MNMG estimators prior to training. There is also another issue that I am looking into to make things work. However, I would recommend to either use sklearn's GridSearchCV with cuML dask estimators, or (if a single GPU can handle it) dask-ml's GridSearchCV with local cuML estimators.
Thanks for the reply, let me know if find something else to make things work.
Thanks
I fixed the issue that prevented serialization prior to training. But again, I am not quite sure if it is a good idea to use a Dask estimator with a Dask GridSearchCV. Maybe you should try either one of them in Dask and the other without. I will look into solving this when I have more time.
Hello @viclafargue, thanks for fixing the issue, how can I have access to the corrected version?
Also, why it is not a good idea to us a Dask estimator with a Dask GridSearchCV?
Thanks!