vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Core][AMD] Migrate fully transparent sleep mode to ROCm platform

Open HollowMan6 opened this issue 10 months ago β€’ 46 comments

FIX #10714 for AMD GPUs

Related to #11743

HollowMan6 avatar Feb 03 '25 15:02 HollowMan6

πŸ‘‹ Hi! Thank you for contributing to the vLLM project. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

πŸš€

github-actions[bot] avatar Feb 03 '25 15:02 github-actions[bot]

Hi @youkaichao! Thank you for your amazing work in #11743. I've checked all the APIs and found that all the CuMem calls used here actually have equivalent Hip versions, so I tried to migrate the feature into ROCm as well. I almost succeeded with all the code compiles, but we had a strange OOM issue with hipMemAddressReserve when we tried to allocate within allocator.use_memory_pool(). So just want to check with you / anyone familiar with this to see how we could resolve this.

For easier reproduction, I modified the demo and pushed it here: https://github.com/HollowMan6/vllm_allocator_adaptor

The error log when running test.py

amdgpu.ids: No such file or directory
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
[vllm_allocator_adaptor_c] create_and_map: device=0x3, size=20971520
[vllm_allocator_adaptor_c] CUDA Error: out of memory at vllm_allocator_adaptor_c.cpp:197
[vllm_allocator_adaptor_c] CUDA Error: out of memory at vllm_allocator_adaptor_c.cpp:137
[vllm_allocator_adaptor_c] CUDA Error: invalid argument at vllm_allocator_adaptor_c.cpp:138
[vllm_allocator_adaptor_c] CUDA Error: invalid argument at vllm_allocator_adaptor_c.cpp:145
[vllm_allocator_adaptor_c] create_and_map: device=0, size=20971520, d_mem=0, p_memHandle=0x80b7a00
Traceback (most recent call last):
  File "vllm_allocator_adaptor/test.py", line 53, in <module>
    y = torch.empty(shape, device='cuda')
torch.OutOfMemoryError: HIP out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 63.98 GiB of which 63.73 GiB is free. Of the allocated memory 4.00 MiB is allocated by PyTorch, and 18.00 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_HIP_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

HollowMan6 avatar Feb 03 '25 15:02 HollowMan6

Hi @HollowMan6, just a quick feedback data point as I was testing this PR. For V0 inference seems to work and the " from vllm.device_allocator.cumem import CuMemAllocator" import error I was facing on main is fixed but for V1 I get this error:

INFO 02-03 16:07:28 model_runner.py:1118] Loading model weights took 14.9888 GB
ERROR 02-03 16:07:28 core.py:210] EngineCore hit an exception: Traceback (most recent call last):
ERROR 02-03 16:07:28 core.py:210]   File "/home/mreso/vllm/vllm/v1/engine/core.py", line 202, in run_engine_core
ERROR 02-03 16:07:28 core.py:210]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 02-03 16:07:28 core.py:210]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-03 16:07:28 core.py:210]   File "/home/mreso/vllm/vllm/v1/engine/core.py", line 156, in __init__
ERROR 02-03 16:07:28 core.py:210]     super().__init__(vllm_config, executor_class)
ERROR 02-03 16:07:28 core.py:210]   File "/home/mreso/vllm/vllm/v1/engine/core.py", line 54, in __init__
ERROR 02-03 16:07:28 core.py:210]     num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
ERROR 02-03 16:07:28 core.py:210]                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-03 16:07:28 core.py:210]   File "/home/mreso/vllm/vllm/v1/engine/core.py", line 75, in _initialize_kv_caches
ERROR 02-03 16:07:28 core.py:210]     kv_cache_spec = self.model_executor.get_kv_cache_spec()
ERROR 02-03 16:07:28 core.py:210]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-03 16:07:28 core.py:210]   File "/home/mreso/vllm/vllm/v1/executor/abstract.py", line 68, in get_kv_cache_spec
ERROR 02-03 16:07:28 core.py:210]     output = self.collective_rpc("get_kv_cache_spec")
ERROR 02-03 16:07:28 core.py:210]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-03 16:07:28 core.py:210]   File "/home/mreso/vllm/vllm/executor/uniproc_executor.py", line 51, in collective_rpc
ERROR 02-03 16:07:28 core.py:210]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 02-03 16:07:28 core.py:210]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-03 16:07:28 core.py:210]   File "/home/mreso/vllm/vllm/utils.py", line 2206, in run_method
ERROR 02-03 16:07:28 core.py:210]     raise NotImplementedError(f"Method {method!r} is not"
ERROR 02-03 16:07:28 core.py:210] NotImplementedError: Method 'get_kv_cache_spec' is not implemented.
ERROR 02-03 16:07:28 core.py:210]
CRITICAL 02-03 16:07:28 core_client.py:158] Got fatal signal from worker processes, shutting down. See stack trace above for root cause issue.
Killed

