vllm_async_service: Inject custom output formatters into VLLMHandler
Description
Hello,
Currently, the method preprocess_request() in VLLMHandler (vllm_async_service.py) initializes predefined stream and non stream output formatters. When using vllm_async_service as entry point in AWS LMI containers, defining a custom_output_formatter in model.py (appropriately decorated with @output_formatter) would not overwrite existing output formatters set by the service.
Current limitations: The functionality of specifying a custom output formatter is limited to text generation only as stated in the documentation:
TextGenerationOutput: This subclass of RequestOutput is specific to text generation tasks. Right now this is the only task supported for custom output formatter. Each text generation task can generate multiple sequences.
The output formatters utilized by the current async service operate on a richer set of protocols such as ChatCompletionResponse and CompletionResponse.
Will this change the current api? How?
Probably the api will need to be adapted to accept user supplied formatters.
Who will benefit from this enhancement?
Users who want to have a finer control over the service output
References
- list reference and related literature
- list known implementations
@frankfliu , @siddvenk
I would need your expertize on the following problem, related to the enhancement proposed above.
I am trying to write a model.py file that will be used as entrypoint for the latest AWS LMI container. Within this file, I am trying to subclass VLLMHandler in order to inject a custom vllm_non_stream_output_formatter (as an alternative to the one available here.
To this end, I am thinking about the following line:
class CustomProcessedRequest(ProcessedRequest):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Override formatter
self.non_stream_output_formatter = custom_non_stream_formatter
class CustomVLLMHandler(VLLMHandler):
def preprocess_request(self, inputs: Input) -> CustomProcessedRequest:
processed_request = super().preprocess_request(inputs)
return CustomProcessedRequest(
processed_request.vllm_request,
processed_request.inference_invoker,
custom_non_stream_formatter, # My formatter
processed_request.stream_output_formatter,
processed_request.accumulate_chunks,
processed_request.include_prompt
)
Is such thing possible? From my understanding, model.py will be copied at /opt/ml/code inside the container. Which means, I will have to use absolute paths to reference ProcessedRequest and VLLMHandler. Are these modules publicly available or hidden, because I cannot seem to locate them in djl_python.
Thank you for your time!
Yes, that should work. You can include additional files within the model directory and import those using relative imports in the model.py. For example, in /opt/ml/model/ you can have:
/opt/ml/model/
- model.py
- custom_post_processors.py
- ... <other files like model/tokenizer artifacts>
Then in your model.py you can use them like:
from djl_python.lmi_vllm.vllm_async_service import VLLMHandler
from .custom_post_processors.py import custom_non_stream_processor, custom_stream_processor
class CustomVllmHandler(VLLMHandler):
def preprocess_request(self, inputs: Input) -> ProcessedRequest:
processed_request = super().preprocess_request(inputs)
processed_request.non_stream_output_formatter = custom_non_stream_processor
processed_request.stream_output_formatter = custom_stream_processor
return processed_request
This should work for now, but it's not the best CX. We'd like to offer something a bit more robust and ergonomic, but that will take some more thought as I'd like to think through whether it makes sense to have something generic that can work for different engines (like vllm and trtllm)
@siddvenk Thank you for looking into this!
Looking forward to using the improved solution once it is ready.
On a separate note, I tested some Sagemaker endpoints hosting the LMI 15 container and noticed that awscurl versions later than 0.29.0 don't accommodate the completions schema, raising the following error:
Exception in thread "main" java.lang.IllegalArgumentException: Infinity is not a valid double value as per JSON specification. To override this behavior, use GsonBuilder.serializeSpecialFloatingPointValues() method.
at com.google.gson.Gson.checkValidFloatingPoint(Gson.java:509)
at com.google.gson.Gson$1.write(Gson.java:471)
at com.google.gson.Gson$1.write(Gson.java:454)
at com.google.gson.internal.bind.TypeAdapterRuntimeTypeWrapper.write(TypeAdapterRuntimeTypeWrapper.java:73)
at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$2.write(ReflectiveTypeAdapterFactory.java:247)
at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.write(ReflectiveTypeAdapterFactory.java:490)
at com.google.gson.Gson.toJson(Gson.java:944)
at com.google.gson.Gson.toJson(Gson.java:899)
at com.google.gson.Gson.toJson(Gson.java:848)
at com.google.gson.Gson.toJson(Gson.java:825)
at ai.djl.awscurl.Result.print(Result.java:212)
at ai.djl.awscurl.AwsCurl.run(AwsCurl.java:353)
at ai.djl.awscurl.AwsCurl.main(AwsCurl.java:98)
The error doesn't occur if inference requests follow the TGI schema or if awscurl 0.29.0 is used with completions schema. The limitation associated with the TGI schema is that in the current implementation logprobs are hardcoded to the value 1, preventing the user to leverage this parameter.
Should I create a bug issue?
All the best!
This issue is stale because it has been open for 30 days with no activity.