datasets icon indicating copy to clipboard operation
datasets copied to clipboard

TimeoutError in map

Open Jiaxin-Wen opened this issue 1 year ago • 7 comments

Describe the bug

from datasets import Dataset

def worker(example):
    while True:
        continue
    example['a'] = 100
    return example

data = Dataset.from_list([{"a": 1}, {"a": 2}])
data = data.map(worker)
print(data[0])

I'm implementing a worker function whose runtime will depend on specific examples (e.g., while most examples take 0.01s in worker, several examples may take 50s).

Therefore, I would like to know how the current implementation will handle those subprocesses that require a long (e.g., >= 5min) or even infinite time.

I notice that the current implementation set a timeout of 0.05 second https://github.com/huggingface/datasets/blob/c3ddb1ef00334a6f973679a51e783905fbc9ef0b/src/datasets/utils/py_utils.py#L674

However, this example code still gets stuck.

Steps to reproduce the bug

run the example above

Expected behavior

I want to set a default worker to handle these timeout cases, instead of getting stuck

Environment info

main branch version

Jiaxin-Wen avatar Apr 06 '24 06:04 Jiaxin-Wen

From my current understanding, this timeout is only used when we need to get the results.

One of:

  1. All tasks are done
  2. One worker died

Your function should work fine and it's definitely a bug if it doesn't.

Dref360 avatar Apr 08 '24 14:04 Dref360

When one of the map's worker processes crashes, the linked code re-raises an error from the crash and returns it to the caller.

If your question is how to limit the time of long-running tasks/worker processes, such functionality doesn't exist in datasets (yet), which means you need to implement it yourself.

E.g., you can implement it using the built-in signal module like this:

import time
import signal
from contextlib import contextmanager

from datasets import Dataset


@contextmanager
def max_exec_time(t):
  def raise_timeout_handler(signum, frame):
    raise TimeoutError
  
  orig_handler = signal.getsignal(signal.SIGALRM)
  signal.signal(signal.SIGALRM, raise_timeout_handler)
  try:
    signal.alarm(t)
    yield
  finally:
    signal.alarm(0)
    signal.signal(signal.SIGALRM, orig_handler)


def worker(example, rank):
    try:
      with max_exec_time(20):  # 20 sec execution limit
        if rank % 2 == 0:
          time.sleep(50)  # simulate a long-running task
        example["a"] = 100
    except TimeoutError:
      example["a"] = None  # Or return empty batches here in the "batched" mode
    return example

data = Dataset.from_list([{"a": 1}, {"a": 2}])
data = data.map(worker, num_proc=2, with_rank=True)
print(data[0])

mariosasko avatar Apr 11 '24 18:04 mariosasko

From my current understanding, this timeout is only used when we need to get the results.

One of:

  1. All tasks are done
  2. One worker died

Your function should work fine and it's definitely a bug if it doesn't.

thanks for responding! can you reproduce the stuck with the above example code?

Jiaxin-Wen avatar Apr 12 '24 15:04 Jiaxin-Wen

When one of the map's worker processes crashes, the linked code re-raises an error from the crash and returns it to the caller.

If your question is how to limit the time of long-running tasks/worker processes, such functionality doesn't exist in datasets (yet), which means you need to implement it yourself.

E.g., you can implement it using the built-in signal module like this:

import time
import signal
from contextlib import contextmanager

from datasets import Dataset


@contextmanager
def max_exec_time(t):
  def raise_timeout_handler(signum, frame):
    raise TimeoutError
  
  orig_handler = signal.getsignal(signal.SIGALRM)
  signal.signal(signal.SIGALRM, raise_timeout_handler)
  try:
    signal.alarm(t)
    yield
  finally:
    signal.alarm(0)
    signal.signal(signal.SIGALRM, orig_handler)


def worker(example, rank):
    try:
      with max_exec_time(20):  # 20 sec execution limit
        if rank % 2 == 0:
          time.sleep(50)  # simulate a long-running task
        example["a"] = 100
    except TimeoutError:
      example["a"] = None  # Or return empty batches here in the "batched" mode
    return example

data = Dataset.from_list([{"a": 1}, {"a": 2}])
data = data.map(worker, num_proc=2, with_rank=True)
print(data[0])

thanks for responding! However, I don't think we should use signal in the context of multiprocessing since sometimes it will crash one process and raise the following error https://github.com/huggingface/datasets/blob/c3ddb1ef00334a6f973679a51e783905fbc9ef0b/src/datasets/utils/py_utils.py#L664

Jiaxin-Wen avatar Apr 12 '24 15:04 Jiaxin-Wen

thanks for responding! However, I don't think we should use signal in the context of multiprocessing since sometimes it will crash one process and raise the following error

The above code has try/except to catch the error from the handler. Or do you get an error other than TimeoutError?

mariosasko avatar Apr 12 '24 17:04 mariosasko

thanks for responding! However, I don't think we should use signal in the context of multiprocessing since sometimes it will crash one process and raise the following error

The above code has try/except to catch the error from the handler. Or do you get an error other than TimeoutError?

yup, it will raise the RuntimeError: https://github.com/huggingface/datasets/blob/c3ddb1ef00334a6f973679a51e783905fbc9ef0b/src/datasets/utils/py_utils.py#L667C19-L670C22

 raise RuntimeError(
                        "One of the subprocesses has abruptly died during map operation."
                        "To debug the error, disable multiprocessing."
                    )

Jiaxin-Wen avatar Apr 13 '24 06:04 Jiaxin-Wen

What @mariosasko proposed it's very useful for debugging. Thank you!

williamberrios avatar Aug 13 '24 15:08 williamberrios