marimo
marimo copied to clipboard
multiprocessing.Pool.map hangs (joblib.Parallel/delayed works)
Describe the bug
I'm trying to use multiprocessing.Pool.map. I have a reproducer below for hanging, locally I could get it to work but then ignore updates to the called function as well.
Running this notebook the 4th cell hangs and when interrupted (it is responsive) the stack is:
Traceback (most recent call last):
File "/home/alon/src/marimo/marimo/_runtime/runner/cell_runner.py", line 302, in run
return_value = execute_cell(cell, self.glbls)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/alon/src/marimo/marimo/_ast/cell.py", line 445, in execute_cell
return eval(cell.last_expr, glbls)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/marimo_1792145/__marimo__cell_gXSm__output.py", line 1, in <module>
list(pool.map(f, [1, 2, 3]))
^^^^^^^^^^^^^^^^^^^^
File "/usr/lib64/python3.12/multiprocessing/pool.py", line 367, in map
return self._map_async(func, iterable, mapstar, chunksize).get()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib64/python3.12/multiprocessing/pool.py", line 768, in get
self.wait(timeout)
File "/usr/lib64/python3.12/multiprocessing/pool.py", line 765, in wait
self._event.wait(timeout)
File "/usr/lib64/python3.12/threading.py", line 655, in wait
signaled = self._cond.wait(timeout)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib64/python3.12/threading.py", line 355, in wait
waiter.acquire()
File "/home/alon/src/marimo/marimo/_runtime/handlers.py", line 31, in interrupt_handler
raise MarimoInterrupt
marimo._runtime.control_flow.MarimoInterrupt
Environment
{
"marimo": "0.5.2",
"OS": "Linux",
"OS Version": "6.8.7-300.fc40.x86_64",
"Processor": "",
"Python Version": "3.12.3",
"Binaries": {
"Browser": "--",
"Node": "v20.12.2"
},
"Requirements": {
"click": "8.1.7",
"importlib-resources": "missing",
"jedi": "0.19.1",
"markdown": "3.6",
"pymdown-extensions": "10.8.1",
"pygments": "2.18.0",
"tomlkit": "0.12.4",
"uvicorn": "0.29.0",
"starlette": "0.37.2",
"websocket": "missing",
"typing-extensions": "4.11.0",
"black": "24.4.2"
}
}
Code to reproduce
import marimo
__generated_with = "0.5.3"
app = marimo.App()
@app.cell
def __():
from multiprocessing import Pool
return Pool,
@app.cell
def __(Pool):
pool = Pool()
return pool,
@app.cell
def __():
def f(x):
return 10 + x
return f,
@app.cell
def __(f, pool):
list(pool.map(f, [1,2,3]))
return
if __name__ == "__main__":
app.run()
When trying to submit this bug I managed to reproduce the second problem, namely staleness:
- case one, above: cell 4 fails to complete running, interrupting shows above stack trace.
- case two: same notebook. cell 4 runs. But then going to cell 3 and changing the function results in no change in the output (the marimo DAG logic is fine, but Pool.map runs the old function).
Note: For reproduction it would be nice to be able to simulate the complete lifecycle - I think you can do that from the test code, but doing it from within a notebook (even with a 'beware-api-quicksand" warning, i.e. marimo.edit_cell_by_id(id, new_code) would be nice.
Note 2: The above hang was done using the current git version, could be related. But I suspect some shared memory related issue since when I closed the notebook I got a python warning about that (I failed to copy it).
Note 3: joblib's Parallel/delayed seems to work fine
Thanks for the thorough bug reports -- will look into it.
I have a reproducer below for hanging,
Does it always hang, or only sometimes? I'm unable to reproduce the hanging on my machine unfortunately so far
When trying to submit this bug I managed to reproduce the second problem, namely staleness:
- case one, above: cell 4 fails to complete running, interrupting shows above stack trace.
- case two: same notebook. cell 4 runs. But then going to cell 3 and changing the function results in no change in the > output (the marimo DAG logic is fine, but Pool.map runs the old function).
Similarly is this something you can reproduce consistently, or only sometimes? I also couldn't reproduce this :/
EDIT: Just kidding -- seeing the staleness issue now ...
If I recreate the process pool, it uses the latest value of the function. It also doesn't hang.
I would suggest using the pool context manager, that way you don't have to think about managing your pool and recreating it:
import marimo
__generated_with = "0.5.2"
app = marimo.App()
@app.cell
def __():
from multiprocessing import Pool
return Pool,
@app.cell
def __():
def f(x):
return 10 + x
return f,
@app.cell
def __(Pool, f):
with Pool() as pool:
outputs = list(pool.map(f, [1,2,3]))
outputs
return outputs, pool
if __name__ == "__main__":
app.run()
multiprocessing support in interactive environments isn't well supported (I checked, and it doesn't work in Jupyter).
For what it's worth, here's a Python script that fails in an analogous way:
from multiprocessing import Pool
def f(x):
return x + 10
if __name__ == "__main__":
pool = Pool()
print(list(pool.map(f, [1, 2, 3])))
# try to redefine `f` -- the pool won't pick it up.
def f(x):
return x + 11
# uses the "old" value of `f`
print(list(pool.map(f, [1, 2, 3])))
# try to call `g` -- the pool will hang
def g(x):
return x + 11
print(list(pool.map(g, [1, 2, 3])))
print("I won't be printed")
In summary, my understanding is that Pool takes a snapshot of the __main__ module at construction time. So any changes made to the kernel state after its creation won't be discoverable by it.
I don't think there's anything we can do to fix this, or even fail gracefully.
Thanks, that works, it is still faster than joblib.Parallel(backend='multiprocessing') this way, and probably what joblib did that caused it to work. You can close this (or should I? not sure what the workflow you prefer is)
Great, thanks for confirming. I'll close the issue.