Not sure if you're aware of this, so I thought I quickly leave comment. Let me know if you need more info to repro.

EDIT: Oh and I had to remove "list(APPEND CUMEM_LIBS amdhip64)" in my env, otherwise I got a linker error: /usr/bin/ld: cannot find -lamdhip64.

mreso avatar Feb 04 '25 00:02 mreso

but we had a strange OOM issue with hipMemAddressReserve when we tried to allocate within allocator.use_memory_pool()

this usually means error happens for hipMemAddressReserve . it is not really an OOM error.

when I implement the PR, the most difficult part is to make sure the call to cuda driver API succeeds. I don't know if simply replacing some functions works to migrate these API calls. We need rocm experts to confirm the API usage, as they are quite low-level (and error-prone).

youkaichao avatar Feb 04 '25 09:02 youkaichao

Hi @HollowMan6, just a quick feedback data point as I was testing this PR. For V0 inference seems to work

EDIT: Oh and I had to remove "list(APPEND CUMEM_LIBS amdhip64)" in my env, otherwise I got a linker error: /usr/bin/ld: cannot find -lamdhip64.

Hi @mreso! Thank you for your feedback! Are you sure it's working for the sleep/wake without the strange OOM error on AMD with this PR? It looks like you failed to link to the ROCm library, so the sleep/wake shouldn't work here. Remove list(APPEND CUMEM_LIBS amdhip64) is not the correct solution as it will stop it from linking to the library correctly. The issue you have is more like an environment issue, so maybe you can try something like export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH for now, maybe there is a better way to handle this properly in the CMakeLists.txt, but let's first get the hipMemAddressReserve working.

The " from vllm.device_allocator.cumem import CuMemAllocator" import error I was facing on main is fixed but for V1 I get this error:

Oh really? it sounds like they have some failed logic in checking whether you are using NVIDIA or AMD, as in the original codebase, CuMemAllocator is disabled for AMD GPUs. I can't produce this on my side, though, so maybe you want to file a separate issue about this. Regarding the V1 one, neither this PR nor #11743 made any modification to this, so it should be a separate issue, too.

HollowMan6 avatar Feb 04 '25 12:02 HollowMan6

Okay, it turns out to be a hardware issue on my side, I'm using MI250x, and it looks like it doesn't support virtual memory management, as is confirmed by https://github.com/ROCm/hip-tests/blob/84a460d96bafb00615969304cc5eaddc3b20bc3d/catch/unit/memory/hip_vmm_common.hh#L27-L37

Maybe this can work on some advanced AMD hardware, such as MI300+? Unfortunately, I don't have access to hardware such as MI300+, so I can't help with this PR any further. I will set this PR as ready as directly mapping those CUDA calls seems to be OK, we also have some test cases here for reference: https://github.com/ROCm/hip-tests/blob/84a460d96bafb00615969304cc5eaddc3b20bc3d/catch/unit/virtualMemoryManagement/hipMemSetGetAccess.cc#L213-L256

Feel free to take this PR over @youkaichao or anyone who is interested and has those GPUs. I have set this PR to allow edits and access to secrets by maintainers

HollowMan6 avatar Feb 04 '25 13:02 HollowMan6

@hongxiayang are you able to help with this?

DarkLight1337 avatar Feb 04 '25 13:02 DarkLight1337

@hongxiayang are you able to help with this?

Thanks for your contribution. @HollowMan6

@DarkLight1337 I will check this when I have more bandwidth.

hongxiayang avatar Feb 04 '25 15:02 hongxiayang

Thanks @HollowMan6 you're right, that will be an env issue and I was not specifically testing sleep/wake explicitly so it will most likely crash if I did.

Oh really? it sounds like they have some failed logic in checking whether you are using NVIDIA or AMD, as in the original codebase, CuMemAllocator is disabled for AMD GPUs.

Interesting that this does not work in my case. This line is causing an exception on my end and I assume if cuda is not present its supposed to swallow the ModuleNotFoundError and disable the CuMemAllocator that way, right? But in my case an AssertionError is thrown here which is not caught. If I raise a ModuleNotFoundError error instead its working.

Is there another mechanism I am missing that should disable CuMemAllocator?

mreso avatar Feb 05 '25 00:02 mreso

@mreso

I was not specifically testing sleep/wake explicitly so it will most likely crash if I did.

What hardware are you using? If it's MI300 and above, then it's likely that it supports virtual memory management and the sleep/wake mode will not crash in that OOM way, you are welcome to test this out and let us know what you found out!

