TimeoutError in map
Describe the bug
from datasets import Dataset
def worker(example):
while True:
continue
example['a'] = 100
return example
data = Dataset.from_list([{"a": 1}, {"a": 2}])
data = data.map(worker)
print(data[0])
I'm implementing a worker function whose runtime will depend on specific examples (e.g., while most examples take 0.01s in worker, several examples may take 50s).
Therefore, I would like to know how the current implementation will handle those subprocesses that require a long (e.g., >= 5min) or even infinite time.
I notice that the current implementation set a timeout of 0.05 second https://github.com/huggingface/datasets/blob/c3ddb1ef00334a6f973679a51e783905fbc9ef0b/src/datasets/utils/py_utils.py#L674
However, this example code still gets stuck.
Steps to reproduce the bug
run the example above
Expected behavior
I want to set a default worker to handle these timeout cases, instead of getting stuck
Environment info
main branch version
From my current understanding, this timeout is only used when we need to get the results.
One of:
- All tasks are done
- One worker died
Your function should work fine and it's definitely a bug if it doesn't.
When one of the map's worker processes crashes, the linked code re-raises an error from the crash and returns it to the caller.
If your question is how to limit the time of long-running tasks/worker processes, such functionality doesn't exist in datasets (yet), which means you need to implement it yourself.
E.g., you can implement it using the built-in signal module like this:
import time
import signal
from contextlib import contextmanager
from datasets import Dataset
@contextmanager
def max_exec_time(t):
def raise_timeout_handler(signum, frame):
raise TimeoutError
orig_handler = signal.getsignal(signal.SIGALRM)
signal.signal(signal.SIGALRM, raise_timeout_handler)
try:
signal.alarm(t)
yield
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, orig_handler)
def worker(example, rank):
try:
with max_exec_time(20): # 20 sec execution limit
if rank % 2 == 0:
time.sleep(50) # simulate a long-running task
example["a"] = 100
except TimeoutError:
example["a"] = None # Or return empty batches here in the "batched" mode
return example
data = Dataset.from_list([{"a": 1}, {"a": 2}])
data = data.map(worker, num_proc=2, with_rank=True)
print(data[0])
From my current understanding, this timeout is only used when we need to get the results.
One of:
- All tasks are done
- One worker died
Your function should work fine and it's definitely a bug if it doesn't.
thanks for responding! can you reproduce the stuck with the above example code?
When one of the
map's worker processes crashes, the linked code re-raises an error from the crash and returns it to the caller.If your question is how to limit the time of long-running tasks/worker processes, such functionality doesn't exist in
datasets(yet), which means you need to implement it yourself.E.g., you can implement it using the built-in
signalmodule like this:import time import signal from contextlib import contextmanager from datasets import Dataset @contextmanager def max_exec_time(t): def raise_timeout_handler(signum, frame): raise TimeoutError orig_handler = signal.getsignal(signal.SIGALRM) signal.signal(signal.SIGALRM, raise_timeout_handler) try: signal.alarm(t) yield finally: signal.alarm(0) signal.signal(signal.SIGALRM, orig_handler) def worker(example, rank): try: with max_exec_time(20): # 20 sec execution limit if rank % 2 == 0: time.sleep(50) # simulate a long-running task example["a"] = 100 except TimeoutError: example["a"] = None # Or return empty batches here in the "batched" mode return example data = Dataset.from_list([{"a": 1}, {"a": 2}]) data = data.map(worker, num_proc=2, with_rank=True) print(data[0])
thanks for responding! However, I don't think we should use signal in the context of multiprocessing since sometimes it will crash one process and raise the following error
https://github.com/huggingface/datasets/blob/c3ddb1ef00334a6f973679a51e783905fbc9ef0b/src/datasets/utils/py_utils.py#L664
thanks for responding! However, I don't think we should use signal in the context of multiprocessing since sometimes it will crash one process and raise the following error
The above code has try/except to catch the error from the handler. Or do you get an error other than TimeoutError?
thanks for responding! However, I don't think we should use signal in the context of multiprocessing since sometimes it will crash one process and raise the following error
The above code has
try/exceptto catch the error from the handler. Or do you get an error other thanTimeoutError?
yup, it will raise the RuntimeError: https://github.com/huggingface/datasets/blob/c3ddb1ef00334a6f973679a51e783905fbc9ef0b/src/datasets/utils/py_utils.py#L667C19-L670C22
raise RuntimeError(
"One of the subprocesses has abruptly died during map operation."
"To debug the error, disable multiprocessing."
)
What @mariosasko proposed it's very useful for debugging. Thank you!