scikit-learn icon indicating copy to clipboard operation
scikit-learn copied to clipboard

ENH: Random Forest Classifier oob scaling/parallel

Open tylerjereddy opened this issue 1 year ago • 2 comments
trafficstars

My team, working on a bioinformatics problem with high feature count (columns/dimensions in X), noticed that the RandomForestClassifier out of bag scoring doesn't scale with n_jobs. To be fair, n_jobs clearly says what it does support, though I do wonder if the out of bag predictions under the hood might also benefit from parallel support. Someone on my team seems to have found that it does help, but implemented externally to sklearn using the exposed base estimators. I suppose it might be nice to have that internally at some point, if there are no design reasons not to?

Sample reproducer code with latest stable release (1.3.2) on 16 cores/x86_64 Linux box (i9-13900K) is below the fold, and the scaling plot is underneath that. We also use far more estimators and features than that, so the delta is much greater, but the scaling trend is the main observation in any case.

from time import perf_counter
import numpy as np
# sklearn 1.3.2
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


timings_oob = []
timings_base = []
feature_counts = np.linspace(10, 180_000, 20, dtype=np.int64)

for feature_count in feature_counts:
    X, y = make_classification(n_samples=1_000,
                               n_features=feature_count,
                               random_state=0)
    for use_oob, timing_list in zip([True, False], [timings_oob, timings_base]):
        start = perf_counter()
        clf = RandomForestClassifier(n_estimators=50,
                                     random_state=0,
                                     oob_score=use_oob,
                                     n_jobs=16)
        clf.fit(X, y)
        timing_list.append(perf_counter() - start)

fig, ax = plt.subplots(1, 1)
ax.set_title(f"Random Forest OOB scaling performance")
ax.plot(feature_counts,
        timings_oob,
        label="WITH OOB",
        marker=".")
ax.plot(feature_counts,
        timings_base,
        label="NO oob",
        marker=".")
ax.set_xlabel("num features")
ax.set_ylabel("Time (s)")
ax.legend()
fig.savefig("bench_feat.png", dpi=300)

image

tylerjereddy avatar Jan 03 '24 21:01 tylerjereddy

We are open to speeding this part of the algorithm. I think that we did not dedicate too much attention because this is the main usage but if the changes are straightforward and we get a decent speed-up, I don't see why we would not be inclined at including the changes.

glemaitre avatar Jan 12 '24 20:01 glemaitre

@tylerjereddy just to clarify: Is the issue that the oob predictions are occurring without joblib parallelization?

Do you have a link to your fix, and are interested in submitting a PR? I would be interested in reviewing and getting this performance improvement into main.

Or if not, I am happy to submit a PR if you describe the idea. I also observed recently qualitatively that oob runtime was worse than non-oob.

adam2392 avatar Feb 15 '24 22:02 adam2392