Interesting that this does not work in my case. This line is causing an exception on my end and I assume if cuda is not present its supposed to swallow the ModuleNotFoundError and disable the CuMemAllocator that way, right? But in my case an AssertionError is thrown here which is not caught. If I raise a ModuleNotFoundError error instead its working.

Is there another mechanism I am missing that should disable CuMemAllocator?

In main, vllm.cumem_allocator shouldn't exist in the first place for AMD, so that's why it catches the ModuleNotFoundError only. That module comes from the C side, so you would need to find out why we have vllm/cumem_allocator.abi3.so (gets compiled in your CMakeLists.txt), while in that CMakeLists.txt, it's specified that it will only compile for CUDA, so most likely an env issue again. Did you install CUDA somehow on your AMD machine? I would also suggest you clean up your workdir, to ensure we don't have any built cumem_allocator.abi3.so, especially if you switch from this PR to the main.

HollowMan6 avatar Feb 05 '25 08:02 HollowMan6

In main, vllm.cumem_allocator shouldn't exist in the first place for AMD, so that's why it catches the ModuleNotFoundError only. That module comes from the C side, so you would need to find out why we have vllm/cumem_allocator.abi3.so (gets compiled in your CMakeLists.txt), while in that CMakeLists.txt, it's specified that it will only compile for CUDA, so most likely an env issue again. Did you install CUDA somehow on your AMD machine? I would also suggest you clean up your workdir, to ensure we don't have any built cumem_allocator.abi3.so, especially if you switch from this PR to the main.

Thats it, thanks a lot! Was an unclean build env. There is no trace of CUDA on the machines but it must have created the file anyways and so the ModuleNotFoundError was not raised... After removing all files and rebuilding it works.

I am on MI300s and will give it a try.

mreso avatar Feb 05 '25 18:02 mreso

Gave it a quick try using this test script:

import torch
from vllm import LLM

llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", enable_sleep_mode=True)

def run_inference(prompt):
    outputs = llm.generate(prompt)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


print("CUDA Memory Usage (after inference):")
torch.cuda.empty_cache()
print(f"{torch.cuda.memory_allocated()=}")

run_inference("San Francisco is")
llm.sleep()

print("CUDA Memory Usage (after sleep):")
torch.cuda.empty_cache()
print(f"{torch.cuda.memory_allocated()=}")

llm.wake_up()

print("CUDA Memory Usage (after wakeup):")
torch.cuda.empty_cache()
print(f"{torch.cuda.memory_allocated()=}")

run_inference("Paris is")

but seems like sleep did not free any memory and after wakeup I got an OOM error:

INFO 02-05 11:19:42 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 54.63 seconds
CUDA Memory Usage (after inference):
torch.cuda.memory_allocated()=168654098432
Processed prompts: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00,  3.85it/s, est. speed input: 15.41 toks/s, output: 61.64 toks/s]
Prompt: 'San Francisco is', Generated text: ' a top tourist destination, with millions of visitors each year. Whether you’re traveling'
INFO 02-05 11:19:46 worker.py:133] Sleep mode freed 0.00 GiB memory, 160.18 GiB memory is still in use.
CUDA Memory Usage (after sleep):
torch.cuda.memory_allocated()=168654097920
CUDA Error: out of memory at /home/mreso/vllm/csrc/cumem_allocator.cpp:56
Segmentation fault (core dumped)

This is my hw:

=========================================== Concise Hardware Info ============================================
GPU  NODE  DID     GUID   GFX VER  GFX RAS  SDMA RAS  UMC RAS  VBIOS             BUS           PARTITION ID
0    2     0x74a1  51140  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:08:00.0  0
1    5     0x74a1  36564  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:28:00.0  0
2    4     0x74a1  18917  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:48:00.0  0
3    3     0x74a1  28918  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:68:00.0  0
4    8     0x74a1  15111  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:88:00.0  0
5    9     0x74a1  57880  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:A8:00.0  0
6    7     0x74a1  44328  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:C8:00.0  0
7    6     0x74a1  13826  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:E9:00.0  0
==============================================================================================================

Let me know if the test script is incorrect for this or I can try anything else.

mreso avatar Feb 05 '25 19:02 mreso

Thank you @mreso! It looks like MI300 supports virtual memory management, but it gives an OOM error at a different line and then ends up with a segment fault, maybe Hip APIs should be called in a different way than CUDA. Unfortunately, I can't help with this PR any further as I don't have MI300 in my hand, but thank you anyway, @mreso!

HollowMan6 avatar Feb 05 '25 19:02 HollowMan6

Gave it a quick try using this test script:

import torch
from vllm import LLM

llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", enable_sleep_mode=True)

