vllm
vllm copied to clipboard
[Core] Consolidate prompt arguments to LLM engines
Currently, LLM.generate
(and similar methods in LLMEngine
and AsyncLLMEngine
) accept prompt
, prompt_token_ids
and multi_modal_data
separately. This PR consolidates them into PromptInputs
so that only a single argument has to be passed in, using type annotations to ensure a consistent format. This reduces the chance for the user to accidentally pass in different lengths of prompt
, prompt_token_ids
, and multi_modal_data
(related checks have been removed to avoid redundant code). On the other hand, sampling_params
remains untouched because it is common to only pass a single instance even for multiple prompt
s.
This would also make it easier to define the interface for processing the inputs using HuggingFace processor, as mentioned in #4194.
API changes
The existing API of LLM.generate
is deprecated, where the parameters prompt
, prompt_token_ids
and multi_modal_data
will be replaced with inputs
. Currently, we still maintain the old API but it will be removed in a future major update. Users may update their code as follows:
Single prompt:
# No change required since the parameter is not referred by name
llm.generate("Hello, my name is")
- llm.generate(prompt="Hello, my name is")
+ llm.generate("Hello, my name is")
- llm.generate(prompt_token_ids=[1, 2, 3])
+ llm.generate({"prompt_token_ids": [1, 2, 3]})
# image is a tensor in NCHW format where N=1
- llm.generate("Hello, my name is", multi_modal_data=MultiModalData(type=..., data=image))
+ llm.generate({"prompt": "Hello, my name is", "multi_modal_data": MultiModalData(type=..., data=image)})
Multiple prompts:
# No change required since the parameter is not referred by name
llm.generate(["Hello, my name is", "The future of AI is"])
- llm.generate(prompt=["Hello, my name is", "The future of AI is"])
+ llm.generate(["Hello, my name is", "The future of AI is"])
- llm.generate(prompt_token_ids=[[1, 2, 3], [4, 5, 6]])
+ llm.generate([{"prompt_token_ids": [1, 2, 3]}, {"prompt_token_ids": [4, 5, 6]}])
# images is a tensor in NCHW format where N=len(prompts)
- prompts = ["Hello, my name is", "The future of AI is"]
- llm.generate(prompts, multi_modal_data=MultiModalData(type=..., data=images))
+ llm.generate([
+ {"prompt": prompt, "multi_modal_data": MultiModalData(type=..., data=images[i:i+1])}
+ for i, prompt in enumerate(prompts)
+ ])
Based on the examples in the documentation, most users should already prefer the first way of calling LLM.generate
; those users need not make any changes.
@ywang96 Any thoughts about this?
Hey @DarkLight1337! Sorry I've been a bit busy lately, but I will surely take a look in the upcoming week! Apologies for the delay!
I managed to get the entrypoints
test to run in a single command. However, I now get this warning when running test_oot_registration_for_api_server
after changing multiprocessing
to use the 'spawn'
start method instead of the default 'fork'
:
/usr/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
Does this mean the OS has to restart in order to clean this memory leak? Hopefully it won't become a serious issue.
~Edit: This change has been split out into #4759~ Edit 2: #4759 has been discontinued as the entrypoints tests can no longer be fit in a single GPU. However the multiprocessing issue is still relevant since I still reduced the number of commands to be run by running multiple tests in groups using pytest markers.
#3734 caused quite a few merge conflicts. Hopefully I didn't break anything.
Note that by setting gpu_memory_utilization
to a low value for tests that only use small models, we can now run multiple files in entrypoints_tests
at the same time.
However, I now get this warning when running test_oot_registration_for_api_server
after changing multiprocessing to use the 'spawn'
start method instead of the default 'fork' (otherwise the test becomes very slow):
/usr/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
Does this mean the OS has to restart in order to clean this memory leak? Hopefully it won't become a significant issue.
Edit: ~~Looking at the stdlib code, it seems that Python will perform the cleanup during shutdown regardless, so it should be fine.~~
Edit 2: Or not. According to the documentation:
On POSIX using the spawn or forkserver start methods will also start a resource tracker process which tracks the unlinked named system resources (such as named semaphores or SharedMemory objects) created by processes of the program. When all processes have exited the resource tracker unlinks any remaining tracked object. Usually there should be none, but if a process was killed by a signal there may be some “leaked” resources. (Neither leaked semaphores nor shared memory segments will be automatically unlinked until the next reboot. This is problematic for both objects because the system allows only a limited number of named semaphores, and shared memory segments occupy some space in the main memory.)
Edit 3: Using torch.multiprocessing
re-enables the use of 'fork'
start method without sacrificing speed. So it's all good now.
@DarkLight1337 Left comments & questions - PTAL!
Thanks for your time! I have responded to your comments.
@ywang96 I have finished addressing this round of comments.
I think we do not want to add output type as a parameter in
_run_engine(self, output_type: Type[_O], *...)
. If the type checking issue happen we can resolve it in some other way, passing it in as runtime argument is not very useful and hurt readability.
I have moved the type validation logic outside of those functions.
The deprecation warning should tell users what to use instead of the deprecated arguments. Currently it doesn't show that.
Done.
I think we should keep prompts/prompt_tokens_ids input for backward compatibility?
I think we should keep prompts/prompt_tokens_ids input for backward compatibility?
These arguments are still being maintained. I just decorated the relevant methods so we can deprecate them simply by turning on the corresponding flag.
Pretty sure the CI will pass now.
@njhill I have responded to your comments.
@DarkLight1337 I just went though this PR again and made a change to move offline API reference to under developer doc. #4710 was a great addition, but I think we should have links in examples to developer doc instead of putting API reference there directly.