[Block Issue][hip graph]: HIP error: operation not permitted when stream is capturing
Problem Description
Hi, developers
We are AMD developers from AI Group. We use the vllm to run the LLM models and we found there is the hip graph issue when doing the hip graph capture. For now we have a block issue, the part of log is shown as below:
[rank0]:[E1030 15:32:32.970893361 ProcessGroupNCCL.cpp:2055] [PG ID 5 PG GUID 51 Rank 0] Process group watchdog thread terminated with exception: HIP error: operation not permitted when stream is capturing
Search for `hipErrorStreamCaptureUnsupported' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__HIPRT__TYPES.html for more information.
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
Exception raised from c10_hip_check_implementation at /app/pytorch/c10/hip/HIPException.cpp:45 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9c (0x7f789de5f1bc in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x374e1 (0x7f78cec984e1 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_hip.so)
frame #2: c10::hip::c10_hip_check_implementation(int, char const*, char const*, int, bool) + 0x1f1 (0x7f78cec98371 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_hip.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x5e (0x7f78d1bbddde in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_hip.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x90 (0x7f78d1bcdc90 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_hip.so)
frame #5: c10d::ProcessGroupNCCL::Watchdog::runLoop() + 0x9de (0x7f78d1bd1a3e in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_hip.so)
frame #6: c10d::ProcessGroupNCCL::Watchdog::run() + 0x117 (0x7f78d1bd3d27 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_hip.so)
frame #7: <unknown function> + 0xdc253 (0x7f789c2ef253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f78e5d60ac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #9: clone + 0x44 (0x7f78e5df1a74 in /lib/x86_64-linux-gnu/libc.so.6)
The core log is HIP error: operation not permitted when stream is capturing, so we want to know which op cannot be captured by the hip runtime. May I know how can we debug such error?
Thank you.
Operating System
Ubuntu 22.04
CPU
AMD EPYC 9575F 64-Core Processor
GPU
AMD MI355 * 8
ROCm Version
ROCm 7.0.1
ROCm Component
HIP
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
Here are the logs I get with the log env flags. How can I get more info?
51884395::3:hip_event.cpp :487 : 547211151990 us: ihipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884402::3:hip_event.cpp :494 : 547211152018 us: hipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884426::3:hip_event.cpp :487 : 547211152050 us: ihipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884428::3:hip_event.cpp :494 : 547211152061 us: hipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884472::3:hip_event.cpp :487 : 547211152132 us: ihipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884476::3:hip_event.cpp :494 : 547211152139 us: hipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884502::3:hip_event.cpp :487 : 547211152177 us: ihipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884506::3:hip_event.cpp :494 : 547211152181 us: hipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884509::3:hip_event.cpp :487 : 547211152186 us: ihipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884533::3:hip_event.cpp :487 : 547211152212 us: ihipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884540::3:hip_event.cpp :494 : 547211152217 us: hipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884572::3:hip_event.cpp :494 : 547211152192 us: hipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884596::3:hip_event.cpp :487 : 547211152265 us: ihipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884598::3:hip_event.cpp :487 : 547211152282 us: ihipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884599::3:hip_event.cpp :494 : 547211152283 us: hipEventQuery: Returned hipErrorStreamCaptureUnsupported :
51884603::3:hip_event.cpp :494 : 547211152288 us: hipEventQuery: Returned hipErrorStreamCaptureUnsupported :
Hi, @amd-nicknick @ppanchad-amd Could you help take a look? Or do you know who may be familiar with hip graph issue? Thank you.
I deep dive the issue and kick off the issue to the RCCL. https://github.com/ROCm/rccl/issues/2022
@zejunchen-zejun you can get the full log using AMD_LOG_LEVEL=4 with minimal iterations (so that the log is smaller) to know which all APIs are being used in the graph that are returning error. From the snippet, I can see that hipEventQuery is returning error and this API is not permitted during stream capture global mode.
In 7.0, some of the error codes were changed for the operations that are not permitted during stream capture to match CUDA.
Updated my current findings in https://github.com/ROCm/rccl/issues/2022. It looks like a RCCL problem for now, keep tracking there.
Hi, @satyanveshd @amd-nicknick
Let's track the issue here. RCCL is ok for now I think.
On our application side, the code below is calling the torch.dist.all_reduce, which is totally and entirely outside of the cuda graph capture context, so this op should not be captured indeed. However, the error operation not permitted when stream is capturing is coming from this torch.dist.all_reduce op. It looks like the hip runtime has not well maintained the hip graph status. It tries to capture the op it shouldn't capture. https://github.com/ROCm/vllm/blob/dev/perf/vllm/v1/worker/dp_utils.py#L52
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
dist.all_reduce(tensor, group=group)
return tensor
When I add the below code to check if the cuda graph is capture, the print is False, which is expected. While the dist.all_reduce will be captured by cuda graph unexpectedly, where this error happens.
tensor[1][dp_rank] = padded_num_tokens_per_ubatch
tensor[2][dp_rank] = 1 if should_ubatch else 0
tensor[3][dp_rank] = 1 if should_dp_pad else 0
+ print('[zejun] is capture = ', torch.cuda.is_current_stream_capturing(), flush=True)
dist.all_reduce(tensor, group=group)
+ print('[zejun] is capture = ', torch.cuda.is_current_stream_capturing(), flush=True)
return tensor
Thus, we wonder why this op is wrongly captured by cuda graph? Do you have any comments? I suspect if the hip runtime doesn't maintain the cuda graph status well. @wuhuikx
Thank you.
Update for current debug status:
The offending call of VLLM is actually not NCCLInitComm, but is a hipEventQuery during graph capture. This explains why reproducer cannot duplicate but VLLM still can.
Logs collected with VLLM:
:3:hip_event.cpp :489 : 5087797112998 us: [pid:5789 tid: 0x768a00e4b6c0] [32m hipEventQuery ( event:0x527a9ea0 ) [0m
:3:hip_event.cpp :483 : 5087797113036 us: [pid:5789 tid: 0x768a00e4b6c0] ihipEventQuery: Returned hipErrorStreamCaptureUnsupported :
:3:hip_event.cpp :490 : 5087797113043 us: [pid:5789 tid: 0x768a00e4b6c0] hipEventQuery: Returned hipErrorStreamCaptureUnsupported :
The invocation came from c10d's watchdog of NCCL work queue.
https://github.com/pytorch/pytorch/blob/c297b02f12f7fd33bb47447c336acc6a78738a62/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L659
By getting rid of the work queueing and watchdog mechanism with TORCH_NCCL_BLOCKING_WAIT=true, I am able to get a positive result.
Next step: Construct simpler reproducer to further investigate the NCCL work item. When stream is capturing, it shouldn't push the work into watchdog queue as query is not allowed.
It's also curious that Nvidia won't encounter this issue, both HIP and CUDA do not allow QueryEvent during underlying stream capture.
Hi, @amd-nicknick Thank you for help. We have verified the env flag and it works! When using this flag, the hipEventQuery will not be called in torch.dist op. Our application can be launched successfully. It makes perfect sense.
It's also curious that Nvidia won't encounter this issue, both HIP and CUDA do not allow QueryEvent during underlying stream capture.
We suspect we need to figure out why this torch dist op and its underlying hip runtime call are under the cuda graph capture context. It shouldn't be but the cuda graph unexpectedly tries to capture it.
Thank you! Really appreciate your effort for this feature.
Hi @zejunchen-zejun,
Could you please help test the following script on a Nvidia system? This script is able to repro the problem on AMD platform, just like to make sure the script passes on NV.
import os
import sys
import time
import torch
import torch.distributed as dist
import multiprocessing as mp
def worker(rank, world_size):
print(f"Rank {rank}: PID {os.getpid()}")
#time.sleep(5)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29502'
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(rank)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
backend = "nccl"
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
tensor = torch.ones((65536, 65536), device=device, dtype=torch.float32)
group = dist.group.WORLD
graph = torch.cuda.CUDAGraph()
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
print(f"Rank {rank}: Begin NCCL reduce", file=sys.stderr)
for i in range(1,100):
dist.all_reduce(tensor, group=group)
print(f"Rank {rank}: End NCCL reduce", file=sys.stderr)
try:
with torch.cuda.graph(graph):
print(f"Rank {rank}: Begin Capture", file=sys.stderr)
time.sleep(3)
print(f"Rank {rank}: Unexpected success", file=sys.stderr)
print(f"Rank {rank}: {tensor}", file=sys.stderr)
except Exception as e:
print(f"Rank {rank}: Expected failure - {e}", file=sys.stderr)
dist.destroy_process_group()
def main():
print(os.getpid())
#time.sleep(5)
world_size = 1
if torch.cuda.device_count() < world_size:
print(f"Error: {world_size} GPUs required, but only {torch.cuda.device_count()} available")
return
ctx = mp.get_context("spawn")
processes = []
for rank in range(world_size):
p = ctx.Process(target=worker, args=(rank, world_size))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
main()
It would be great if you could adjust the world_size and test against 1 and 8 card combination, as this issue is quite timing sensitive.
(Detail to why it's timing sensitive: The violation came from NCCL watchdog thread, which runs periodically. The loop doing all_reduce attempts to rake up unfinished NCCL task for watchdog to dequeue, but the dequeue must not finish before stream begins capture to repro the violation).
Hi, @amd-nicknick
Thank you for help! First of all, we will run your reproducer on B200 from world size 1 to 8 to test the B200's behavior. We will check if NV will report the same error. Then, your suspicion on this point makes perfect sense because previously we add the rocm gdb and run the model to try to get any useful log during the error happen, but the issue is gone under this circumstance, so we also suspect the issue is time-sensitive. You mean the NCCL has a watch dog thread which will trigger the hip event query during the cuda graph capture? We agree it could happen, when NCCL queries the hip event and the cuda graph failed for this unsupported runtime op.
BTW, does this NCCL watch dog thread only work for ROCm device? I think for NV and ROCm, the watch dog behavior could be same. Is it possible that NV has some specific mechanism that blocking the runtime event query under the cuda graph capture?
for i in range(1,100):
dist.all_reduce(tensor, group=group)
print(f"Rank {rank}: End NCCL reduce", file=sys.stderr)
try:
with torch.cuda.graph(graph):
print(f"Rank {rank}: Begin Capture", file=sys.stderr)
time.sleep(3)
print(f"Rank {rank}: Unexpected success", file=sys.stderr)
print(f"Rank {rank}: {tensor}", file=sys.stderr)
Let me try on B200 and give the feedback this week here.
The behavior of NV vs AMD on Watchdog is exactly the same, so if the reproducer will not fail on NV, it is likely a HIP runtime problem we need to align. I think a lot of condition need to be in place for this error to trigger:
- A non-null stream context must be configured and in use for NCCL to run "synchronously" before capture (See te
with torch.cuda.stream(stream):block being moved to encapsulate the NCCL ops). - A separated thread vs the one who launched the NCCL task to query said task on the same stream.
- Before everything gets cleared out on the watchdog thread, graph capture begins on the original thread & stream.
My suspicion is HIP is not treating the per-thread capture mode flag correctly when the
hipQueryEventoccurred.
Thanks for helping out verify the cause, kindly let me know the result and we can see where to proceed from here :)
Hi, @amd-nicknick @satyanveshd Sorry for the late response. I verified the reproducer on the B200 machine and here is the log:
root@GPUA81E:/home/zejchen/graph_issue# python -u reproducer.py
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
2702
/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
Rank 0: PID 2768
Rank 0: Begin NCCL reduce
Rank 0: End NCCL reduce
Rank 0: Begin Capture
/usr/local/lib/python3.12/dist-packages/torch/cuda/graphs.py:128: UserWarning: The CUDA Graph is empty. This usually means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/cuda/CUDAGraph.cpp:138.)
super().capture_end()
Rank 0: Unexpected success
Rank 0: tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0')
[W1126 13:53:46.148684455 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
[W1126 13:53:47.778238721 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
It looks like the CUDA will not fail! It finished successfully.
So the problem could be hip runtime specific. It looks like the hip runtime cannot distinguish the threads very well. There are 2 threads, one is doing graph capture, another is watch dog querying event status. The hip graph starts capture but it tries to capture the hipQueryEvent, which is called from another thread, then the error happened. Is my understanding right?
How can we further debug the issue? Or who we can mention here to take a look at the graph capture behavior?
Thank you.
I think the thread-local capture mode is indeed ignored in some cases. Added a fix here: https://github.com/ROCm/rocm-systems/pull/2177
Hi, @iassiour @amd-nicknick
Thank you so much for having a fix. May I know if the fix can pass the reproducer here? BTW, how long will we get your fix? For the next ROCm release?
Thank you.
Hi @wuhuikx @ganyi1996ppo With the fix here, there is no need to do the WA in framework side for DP mode! https://github.com/ROCm/vllm/blob/dev/perf/vllm/v1/worker/dp_utils.py#L54-L64
@zejunchen-zejun Yes the fix passes the reproducer, I think this is unlikely to be included in the next ROCm release, but should be there in future releases after that.
Thank you @iassiour When it is included in the ROCm release, we will verify your fix and remove the work around in our application level. Thank you for help!