def run_inference(prompt):
    outputs = llm.generate(prompt)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


print("CUDA Memory Usage (after inference):")
torch.cuda.empty_cache()
print(f"{torch.cuda.memory_allocated()=}")

run_inference("San Francisco is")
llm.sleep()

print("CUDA Memory Usage (after sleep):")
torch.cuda.empty_cache()
print(f"{torch.cuda.memory_allocated()=}")

llm.wake_up()

print("CUDA Memory Usage (after wakeup):")
torch.cuda.empty_cache()
print(f"{torch.cuda.memory_allocated()=}")

run_inference("Paris is")

but seems like sleep did not free any memory and after wakeup I got an OOM error:

INFO 02-05 11:19:42 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 54.63 seconds
CUDA Memory Usage (after inference):
torch.cuda.memory_allocated()=168654098432
Processed prompts: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00,  3.85it/s, est. speed input: 15.41 toks/s, output: 61.64 toks/s]
Prompt: 'San Francisco is', Generated text: ' a top tourist destination, with millions of visitors each year. Whether you’re traveling'
INFO 02-05 11:19:46 worker.py:133] Sleep mode freed 0.00 GiB memory, 160.18 GiB memory is still in use.
CUDA Memory Usage (after sleep):
torch.cuda.memory_allocated()=168654097920
CUDA Error: out of memory at /home/mreso/vllm/csrc/cumem_allocator.cpp:56
Segmentation fault (core dumped)

This is my hw:

=========================================== Concise Hardware Info ============================================
GPU  NODE  DID     GUID   GFX VER  GFX RAS  SDMA RAS  UMC RAS  VBIOS             BUS           PARTITION ID
0    2     0x74a1  51140  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:08:00.0  0
1    5     0x74a1  36564  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:28:00.0  0
2    4     0x74a1  18917  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:48:00.0  0
3    3     0x74a1  28918  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:68:00.0  0
4    8     0x74a1  15111  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:88:00.0  0
5    9     0x74a1  57880  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:A8:00.0  0
6    7     0x74a1  44328  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:C8:00.0  0
7    6     0x74a1  13826  gfx942   ENABLED  ENABLED   ENABLED  113-M3000100-102  0000:E9:00.0  0
==============================================================================================================

Let me know if the test script is incorrect for this or I can try anything else.

I can also confirm that the current solution does not work on mi300 (on v0.7.2).

INFO 02-13 15:46:13 model_runner.py:1110] Starting to load model meta-llama/Llama-3.1-8B-Instruct...
INFO 02-13 15:46:18 weight_utils.py:252] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  1.77it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:02<00:03,  1.66s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:05<00:02,  2.12s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:08<00:00,  2.34s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:08<00:00,  2.08s/it]

