EconML
EconML copied to clipboard
_BaseGRF.oob_predict does not take use of multi-core processing
I found despite setting n_jobs to -1 or number of CPU cores, the CausalForestDML is still very slow to train. It turns out that it gets stuck in _BaseGRF.oob_predict() since this method is using threading joblib backend and cannot take advantage of multi-core.
I can fix it with the following:
def oob_predict(self, Xtrain: np.ndarray):
...
# Parallel loop
## ORIGINAL CODE SNIPPET responsible for the sluggish
# lock = threading.Lock()
# Parallel(n_jobs=self.n_jobs, verbose=self.verbose, backend='threading', require="sharedmem")(
# delayed(_accumulate_oob_preds)(tree, Xtrain, sinds, alpha_hat, jac_hat, counts, lock)
# for tree, sinds in zip(self.estimators_, subsample_inds))
temp_folder = tempfile.mkdtemp()
filename = os.path.join(temp_folder, 'joblib_test.mmap')
try:
if os.path.exists(filename):
os.unlink(filename)
filename = os.path.join(temp_folder, 'joblib_test.mmap')
# WARNING: this is unfortunate. Xtrain.dtype == `object` which can't be serialized; for us all cols are int/float/bool
_X = Xtrain.astype(np.float32)
_X.tofile(filename)
X_memmap = np.memmap(filename, dtype=_X.dtype, mode='r', shape=_X.shape)
def _accumulate_oob_preds_fast(tree, subsample_inds):
nonlocal X_memmap
mask = np.ones(X_memmap.shape[0], dtype=bool)
mask[subsample_inds] = False
alpha, jac = tree.predict_alpha_and_jac(X_memmap[mask])
return mask, alpha, jac, os.getpid()
job = Parallel(n_jobs=self.n_jobs, backend='loky', return_as='generator')
for mask, alpha, jac, pid in job(
delayed(_accumulate_oob_preds_fast)(tree, sinds)
for tree, sinds in zip(self.estimators_, subsample_inds)):
alpha_hat[mask] += alpha
jac_hat[mask] += jac
counts[mask] += 1
finally:
if os.path.exists(filename):
os.unlink(filename)
Note that memory mapping of the large Xtrain is required for takimg advantage of all cores; else it still runs on 3-4 cores concurrently only if Xtrain is passed via nonlocal reference. However, this unfortunate requires Xtrain.astype(np.float32) for memory-mapping the numpy array. So there may need other changes to this method or caller for a general fix.