ray
ray copied to clipboard
[core][experimental] ray.get of accelerated DAG result may not throw exception for MultiOutputNode
What happened + What you expected to happen
When calling ray.get() on the outputs of an accelerated DAG with a MultiOutputNode, we currently block on the results one at a time. This means that if one of the actor tasks threw an exception, we do not throw it until after we get the results for all of the previous actors' results, according to the order in the MultiOutputNode. This is a problem for tensor-parallel inference, because all of the workers execute in lockstep, and if one actor throws an exception, the others may hang. Depending on the order of the actors, ray.get() may never return.
We should fix this by trying to simultaneously get all of the outputs in the MultiOutputNode. One way is to try to ReadAcquire the channels in round-robin order.
Versions / Dependencies
3.0dev
Reproduction script
--
Issue Severity
None
Marking as p0 because it's important for vLLM use case. But it's okay to release developer preview without it.
@ruisearch42 are you tracking this?
Thanks @anyscalesam . Yes I will work on this.
I have a reproduction script here: https://gist.github.com/kevin85421/a7f14ea38d64420b105fbd79fd31fb8a
with InputNode() as inp:
# fail: raise an exception, sleep: sleep $inp seconds
dag = MultiOutputNode([b.fail.bind(inp), a.sleep.bind(inp)])
compiled_dag = dag.experimental_compile()
ref = compiled_dag.execute(10)
ray.get(ref, timeout=30)
There are two input channels for _dag_output_fetcher: one for b.fail and one for a.sleep.
Current behavior
Although b.fail throws an exception, the read operation still needs to wait until a.sleep finishes (in this example, 10 seconds) or until a timeout occurs (in this example, 30 seconds).
Use two threads to read these two input channels (TL;DR: not recommended)
My first thought is to use two threads to read from these two input channels in _read_list, return immediately if either thread encounters a RayTaskError, and then stop any other threads. I implemented something like below:
def _read_list(self, timeout: Optional[float] = None) -> List[Any]:
results = [None] * len(self._input_channels)
def _read_channel(idx: int, channel: ChannelInterface, timeout: Optional[float]):
results[idx] = channel.read(timeout)
return results[idx]
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for idx, channel in enumerate(self._input_channels):
futures.append(executor.submit(_read_channel, idx, channel, timeout))
for future in concurrent.futures.as_completed(futures):
result = future.result()
# if result is `RayTaskError`, cancel any other threads and return immediately.
return results
However, I found that concurrent.futures.cancel can't cancel running threads. Therefore, we still need to wait 10 seconds for a.sleep to finish before the function can return.
We currently have two options if we use multiple threads:
- Use
threading.Event()to control when other threads stop reading. However, this requires many changes, from Python and Cython to C++. - Use
future.result(timeout=...). However, we have already supported timeout.
In addition, in my benchmark, multi-threading causes several times performance degradation in _read_list if the input size is not large enough, and only shows small performance improvements with larger input sizes.
- TL;DR: not recommended
Recommended solution
-
Use the original
_read_listand check whether the result is aRayTaskErrorafter each read operation. This way, some cases will return earlier iffailis read beforesleep. -
For cases not covered by (1), wait for a timeout.
The drawback is that users need to specify timeout by themselves.
Use the original _read_list and check whether the result is a RayTaskError after each read operation. This way, some cases will return earlier if fail is read before sleep.
Does it work when you do
dag = MultiOutputNode([a.sleep.bind(inp), b.fail.bind(inp)])
?
I think what we need is try batch waiting with short timeout (and keep checking until every objects are ready). similar to how ray.get is implemented in ray. (See Status CoreWorkerMemoryStore::GetImpl)
Does it work when you do
dag = MultiOutputNode([a.sleep.bind(inp), b.fail.bind(inp)])?
No
I think what we need is try batch waiting with short timeout (and keep checking until every objects are ready). similar to how ray.get is implemented in ray. (See
Status CoreWorkerMemoryStore::GetImpl)
Are you referring to https://sourcegraph.com/github.com/ray-project/ray@6309e4be65fe94ed9489f3b035a6ba1215e71095/-/blob/src/ray/core_worker/store_provider/memory_store/memory_store.cc?L364-L375?
Do you mean:
- Using multiple threads to read each object in
MultiOutputNode. - Specifying a shorter timeout for
channel.read. - If
channel.readtimes out, check whether it exceeds the timeout specified by the user. If not, retry.