INFO 02-13 15:46:27 model_runner.py:1115] Loading model weights took 14.9888 GB
INFO 02-13 15:47:10 worker.py:284] Memory profiling takes 43.04 seconds
INFO 02-13 15:47:10 worker.py:284] the current vLLM instance can use total_gpu_memory (191.45GiB) x gpu_memory_utilization (0.90) = 172.31GiB
INFO 02-13 15:47:10 worker.py:284] model weights take 14.99GiB; non_torch_memory takes 2.21GiB; PyTorch activation peak memory takes 13.50GiB; the rest of the memory reserved for KV Cache is 141.61GiB.
INFO 02-13 15:47:10 executor_base.py:110] # CUDA blocks: 72501, # CPU blocks: 2048
INFO 02-13 15:47:10 executor_base.py:115] Maximum concurrency for 131072 tokens per request: 8.85x
INFO 02-13 15:48:44 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 35/35 [00:20<00:00,  1.73it/s]
INFO 02-13 15:49:04 model_runner.py:1562] Graph capturing finished in 20 secs, took 0.20 GiB
INFO 02-13 15:49:04 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 157.36 seconds
CUDA Memory Usage (after inference):
torch.cuda.memory_allocated()=168163364864
Processed prompts: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:02<00:00,  2.84s/it, est. speed input: 1.41 toks/s, output: 5.64 toks/s]
Prompt: 'San Francisco is', Generated text: ' a top tourist destination, with millions of visitors each year. Whether you’re traveling'
CUDA Memory Usage (after sleep):
torch.cuda.memory_allocated()=168163364352
CUDA Memory Usage (after wakeup):
torch.cuda.memory_allocated()=168163364352
Processed prompts: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:02<00:00,  2.47s/it, est. speed input: 1.22 toks/s, output: 6.49 toks/s]
Prompt: 'Paris is', Generated text: ' a battleground for the culinary world. Street vendors, high-end restaurants, bakers'
[rank0]:[W213 15:49:11.474759635 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of an

YangWang92 avatar Feb 13 '25 15:02 YangWang92

By the way, feel free to tell me what I can help with on this pull.

YangWang92 avatar Feb 13 '25 15:02 YangWang92

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @HollowMan6.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 06 '25 01:03 mergify[bot]

Hi @HollowMan6, just a quick feedback data point as I was testing this PR. For V0 inference seems to work EDIT: Oh and I had to remove "list(APPEND CUMEM_LIBS amdhip64)" in my env, otherwise I got a linker error: /usr/bin/ld: cannot find -lamdhip64.

Hi @mreso! Thank you for your feedback! Are you sure it's working for the sleep/wake without the strange OOM error on AMD with this PR? It looks like you failed to link to the ROCm library, so the sleep/wake shouldn't work here. Remove list(APPEND CUMEM_LIBS amdhip64) is not the correct solution as it will stop it from linking to the library correctly. The issue you have is more like an environment issue, so maybe you can try something like export LD_LIBRARY_PATH=/opt/rocm/lib:$LD_LIBRARY_PATH for now, maybe there is a better way to handle this properly in the CMakeLists.txt, but let's first get the hipMemAddressReserve working.

The " from vllm.device_allocator.cumem import CuMemAllocator" import error I was facing on main is fixed but for V1 I get this error:

Oh really? it sounds like they have some failed logic in checking whether you are using NVIDIA or AMD, as in the original codebase, CuMemAllocator is disabled for AMD GPUs. I can't produce this on my side, though, so maybe you want to file a separate issue about this. Regarding the V1 one, neither this PR nor #11743 made any modification to this, so it should be a separate issue, too.

@HollowMan6 @hongxiayang I can't seem to compile due to missing linker -lamdhip64 when building DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm-ellm:sleepamd .

 deprecated: This API is marked as deprecated and might not be supported in future releases. For more details please re[194/1943]
//github.com/ROCm/HIP/blob/develop/docs/reference/deprecated_api_list.md [-Wdeprecated-declarations]                             
57.87    57 |   return hipDevicePrimaryCtxRetain(ctx, dev);                                                                      
57.87       |          ~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~                                                                       
57.87 In file included from /app/vllm/csrc/cumem_allocator_compat.h:7,                                                           
57.87                  from /app/vllm/csrc/cumem_allocator.cpp:6:                                                                
57.87 /opt/rocm/include/hip/hip_runtime_api.h:5437:12: note: declared here                                                       
57.87  5437 | hipError_t hipDevicePrimaryCtxRetain(hipCtx_t* pctx, hipDevice_t dev);                                             
57.87       |            ^~~~~~~~~~~~~~~~~~~~~~~~~                                                                               
58.04 [2/28] Linking CXX shared module cumem_allocator.abi3.so                                                                   
58.04 FAILED: cumem_allocator.abi3.so                                                                                            
58.04 : && /usr/bin/c++ -fPIC -Wno-unused-result -O2 -g -DNDEBUG   -shared  -o cumem_allocator.abi3.so CMakeFiles/cumem_allocator
.dir/csrc/cumem_allocator.cpp.o  -Wl,-rpath,/usr/local/lib/python3.12/dist-packages/torch/lib:/opt/rocm-6.3.1/lib:/opt/rocm/lib: 
 -lamdhip64  /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch.so  /usr/local/lib/python3.12/dist-packages/torch/lib/lib
c10.so  -Wl,--no-as-needed,"/usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cpu.so" -Wl,--as-needed  -Wl,--no-as-neede
d,"/usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_hip.so" -Wl,--as-needed  /usr/local/lib/python3.12/dist-packages/to
rch/lib/libc10_hip.so  /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so  /opt/rocm-6.3.1/lib/libMIOpen.so.1.0.60301  /
opt/rocm/lib/libhiprtc.so.6.3.60301  -ldl  /opt/rocm-6.3.1/lib/libhipblas.so.2.3.60301  /opt/rocm-6.3.1/lib/libhipfft.so.0.1.6030
1  /opt/rocm-6.3.1/lib/libhiprand.so.1.1.60301  /opt/rocm-6.3.1/lib/librocrand.so.1.1.60301  /opt/rocm-6.3.1/lib/libhipsparse.so.
1.1.0.60301  /opt/rocm-6.3.1/lib/libhipsolver.so.0.3.60301  /opt/rocm-6.3.1/lib/libhipblaslt.so.0.13  /opt/rocm/lib/libamdhip64.s
o.6.3.60301  -Wl,--no-as-needed,"/usr/local/lib/python3.12/dist-packages/torch/lib/libtorch.so" -Wl,--as-needed  -Wl,-rpath-link,
/opt/rocm-6.3.1/lib && :
58.04 /usr/bin/ld: cannot find -lamdhip64: No such file or directory
58.04 collect2: error: ld returned 1 exit status

This compilation error is resolved through:

sudo ln -sf /opt/rocm/lib/libamdhip64.so.6.3.60301 /usr/lib/libamdhip64.so sudo ldconfig

HIP-TESTs of VirtualMemoryManagement

I am using the hip and rocm in the vLLM docker image to build the test

I am getting two failure when running VirtualMemoryManagement tests

# ./VirtualMemoryManagementTest 

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
VirtualMemoryManagementTest is a Catch v2.13.4 host application.
Run with -? for options

-------------------------------------------------------------------------------
Unit_hipMemCreate_MapNonContiguousChunks
-------------------------------------------------------------------------------
/app/hip-tests/catch/unit/virtualMemoryManagement/hipMemCreate.cc:272
...............................................................................

/app/hip-tests/catch/unit/virtualMemoryManagement/hipMemCreate.cc:323: FAILED:
  REQUIRE( true == std::equal(B_h.begin(), B_h.end(), C_h.data()) )
with expansion:
  true == false

Memory access fault by GPU node-9 (Agent handle: 0xd13970) on address 0x7f3214c47000. Reason: Unknown.
GPU core dump created: gpucore.28281
-------------------------------------------------------------------------------
Unit_hipMemSetAccess_Vmm2UnifiedMemCpy
-------------------------------------------------------------------------------
/app/hip-tests/catch/unit/virtualMemoryManagement/hipMemSetGetAccess.cc:552
...............................................................................

/app/hip-tests/catch/unit/virtualMemoryManagement/hipMemSetGetAccess.cc:552: FAILED:
  {Unknown expression after the reported line}
due to a fatal error condition:
  SIGABRT - Abort (abnormal termination) signal

===============================================================================
test cases:    24 |    22 passed | 2 failed
assertions: 73852 | 73850 passed | 2 failed

Aborted

HIP-TEST setup

   20  export ROCM_BRANCH=rocm-6.3.x
   21  git clone -b "$ROCM_BRANCH" https://github.com/ROCm/hip-tests.git
   22  export HIPTESTS_DIR="$(readlink -f hip-tests)"
   23  cd "$HIPTESTS_DIR"
   24  mkdir -p build; cd build
   25  cmake ../catch -DHIP_PLATFORM=amd -DHIP_PATH=/opt/rocm
   26  make build_tests
   27  cd catch_tests/
   28  ls
   29  cd unit/
   30  ls
   31  cd virtualMemoryManagement/
   32  ls
   33  ./VirtualMemoryManagementTest 

tjtanaa avatar Mar 13 '25 03:03 tjtanaa

Update No.2 03122025

@HollowMan6 @YangWang92 The VRAM did drop to 1% from 84% when it goes to sleep (if you check using rocm-smi). There might be a problem when using torch to get the VRAM allocation. But to let you have some time to view output from rocm-smi, you can add time.sleep(10) after llm.sleep()

Segfault is smthg that is caused by the LLM at the end of the script.

I think the segfault problem is with pytorch. The memory allocation is now handled manually outside of Pytorch, and the Pytorch might not be retrievingΒ realtime information. When the program ends, torch try to deallocate memory that does belongΒ toΒ itΒ anymore

tjtanaa avatar Mar 13 '25 09:03 tjtanaa

Update No.3 03122025

@HollowMan6 I have fixed your test script, this time, we are able to read the GPU memory VRAM accurately on ROCm. It requires the use of amdsmi package:

Fixed Script:

import torch
import time
from contextlib import contextmanager
from amdsmi import (amdsmi_get_gpu_vram_usage,
                    amdsmi_get_processor_handles, amdsmi_init,
                    amdsmi_shut_down)
from vllm import LLM
import os

@contextmanager
def _rocm():
    try:
        amdsmi_init()
        yield
    finally:
        amdsmi_shut_down()

def get_physical_device_indices(devices):
    visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
    if visible_devices is None:
        return devices

    visible_indices = [int(x) for x in visible_devices.split(",")]
    index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
    return [index_mapping[i] for i in devices if i in index_mapping]

def print_gpu_memory():
    with _rocm():
        devices = list(range(torch.cuda.device_count())) 
        devices = get_physical_device_indices(devices)
        start_time = time.time()

        output: dict[int, str] = {}
        output_raw: dict[int, float] = {}
        for device in devices:
            dev_handle = amdsmi_get_processor_handles()[device]
            mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
            gb_used = mem_info["vram_used"] / 2**10
            output_raw[device] = gb_used
            output[device] = f'{gb_used:.02f}'

        print('gpu memory used (GB): ', end='')
        for k, v in output.items():
            print(f'{k}={v}; ', end='')
        print('')

        dur_s = time.time() - start_time

        time.sleep(5)

def print_memory_usage(stage):
    torch.cuda.synchronize()  # Ensure all operations are complete
    print(f"CUDA Memory Usage ({stage}):")
    print_gpu_memory()

llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", enable_sleep_mode=True)

def run_inference(prompt):
    outputs = llm.generate(prompt)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

print_memory_usage("initial")

run_inference("San Francisco is")
llm.sleep()

time.sleep(10)
print_memory_usage("after sleep")

llm.wake_up()

time.sleep(10)
print_memory_usage("after wakeup")

run_inference("Paris is")

The printout of the message

INFO 03-13 12:08:17 [__init__.py:256] Automatically detected platform rocm.                                             
CUDA Memory Usage (initial):                                                                                            
**gpu memory used (GB): 5=0.28;**                                                                                           
INFO 03-13 12:08:41 [config.py:578] This model supports multiple tasks: {'embed', 'generate', 'classify', 'score', 'rewa
rd'}. Defaulting to 'generate'.                                                                                         
INFO 03-13 12:08:41 [config.py:1521] Disabled the custom all-reduce kernel because it is not supported on AMD GPUs.     
WARNING 03-13 12:08:41 [arg_utils.py:1248] The model has a long context length (131072). This may cause OOM errors durin
g the initial memory profiling phase, or result in low performance due to small KV cache space. Consider setting --max-m
odel-len to a smaller value.                                                                                            
INFO 03-13 12:08:41 [llm_engine.py:235] Initializing a V0 LLM engine (v0.1.dev4964+g8a1e046) with config: model='meta-ll
ama/Llama-3.1-8B-Instruct', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B-Instruct', skip_tokenizer_init=F
alse, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False,
 dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipel
ine_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  devic
e_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability
_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, coll
ect_model_execute_time=False), seed=0, served_model_name=meta-llama/Llama-3.1-8B-Instruct, num_scheduler_steps=1, multi_
step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=True, disabl
e_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"com
pile_sizes":[],"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,10
4,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False,                          
INFO 03-13 12:08:43 [rocm.py:130] None is not supported in AMD GPUs.                                                    
INFO 03-13 12:08:43 [rocm.py:131] Using ROCmFlashAttention backend.                                                     
INFO 03-13 12:08:43 [parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0       
INFO 03-13 12:08:43 [model_runner.py:1110] Starting to load model meta-llama/Llama-3.1-8B-Instruct...                   
INFO 03-13 12:08:45 [weight_utils.py:257] Using model weights format ['*.safetensors']                                  
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]                                            
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:02<00:06,  2.16s/it]                                    
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:04<00:04,  2.33s/it]                                    
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:07<00:02,  2.37s/it]                                    
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  1.78s/it]                                    
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:07<00:00,  1.98s/it]  

