[Bug]: vLLM sleep experiences segmentation fault when used in TRL
Your current environment
I am utilizing vLLM sleep in HF - TRL to efficiently manage GPU memory between training and generation. See my draft PR. The training is completed successfully, but I see a segmentation fault error at the end. I am not seeing this error in vLLM==0.7.3, but observing the error in 0.8.0, 0.8.1, and so on.
The logs of my run `err.log`
2025-04-22 15:34:26 - INFO - __main__ - *** Save model ***
[INFO|trainer.py:3984] 2025-04-22 15:34:28,362 >> Saving model checkpoint to trainer_output
[INFO|configuration_utils.py:419] 2025-04-22 15:34:28,366 >> Configuration saved in trainer_output/config.json
[INFO|configuration_utils.py:911] 2025-04-22 15:34:28,367 >> Configuration saved in trainer_output/generation_config.json
[rank4]:[W422 15:34:30.032759954 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank6]:[W422 15:34:30.076030919 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W422 15:34:30.136720629 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[INFO|modeling_utils.py:3572] 2025-04-22 15:34:31,111 >> Model weights saved in trainer_output/model.safetensors
[INFO|tokenization_utils_base.py:2510] 2025-04-22 15:34:31,113 >> tokenizer config file saved in trainer_output/tokenizer_config.json
[INFO|tokenization_utils_base.py:2519] 2025-04-22 15:34:31,114 >> Special tokens file saved in trainer_output/special_tokens_map.json
2025-04-22 15:34:31 - INFO - __main__ - Model saved to trainer_output
[INFO|configuration_utils.py:419] 2025-04-22 15:34:31,369 >> Configuration saved in trainer_output/config.json
[grpo-experiment-mert-experiments-master-0:19399:0:19399] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
==== backtrace (tid: 19399) ====
0 0x0000000000042520 __sigaction() ???:0
=================================
terminate called after throwing an instance of 'c10::Error'
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8f2196c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f8f21915b3f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7f8ed248e667 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7f8f21d45b78 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7f8f21d4620e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39afa (0x7f8f21d5cafa in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7f8f21d48329 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdf74f0 (0x7f8f19a864f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x516907 (0x7f8f191a5907 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5174d1 (0x7f8f191a64d1 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x172fd1 (0x55f46de26fd1 in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x55f46dde9c52 in /usr/bin/python)
frame #12: <unknown function> + 0x136991 (0x55f46ddea991 in /usr/bin/python)
frame #13: <unknown function> + 0x13678c (0x55f46ddea78c in /usr/bin/python)
frame #14: <unknown function> + 0x136877 (0x55f46ddea877 in /usr/bin/python)
frame #15: <unknown function> + 0x172aa0 (0x55f46de26aa0 in /usr/bin/python)
frame #16: <unknown function> + 0x1caa87 (0x55f46de7ea87 in /usr/bin/python)
frame #17: <unknown function> + 0x129ebc (0x55f46ddddebc in /usr/bin/python)
frame #18: <unknown function> + 0x264970 (0x55f46df18970 in /usr/bin/python)
frame #19: Py_FinalizeEx + 0x148 (0x55f46df144c8 in /usr/bin/python)
frame #20: Py_RunMain + 0x173 (0x55f46df05913 in /usr/bin/python)
frame #21: Py_BytesMain + 0x2d (0x55f46dedc02d in /usr/bin/python)
frame #22: <unknown function> + 0x29d90 (0x7f8f5a34bd90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #23: __libc_start_main + 0x80 (0x7f8f5a34be40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: _start + 0x25 (0x55f46dedbf25 in /usr/bin/python)
[grpo-experiment-mert-experiments-master-0:19394:0:19394] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
terminate called after throwing an instance of 'c10::Error'
[grpo-experiment-mert-experiments-master-0:19398:0:19398] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x400)
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fee9536c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7fee95315b3f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7fee45e8e667 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7fee95745b78 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7fee9574620e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39afa (0x7fee9575cafa in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7fee95748329 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdf74f0 (0x7fee8d4864f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x516907 (0x7fee8cba5907 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5174d1 (0x7fee8cba64d1 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x172fd1 (0x564c79c2afd1 in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x564c79bedc52 in /usr/bin/python)
frame #12: <unknown function> + 0x136991 (0x564c79bee991 in /usr/bin/python)
frame #13: <unknown function> + 0x13678c (0x564c79bee78c in /usr/bin/python)
frame #14: <unknown function> + 0x136877 (0x564c79bee877 in /usr/bin/python)
frame #15: <unknown function> + 0x172aa0 (0x564c79c2aaa0 in /usr/bin/python)
frame #16: <unknown function> + 0x1caa87 (0x564c79c82a87 in /usr/bin/python)
frame #17: <unknown function> + 0x129ebc (0x564c79be1ebc in /usr/bin/python)
frame #18: <unknown function> + 0x264970 (0x564c79d1c970 in /usr/bin/python)
frame #19: Py_FinalizeEx + 0x148 (0x564c79d184c8 in /usr/bin/python)
frame #20: Py_RunMain + 0x173 (0x564c79d09913 in /usr/bin/python)
frame #21: Py_BytesMain + 0x2d (0x564c79ce002d in /usr/bin/python)
frame #22: <unknown function> + 0x29d90 (0x7feecdd44d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #23: __libc_start_main + 0x80 (0x7feecdd44e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: _start + 0x25 (0x564c79cdff25 in /usr/bin/python)
[grpo-experiment-mert-experiments-master-0:19395:0:19395] Caught signal 11 (Segmentation fault: Sent by the kernel at address (nil))
==== backtrace (tid: 19394) ====
0 0x0000000000042520 __sigaction() ???:0
=================================
[grpo-experiment-mert-experiments-master-0:19397:0:19397] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
==== backtrace (tid: 19395) ====
0 0x0000000000042520 __sigaction() ???:0
1 0x0000000000022b76 c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::release_block() CUDACachingAllocator.cpp:0
2 0x000000000002320e c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::release_blocks() CUDACachingAllocator.cpp:0
3 0x0000000000039afa c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::release_cached_blocks() :0
4 0x0000000000025329 c10::cuda::MemPool::~MemPool() ???:0
5 0x0000000000df74f0 pybind11::class_<c10::cuda::MemPool, std::shared_ptr<c10::cuda::MemPool> >::dealloc() :0
6 0x0000000000516907 pybind11::detail::clear_instance() :0
7 0x00000000005174d1 pybind11_object_dealloc() :0
8 0x0000000000172fd1 PyObject_DelItem() ???:0
9 0x0000000000135c52 _Py_CheckFunctionResult() ???:0
10 0x0000000000136991 _Py_CheckFunctionResult() ???:0
11 0x000000000013678c _Py_CheckFunctionResult() ???:0
12 0x0000000000136877 _Py_CheckFunctionResult() ???:0
13 0x0000000000172aa0 PyObject_DelItem() ???:0
14 0x00000000001caa87 PyDict_Clear() ???:0
15 0x0000000000129ebc PyObject_GC_Del() ???:0
16 0x0000000000264970 PyMarshal_ReadLongFromFile() ???:0
17 0x00000000002604c8 Py_FinalizeEx() ???:0
18 0x0000000000251913 Py_RunMain() ???:0
19 0x000000000022802d Py_BytesMain() ???:0
20 0x0000000000029d90 __libc_init_first() ???:0
21 0x0000000000029e40 __libc_start_main() ???:0
22 0x0000000000227f25 _start() ???:0
=================================
==== backtrace (tid: 19397) ====
0 0x0000000000042520 __sigaction() ???:0
=================================
[rank0]:[W422 15:34:34.190553574 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[grpo-experiment-mert-experiments-master-0:19392:0:19392] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
==== backtrace (tid: 19392) ====
0 0x0000000000042520 __sigaction() ???:0
=================================
š Describe the bug
Segmentation fault occurs when distributed training is complete.
I tried to reproduce it in a simpler script (see below). I don't get the segmentation fault, but observe a critical warning, which may be a hint:
[rank0]:[W422 15:49:17.205132298 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
/usr/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Simple vllm inf script `inf.py`
from vllm import LLM, SamplingParams
import time
# Create prompts, the same across all ranks
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=2,
distributed_executor_backend="external_launcher",
enable_sleep_mode=True,
seed=1
)
outputs = llm.generate(prompts, sampling_params)
# all ranks will have the same outputs
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
llm.sleep(level=2)
Before submitting a new issue...
- [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
CC @youkaichao, @fabianlim , @fingertap
I also experienced this. To reproduce, add a single line into the examples/offline_inference/torchrun_example.py:
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=2,
distributed_executor_backend="external_launcher",
seed=0,
enable_sleep_mode=True # Add this to trigger c10:Error
)
outputs = llm.generate(prompts, sampling_params)
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}")
print("-" * 50)
I did not get a full trace of this. It seems that the cuMemAllocator deletes freed some memory before pytorch attempts to free it again.
This happens at cleanup phase after all scripts are finished. My trace is a bit different from this issue, yet similar. My pytorch complains that it is trying to free some space that is not allocated in torch.
@youkaichao Can you give this mini-repro a try and see if it also reproduces the error on your side?
@fingertap I cannot reproduce it (I'm using python 3.12).
at the first glance, it might be related to some python object gc order issue.
@fingertap, I only get a warning when I run your mini repro w/ torchrun --nproc-per-node=2 inf.py in both Python3.10 and Python3.12
Warning:
[rank0]:[W428 20:32:01.845828927 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
/usr/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
But the bug is there, because I see it consistently at the end of each TRL/GRPO training when using sleep enabled (using accelerate).
See the logs of my run `err.log`
2025-04-26 23:11:53 - INFO - __main__ - *** Save model ***
[INFO|trainer.py:3978] 2025-04-26 23:14:40,943 >> Saving model checkpoint to /workspace/data/experiment/mert72bsleep1new/results/hf
[INFO|configuration_utils.py:420] 2025-04-26 23:14:40,956 >> Configuration saved in /workspace/data/experiment/mert72bsleep1new/results/hf/config.json
[INFO|configuration_utils.py:917] 2025-04-26 23:14:40,962 >> Configuration saved in /workspace/data/experiment/mert72bsleep1new/results/hf/generation_config.json
terminate called after throwing an instance of 'c10::Error'
terminate called after throwing an instance of 'c10::Error'
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fdb1a36c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7fdb1a315b3f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7fdacb28e667 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7fdb1a76bb78 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7fdb1a76c20e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39afa (0x7fdb1a782afa in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7fdb1a76e329 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdf74f0 (0x7fdb128864f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x516907 (0x7fdb11fa5907 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5174d1 (0x7fdb11fa64d1 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x172fd1 (0x562c85f01fd1 in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x562c85ec4c52 in /usr/bin/python)
frame #12: <unknown function> + 0x136991 (0x562c85ec5991 in /usr/bin/python)
frame #13: <unknown function> + 0x13678c (0x562c85ec578c in /usr/bin/python)
frame #14: <unknown function> + 0x136877 (0x562c85ec5877 in /usr/bin/python)
frame #15: <unknown function> + 0x172aa0 (0x562c85f01aa0 in /usr/bin/python)
frame #16: <unknown function> + 0x1caa87 (0x562c85f59a87 in /usr/bin/python)
frame #17: <unknown function> + 0x129fa1 (0x562c85eb8fa1 in /usr/bin/python)
frame #18: <unknown function> + 0x264970 (0x562c85ff3970 in /usr/bin/python)
frame #19: Py_FinalizeEx + 0x148 (0x562c85fef4c8 in /usr/bin/python)
frame #20: Py_RunMain + 0x173 (0x562c85fe0913 in /usr/bin/python)
frame #21: Py_BytesMain + 0x2d (0x562c85fb702d in /usr/bin/python)
frame #22: <unknown function> + 0x29d90 (0x7fdb52f58d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #23: __libc_start_main + 0x80 (0x7fdb52f58e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: _start + 0x25 (0x562c85fb6f25 in /usr/bin/python)
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8a6896c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f8a68915b3f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7f8a1988e667 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7f8a68d6bb78 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7f8a68d6c20e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39afa (0x7f8a68d82afa in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7f8a68d6e329 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdf74f0 (0x7f8a60e864f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x516907 (0x7f8a605a5907 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5174d1 (0x7f8a605a64d1 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x172fd1 (0x5604fc987fd1 in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x5604fc94ac52 in /usr/bin/python)
frame #12: <unknown function> + 0x136991 (0x5604fc94b991 in /usr/bin/python)
frame #13: <unknown function> + 0x13678c (0x5604fc94b78c in /usr/bin/python)
frame #14: <unknown function> + 0x136877 (0x5604fc94b877 in /usr/bin/python)
frame #15: <unknown function> + 0x172aa0 (0x5604fc987aa0 in /usr/bin/python)
frame #16: <unknown function> + 0x1caa87 (0x5604fc9dfa87 in /usr/bin/python)
frame #17: <unknown function> + 0x129fa1 (0x5604fc93efa1 in /usr/bin/python)
frame #18: <unknown function> + 0x264970 (0x5604fca79970 in /usr/bin/python)
frame #19: Py_FinalizeEx + 0x148 (0x5604fca754c8 in /usr/bin/python)
frame #20: Py_RunMain + 0x173 (0x5604fca66913 in /usr/bin/python)
frame #21: Py_BytesMain + 0x2d (0x5604fca3d02d in /usr/bin/python)
frame #22: <unknown function> + 0x29d90 (0x7f8aa15dad90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #23: __libc_start_main + 0x80 (0x7f8aa15dae40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: _start + 0x25 (0x5604fca3cf25 in /usr/bin/python)
terminate called after throwing an instance of 'c10::Error'
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fd022f6c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7fd022f15b3f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7fcfd3e8e667 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7fd02336bb78 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7fd02336c20e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39afa (0x7fd023382afa in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7fd02336e329 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdf74f0 (0x7fd01b4864f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x516907 (0x7fd01aba5907 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5174d1 (0x7fd01aba64d1 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x172fd1 (0x5609d47d0fd1 in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x5609d4793c52 in /usr/bin/python)
frame #12: <unknown function> + 0x136991 (0x5609d4794991 in /usr/bin/python)
frame #13: <unknown function> + 0x13678c (0x5609d479478c in /usr/bin/python)
frame #14: <unknown function> + 0x136877 (0x5609d4794877 in /usr/bin/python)
frame #15: <unknown function> + 0x172aa0 (0x5609d47d0aa0 in /usr/bin/python)
frame #16: <unknown function> + 0x1caa87 (0x5609d4828a87 in /usr/bin/python)
frame #17: <unknown function> + 0x129fa1 (0x5609d4787fa1 in /usr/bin/python)
frame #18: <unknown function> + 0x264970 (0x5609d48c2970 in /usr/bin/python)
frame #19: Py_FinalizeEx + 0x148 (0x5609d48be4c8 in /usr/bin/python)
frame #20: Py_RunMain + 0x173 (0x5609d48af913 in /usr/bin/python)
frame #21: Py_BytesMain + 0x2d (0x5609d488602d in /usr/bin/python)
frame #22: <unknown function> + 0x29d90 (0x7fd05bbcbd90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #23: __libc_start_main + 0x80 (0x7fd05bbcbe40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: _start + 0x25 (0x5609d4885f25 in /usr/bin/python)
terminate called after throwing an instance of 'c10::Error'
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f9cbef6c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f9cbef15b3f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7f9c6fa8e667 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7f9cbf345b78 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7f9cbf34620e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39afa (0x7f9cbf35cafa in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7f9cbf348329 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdf74f0 (0x7f9cb70864f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x516907 (0x7f9cb67a5907 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5174d1 (0x7f9cb67a64d1 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x172fd1 (0x56066e6c6fd1 in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x56066e689c52 in /usr/bin/python)
frame #12: <unknown function> + 0x136991 (0x56066e68a991 in /usr/bin/python)
frame #13: <unknown function> + 0x13678c (0x56066e68a78c in /usr/bin/python)
frame #14: <unknown function> + 0x136877 (0x56066e68a877 in /usr/bin/python)
frame #15: <unknown function> + 0x172aa0 (0x56066e6c6aa0 in /usr/bin/python)
frame #16: <unknown function> + 0x1caa87 (0x56066e71ea87 in /usr/bin/python)
frame #17: <unknown function> + 0x129fa1 (0x56066e67dfa1 in /usr/bin/python)
frame #18: <unknown function> + 0x264970 (0x56066e7b8970 in /usr/bin/python)
frame #19: Py_FinalizeEx + 0x148 (0x56066e7b44c8 in /usr/bin/python)
frame #20: Py_RunMain + 0x173 (0x56066e7a5913 in /usr/bin/python)
frame #21: Py_BytesMain + 0x2d (0x56066e77c02d in /usr/bin/python)
frame #22: <unknown function> + 0x29d90 (0x7f9cf784cd90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #23: __libc_start_main + 0x80 (0x7f9cf784ce40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: _start + 0x25 (0x56066e77bf25 in /usr/bin/python)
[grpo-experiment-mert-mert72bsleep1new-master-0:934 :0:934] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
[grpo-experiment-mert-mert72bsleep1new-master-0:939 :0:939] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil))
==== backtrace (tid: 934) ====
0 0x0000000000042520 __sigaction() ???:0
=================================
==== backtrace (tid: 939) ====
0 0x0000000000042520 __sigaction() ???:0
=================================
terminate called after throwing an instance of 'c10::Error'
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f9ac3f6c1b6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f9ac3f15b3f in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7f9a74e8e667 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7f9ac436bb78 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7f9ac436c20e in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39afa (0x7f9ac4382afa in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7f9ac436e329 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdf74f0 (0x7f9abc4864f0 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x516907 (0x7f9abbba5907 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5174d1 (0x7f9abbba64d1 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x172fd1 (0x557fa16c1fd1 in /usr/bin/python)
frame #11: <unknown function> + 0x135c52 (0x557fa1684c52 in /usr/bin/python)
frame #12: <unknown function> + 0x136991 (0x557fa1685991 in /usr/bin/python)
frame #13: <unknown function> + 0x13678c (0x557fa168578c in /usr/bin/python)
frame #14: <unknown function> + 0x136877 (0x557fa1685877 in /usr/bin/python)
frame #15: <unknown function> + 0x172aa0 (0x557fa16c1aa0 in /usr/bin/python)
frame #16: <unknown function> + 0x1caa87 (0x557fa1719a87 in /usr/bin/python)
frame #17: <unknown function> + 0x129fa1 (0x557fa1678fa1 in /usr/bin/python)
frame #18: <unknown function> + 0x264970 (0x557fa17b3970 in /usr/bin/python)
frame #19: Py_FinalizeEx + 0x148 (0x557fa17af4c8 in /usr/bin/python)
frame #20: Py_RunMain + 0x173 (0x557fa17a0913 in /usr/bin/python)
frame #21: Py_BytesMain + 0x2d (0x557fa177702d in /usr/bin/python)
frame #22: <unknown function> + 0x29d90 (0x7f9afcb87d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #23: __libc_start_main + 0x80 (0x7f9afcb87e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: _start + 0x25 (0x557fa1776f25 in /usr/bin/python)
OK. I think the issue it's on TRL side. My issue may be due to my local env. I do not experience any errors in my own RL architecture now.
@fingertap @toslali-ibm @youkaichao i was able to reproduce the bug in a pure VLLM setting an A100 environment on python 3.12. Therefore, I do not believe this issue is due to other code-bases (e.g., TRL), because you can try the below commands to reproduce the problem with only vllm.
- note that the reproduction steps sets
VLLM_USE_V1=0. I found by chance the error appears very clearly when this is set. On the other hand, with V1, I believe the error is still there, but it will not trigger the CUDA issue, but it hangs for a while and then gives an exception. - Therefore I think the problem still exists with V1 also.
Reproduction steps:
# create virtual environment and install stuff
pip install virtualenv
python -m virtualenv $HOME/venv/open-r1-mert
source $HOME/venv/open-r1-mert/bin/activate
pip install vllm==0.8.5
## RUN this command
VLLM_USE_V1=0 torchrun --nproc-per-node=1 -m inf
where inf points to the below simplified script
Script
from vllm import LLM, SamplingParams
import time
# Create prompts, the same across all ranks
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
llm = LLM(
model="facebook/opt-125m",
# tensor_parallel_size=1,
distributed_executor_backend="external_launcher",
enable_sleep_mode=True,
seed=1,
enforce_eager=True,
)
outputs = llm.generate(prompts, sampling_params)
# all ranks will have the same outputs
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Output
. This gave:
Prompt: 'The president of the United States is', Generated text: ' declaring war on the business community in the name of the economy.\n\nIn' [27/1516]
Prompt: 'The capital of France is', Generated text: ' home to nearly two million people. Its population is around 23 million. The French'
Prompt: 'The future of AI is', Generated text: ' looking very bright, thanks to a handful of recent breakthroughs. Here are some'
[rank0]:[W429 16:41:38.807570414 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resourc
es. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
terminate called after throwing an instance of 'c10::Error'
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x149e316511b6 in /gpfs/users/flim/venv/open-r1-mert/lib/python3.12/site-packages/torch/li
b/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x149e315fab3f in /gpfs/users/flim/venv/open-r1-mert/lib/pyth
on3.12/site-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x149e351ff667 in /gpfs/users/flim/venv/open-r1-mert/lib/pytho
n3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x149e31707b78 in /gpfs/users/flim/venv/open-r1-mert/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x149e3170820e in /gpfs/users/flim/venv/open-r1-mert/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39b0d (0x149e3171eb0d in /gpfs/users/flim/venv/open-r1-mert/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x149e3170a329 in /gpfs/users/flim/venv/open-r1-mert/lib/python3.12/site-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdffa90 (0x149e7c831a90 in /gpfs/users/flim/venv/open-r1-mert/lib/python3.12/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x517c37 (0x149e7bf49c37 in /gpfs/users/flim/venv/open-r1-mert/lib/python3.12/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x518881 (0x149e7bf4a881 in /gpfs/users/flim/venv/open-r1-mert/lib/python3.12/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #23: <unknown function> + 0x295d0 (0x149e84f1c5d0 in /lib64/libc.so.6)
frame #24: __libc_start_main + 0x80 (0x149e84f1c680 in /lib64/libc.so.6)
frame #25: _start + 0x25 (0x401075 in /gpfs/users/flim/venv/open-r1-mert/bin/python)
I can confirm that Iām able to reproduce the bug using @fabianlim 's script. I also tested it across multiple vLLM versions ā 0.8.0, 0.8.1, 0.8.2, 0.8.3, 0.8.4, and 0.8.5 ā and the error persists in all of them. However, the issue does not occur in version 0.7.3, which suggests that the bug may have been introduced starting from 0.8.0.
Also, on my end (TRL), V1 also results with c10 error.
Im facing the same issue on version 0.8.2.
I have tried to manually clean up atexit.
# Ensure proper teardown
def cleanup():
print("Goodbye š Cleaning up...")
try:
if dist.is_initialized():
dist.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
print("---cleaned")
except Exception as e:
print(f"š Error during cleanup: {e}")
atexit.register(cleanup)
But still getting the c10 error.
---cleaned
terminate called after throwing an instance of 'c10::Error'
what(): Trying to free a pointer not allocated here
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f434416c1b6 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f4344115b3f in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #2: torch::cuda::CUDAPluggableAllocator::CUDAPluggableAllocator::raw_delete(void*) + 0x1a7 (0x7f42f5087667 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so)
frame #3: <unknown function> + 0x22b78 (0x7f434456bb78 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x2320e (0x7f434456c20e in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_cuda.so)
frame #5: <unknown function> + 0x39b0d (0x7f4344582b0d in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_cuda.so)
frame #6: c10::cuda::MemPool::~MemPool() + 0x1b9 (0x7f434456e329 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10_cuda.so)
frame #7: <unknown function> + 0xdffa90 (0x7f433c687a90 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0x517c37 (0x7f433bd9fc37 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x518881 (0x7f433bda0881 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #10: /usr/bin/python() [0x59f6d3]
frame #11: /usr/bin/python() [0x5919fa]
frame #12: /usr/bin/python() [0x57637e]
frame #13: /usr/bin/python() [0x5760cc]
frame #14: /usr/bin/python() [0x579ed2]
frame #15: /usr/bin/python() [0x59f4e9]
frame #16: PyDict_Clear + 0x14e (0x57a9ae in /usr/bin/python)
frame #17: /usr/bin/python() [0x5a0ea1]
frame #18: /usr/bin/python() [0x61d88d]
frame #19: /usr/bin/python() [0x6bd379]
frame #20: Py_FinalizeEx + 0xad (0x6b136d in /usr/bin/python)
frame #21: Py_RunMain + 0x281 (0x6bcd61 in /usr/bin/python)
frame #22: Py_BytesMain + 0x2d (0x6bc97d in /usr/bin/python)
frame #23: <unknown function> + 0x2a1ca (0x7f43451011ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #24: __libc_start_main + 0x8b (0x7f434510128b in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #25: _start + 0x25 (0x6584a5 in /usr/bin/python)
Hey @youkaichao , just wanted to follow up on this issue to see if you have any recommendations or pointers?
I met this error again. Do you know what may be the cause? @youkaichao
Exception raised from raw_delete at /pytorch/torch/csrc/cuda/CUDAPluggableAllocator.cpp:151 (most recent call first):
an error like this is usually caused by python's garbage collection order (which does not have any guarantee) . the sleep mode depends on pytorch CUDAPluggableAllocator with memory pool implementation, and there's complicated logic between the memory pool and the allocated tensor.
one possible workaround is to reimplement the sleep mode with dispatch mode, so that we can easily hook all the tensor allocations without using memory pool.
one possible workaround is to reimplement the sleep mode with dispatch mode
Sounds like a lot of work. Can we skip the garbage collection for the tensors recorded in the the allocator?
Can we skip the garbage collection for the tensors recorded in the the allocator?
I'm not sure, you can have a try and see if it works.
one possible workaround is to reimplement the sleep mode with dispatch mode
Sounds like a lot of work. Can we skip the garbage collection for the tensors recorded in the the allocator?
If you give it a try, can you let me know as well?
This GC bug will break my ci environment, is there any progress?
should be fixed by https://github.com/vllm-project/vllm/pull/23477