MLServer
MLServer copied to clipboard
Inference streaming support
This PR includes streaming support for MLServer by allowing the user to implement in the runtime the predict_stream method which expects as input a async generator of request an outputs a async generator of response.
class MyModel(MLModel):
async def predict(self, payload: InferenceRequest) -> InferenceResponse:
pass
async def predict_stream(
self, payloads: AsyncIterator[InferenceRequest]
) -> AsyncIterator[InferenceResponse]:
pass
While the input-output types for the predict remain the same, for the predict_stream the implementation can handle a stream of inputs and a stream of outputs. This design choice is quite general and can cover many input-output scenarios:
- unary input - unary output (handled by
predict) - unary input - stream output (handled by
predict_stream) - stream input - unary output (handled by
predict_stream) - stream input - stream output (handled by
predict_stream)
Although for REST, streamed input might not be a thing and currently not supported, for gRPC it is quite natural to have. In the case that a user will like to use streamed inputs, then they will have to use gRPC.
Exposed endpoints
We expose the following endpoints (+ the ones including the version) to the user:
/v2/models/{model_name}/infer/v2/models/{model_name}/infer_stream/v2/models/{model_name}/generate/v2/models/{model_name}/generate_stream
The first two are general purpose endpoints while the later two are LLM specific (see open inference protocol here). Note that the infer and generate endpoints will point to the infer implementation while infer_stream and generate_stream will point to infer_stream implementation defined above.
Client calls
REST non-streaming
import os
import requests
from mlserver import types
from mlserver.codecs import StringCodec
TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)
api_url = "http://localhost:8080/v2/models/text-model/generate"
response = requests.post(api_url, json=inference_request.dict())
response = types.InferenceResponse.parse_raw(response.text)
print(StringCodec.decode_output(response.outputs[0]))
REST streaming
import os
import httpx
from httpx_sse import connect_sse
from mlserver import types
from mlserver.codecs import StringCodec
TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)
with httpx.Client() as client:
with connect_sse(client, "POST", "http://localhost:8080/v2/models/text-model/generate_stream", json=inference_request.dict()) as event_source:
for sse in event_source.iter_sse():
response = types.InferenceResponse.parse_raw(sse.data)
print(StringCodec.decode_output(response.outputs[0]))
gRPC non-streaming
import os
import grpc
import mlserver.grpc.converters as converters
import mlserver.grpc.dataplane_pb2_grpc as dataplane
import mlserver.types as types
from mlserver.codecs import StringCodec
from mlserver.grpc.converters import ModelInferResponseConverter
TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)
# need to convert from string to bytes for grpc
inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.__root__)
inference_request_g = converters.ModelInferRequestConverter.from_types(
inference_request, model_name="text-model", model_version=None
)
grpc_channel = grpc.insecure_channel("localhost:8081")
grpc_stub = dataplane.GRPCInferenceServiceStub(grpc_channel)
response = grpc_stub.ModelInfer(inference_request_g)
response = ModelInferResponseConverter.to_types(response)
print(StringCodec.decode_output(response.outputs[0]))
gRPC streaming
import os
import grpc
import mlserver.grpc.converters as converters
import mlserver.grpc.dataplane_pb2_grpc as dataplane
import mlserver.types as types
from mlserver.codecs import StringCodec
from mlserver.grpc.converters import ModelInferResponseConverter
TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)
# need to convert from string to bytes for grpc
inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.__root__)
inference_request_g = converters.ModelInferRequestConverter.from_types(
inference_request, model_name="text-model", model_version=None
)
async def get_inference_request_stream(inference_request):
yield inference_request
async with grpc.aio.insecure_channel("localhost:8081") as grpc_channel:
grpc_stub = dataplane.GRPCInferenceServiceStub(grpc_channel)
inference_request_stream = get_inference_request_stream(inference_request_g)
async for response in grpc_stub.ModelStreamInfer(inference_request_stream):
response = ModelInferResponseConverter.to_types(response)
print(StringCodec.decode_output(response.outputs[0]))