INFO 03-13 12:08:53 [loader.py:422] Loading weights took 7.90 seconds
INFO 03-13 12:08:53 [model_runner.py:1117] Model loading took 15.3027 GB and 10.225612 seconds
INFO 03-13 12:09:03 [worker.py:267] Memory profiling takes 8.93 seconds
INFO 03-13 12:09:03 [worker.py:267] the current vLLM instance can use total_gpu_memory (191.98GiB) x gpu_memory_utilization (0.90) = 172.79GiB
INFO 03-13 12:09:03 [worker.py:267] model weights take 15.30GiB; non_torch_memory takes 0.57GiB; PyTorch activation peak memory takes 13.50GiB; the rest of the memory reserved for KV Cache is 143.41GiB.
INFO 03-13 12:09:03 [executor_base.py:111] # rocm blocks: 73425, # CPU blocks: 2048
INFO 03-13 12:09:03 [executor_base.py:116] Maximum concurrency for 131072 tokens per request: 8.96x
INFO 03-13 12:09:20 [model_runner.py:1442] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 35/35 [00:13<00:00,  2.57it/s]
INFO 03-13 12:09:34 [model_runner.py:1570] Graph capturing finished in 14 secs, took 0.23 GiB
INFO 03-13 12:09:34 [llm_engine.py:441] init engine (profile, create kv cache, warmup model) took 40.21 seconds
+ CUDA Memory Usage (initial): <<<<<<<<<<<<<<<<<<<<<<<<<<
+ gpu memory used (GB): 5=161.68; <<<<<<<<<<<<<<<<<<<<<<<
Processed prompts: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00,  4.43it/s, est. speed input: 17.73 toks/s, output: 70.92 toks/s]
Prompt: 'San Francisco is', Generated text: ' a top tourist destination, with millions of visitors each year. Whether you’re traveling'
- INFO 03-13 12:09:42 [worker.py:133] Sleep mode freed 158.44 GiB memory, 1.40 GiB memory is still in use.
INFO 03-13 12:09:42 [executor_base.py:208] It took 1.980214 seconds to fall asleep.
+ CUDA Memory Usage (after sleep): <<<<<<<<<<<<<<<<<<<<<
+ gpu memory used (GB): 5=3.31; <<<<<<<<<<<<<<<<<<<<<<<
INFO 03-13 12:09:57 [executor_base.py:219] It took 0.297716 seconds to wake up.
+ CUDA Memory Usage (after wakeup): <<<<<<<<<<<<<<<<<<<<<
+ gpu memory used (GB): 5=161.75; <<<<<<<<<<<<<<<<<<<<<<<
Processed prompts: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00,  4.48it/s, est. speed input: 13.44 toks/s, output: 71.70 toks/s]
Prompt: 'Paris is', Generated text: ' a battleground for the left in France. Now the years since 2011 will'
[rank0]:[W313 12:10:13.471004605 ProcessGroupNCCL.cpp:1505] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
Segmentation fault

