dask-ml icon indicating copy to clipboard operation
dask-ml copied to clipboard

ParallelPostFit excessive scheduler memory use and CancelledError

Open rikturr opened this issue 3 years ago • 1 comments

Notebook with MCVE and all notes: https://nbviewer.jupyter.org/gist/rikturr/43336377678018d01d4f21f19dd7ef11

When using ParallelPostFit to train with pandas/numpy objects then predict on dask objects, I noticed that the scheduler memory use runs extremely high. Many times I would get a CancelledError and scheduler dying when calling .predict() with pretty small data sizes (refer to notebook for full code with outputs):

X_train, X_test, y_train, y_test = ...
rf = ParallelPostFit(
    RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=-1)
)
_ = rf.fit(X_train, y_train)

preds = rf.predict(X_test)
_ = preds.compute()  # failure happens on this line after ~40 minutes

The scheduler memory balloons to npartitions of X_test * size of rf, which can get into the multiple GBs very fast. I noticed that each time an operation would get called on preds, this memory exchange would happen again. I realize that this is because ParallePostFit uses map_partitions behind the scenes, but does not broadcast the model objects. This causes Dask to send the object through the scheduler to each worker each time you do something with preds (unless of course you persisted it).

Workaround is to broadcast the model object then use map_partitions yourself instead of the ParallelPostFit wrapper:

rf = RandomForestClassifier(n_estimators=100, random_state=seed, n_jobs=-1)
_ = rf.fit(X_train, y_train)

rf_fut = client.scatter(rf, broadcast=True)

def dask_predict(df, model):
    return model.predict(df)

preds = X_test.map_partitions(
    dask_predict,
    model=rf_fut,
    meta=np.array([1])
)

I plan to follow up with a PR to fix this in the ParallelPostFit class

rikturr avatar Jul 01 '21 19:07 rikturr

Just confirmed this happens with a LocalCluster too it is easier to reproduce that way

rikturr avatar Jul 02 '21 16:07 rikturr