Unable to pickle trained UMAP instance
I am building a composite predictive model that uses UMAP as an embedding component. I implemented the model in a class that stores trained UMAP objects as instance attributes. When I am trying to use an instance of my class with joblib parallel processes I get an exception because my object fails to un-serialize at the process worker. The problem arises only if the UMAP instances are trained, and my code works fine as long as they remain untrained.
I am using umap 0.5.5 cloned from github, numba 0.57.1, and pynndescent 0.5.10
I get the following error message regardless of what I am actually trying to get done on the objects in parallel. The message suggests that the issue is caused by an NNDescent object inside the UMAP objects, and is due to a numba typing error:
Thu Feb 22 13:30:22 2024 Building and compiling search function
Thu Feb 22 13:30:22 2024 Building and compiling search function
Thu Feb 22 13:30:22 2024 Building and compiling search function
Thu Feb 22 13:30:22 2024 Building and compiling search function
Thu Feb 22 13:30:22 2024 Building and compiling search function
joblib.externals.loky.process_executor._RemoteTraceback:
"""
Traceback (most recent call last):
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
call_item = call_queue.get(block=True, timeout=timeout)
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/pynndescent/pynndescent_.py", line 969, in __setstate__
self._init_search_function()
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/pynndescent/pynndescent_.py", line 1366, in _init_search_function
inds, dists, _ = self._search_function(
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
error_rewrite(e, 'typing')
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No conversion from readonly array(float32, 1d, C) to array(float32, 1d, C) for 'current_query', defined at None
File "../mambaforge/envs/umap/lib/python3.10/site-packages/pynndescent/pynndescent_.py", line 1288:
def search_closure(query_points, k, epsilon, visited, rng_state):
<source elided>
else:
current_query = query_points[i]
^
During: typing of intrinsic-call at /home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/pynndescent/pynndescent_.py (1288)
File "../mambaforge/envs/umap/lib/python3.10/site-packages/pynndescent/pynndescent_.py", line 1288:
def search_closure(query_points, k, epsilon, visited, rng_state):
<source elided>
else:
current_query = query_points[i]
^
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home1/istvan/mfl/cross_validate.py", line 178, in cross_validate_model
output = Parallel(n_jobs=n_jobs, return_as="list", **kwargs)(
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/joblib/parallel.py", line 1944, in __call__
return output if self.return_generator else list(output)
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/joblib/parallel.py", line 1587, in _get_outputs
yield from self._retrieve()
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/joblib/parallel.py", line 1691, in _retrieve
self._raise_error_fast()
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/joblib/parallel.py", line 1726, in _raise_error_fast
error_job.get_result(self.timeout)
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/joblib/parallel.py", line 735, in get_result
return self._return_or_raise()
File "/home1/istvan/mambaforge/envs/umap/lib/python3.10/site-packages/joblib/parallel.py", line 753, in _return_or_raise
raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
I tried hard to create a minimal working example, without success so far. I surmise that the issue might arise in rather special circumstances that I am unable to reproduce in a small piece of code. I hope though that the above error message will pinpoint the source of the issue for those who have a deep knowledge of the source code.