Green lines are the print results from amdsmi. Red lines are the printout from vLLM sleep feature.

TODO:

  • [x] Pass the unittest of tests/basic_correctness/test_cumem.py. Blocked by fork_new_process_for_each_test decorator, cannot use os.fork, has to use multiprocessing or torch.multiprocessing with spawn method.
  • [x] Find out what is causing the Segfault when the LLM object is destroyed.

tjtanaa avatar Mar 13 '25 12:03 tjtanaa

Thank you @tjtanaa! Unfortunately, I don't have the hardware to continue working on this PR. Feel free to send me your patch here / via a separate PR at https://github.com/HollowMan6/vllm/tree/sleep_amd, and I would be very happy to get it integrated into this PR.

HollowMan6 avatar Mar 13 '25 12:03 HollowMan6

Thank you @tjtanaa! Unfortunately, I don't have the hardware to continue working on this PR. Feel free to send me your patch here / via a separate PR at https://github.com/HollowMan6/vllm/tree/sleep_amd, and I would be very happy to get it integrated into this PR.

Sure. After I am done, I will open a PR. I am moving your branch to my fork and continue with the rest. Not sure if I can solve it as it might be a very low-level issue. Could be related to https://github.com/ROCm/hip/issues/3762 and https://github.com/ROCm/hip/issues/3763 ?

tjtanaa avatar Mar 13 '25 12:03 tjtanaa

Update No.1 03142025

We have reached a stage of able to run the feature with

  1. Mixtral-8x7B-Instruct
  2. Llama-3.1-8B-Instruct
  3. Llama-3.3-70B-Instruct
  4. Qwen-2.5-7B

Last Roadblock:

facebook/opt-125m model in the unit tests case is failing.

tjtanaa avatar Mar 14 '25 14:03 tjtanaa

@HollowMan6 could you invite me as collaborator so that I could easily fix the changes in your branch and also merge it with main whenever necessary. Or I can move to another PR for the ease of my team to also work on it. All your commits are in the PR as we branched out from your branch. We will mention you when the PR is ready as well.

tjtanaa avatar Mar 18 '25 03:03 tjtanaa

@tjtanaa invited, thanks!

HollowMan6 avatar Mar 18 '25 07:03 HollowMan6

@HollowMan6 Can you invite one more person @kliuae ? Thank you.

tjtanaa avatar Mar 18 '25 10:03 tjtanaa

@kliuae invited

HollowMan6 avatar Mar 18 '25 10:03 HollowMan6

It seems that allocating virtual memory with hipMemCreate causes it to throw an OOM error in sleep mode when trying to allocate a large contiguous memory space in one go, especially when allocating KV cache space for small models like facebook/opt-125m on MI300 devices. To circumvent this, ROCm's allocator is modified so that it segments large memory into smaller chunks and that memory allocation and mapping take place in small chunks. The additional overhead incurred seems fine for small TPs, but for large TP sizes, it may cause the first allocation and deallocation to appear slow, and this may need further optimization or PR to resolve. This feature requires ROCm 6.3.4 and updates the ROCm base image.

kliuae avatar Mar 28 '25 08:03 kliuae

@youkaichao @DarkLight1337 This feature is also ready for reviewing.

tjtanaa avatar Mar 28 '25 16:03 tjtanaa

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @HollowMan6.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 01 '25 08:04 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @HollowMan6.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 01 '25 15:04 mergify[bot]