ray icon indicating copy to clipboard operation
ray copied to clipboard

[Serve] model multiplexing and batching does not work together

Open patches11 opened this issue 3 months ago • 13 comments

What happened + What you expected to happen

When using multiplexing without batching, it works fine. However when adding in batching in appears the RequestContext is incorrect, and whichever model is loaded first will be used for subsequent requests

See reproduction script attached.

Start it locally:

serve run multiplexing_issue:app

Correct output with non_batched:

(venv) patches@computer ray-serve-debris % curl -X POST -H "Content-Type: application/octet-stream" "http://localhost:8000/predict?model_id=aaa&arg=1&kind=non_batched"
"Response from model_obj_for_aaa 1"

(venv) patches@computer ray-serve-debris % curl -X POST -H "Content-Type: application/octet-stream" "http://localhost:8000/predict?model_id=bbb&arg=1&kind=non_batched"
"Response from model_obj_for_bbb 1"

Correct log output:

INFO 2025-09-17 14:12:00,330 serve 17151 -- Application 'default' is ready at http://127.0.0.1:8000/.
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:12,514 default_MultiplexedModel anlir3bn 52c45e21-3731-45c1-900c-fd3d2dc7d7e1 -- Request for model_id: aaa
(ServeReplica:default:MultiplexedModel pid=17189) _RequestContext(route='/predict', request_id='52c45e21-3731-45c1-900c-fd3d2dc7d7e1', _internal_request_id='0748d457-6a2b-40ca-9a53-daa1a1117fab', app_name='default', multiplexed_model_id='aaa', grpc_context=None, is_http_request=False, cancel_on_parent_request_cancel=False)
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:12,514 default_MultiplexedModel anlir3bn 52c45e21-3731-45c1-900c-fd3d2dc7d7e1 -- Loading model 'aaa'.
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:12,514 default_MultiplexedModel anlir3bn 52c45e21-3731-45c1-900c-fd3d2dc7d7e1 -- Loading model: aaa
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:12,515 default_MultiplexedModel anlir3bn 52c45e21-3731-45c1-900c-fd3d2dc7d7e1 -- Successfully loaded model 'aaa' in 0.1ms.
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:12,526 default_MultiplexedModel anlir3bn 52c45e21-3731-45c1-900c-fd3d2dc7d7e1 -- CALL /predict OK 13.0ms
(ServeReplica:default:APIIngress pid=17179) INFO 2025-09-17 14:12:12,496 default_APIIngress 9irqpcfm 52c45e21-3731-45c1-900c-fd3d2dc7d7e1 -- model_id aaa
(ServeReplica:default:APIIngress pid=17179) INFO 2025-09-17 14:12:12,504 default_APIIngress 9irqpcfm 52c45e21-3731-45c1-900c-fd3d2dc7d7e1 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x11914a590>.
(ServeReplica:default:APIIngress pid=17179) INFO 2025-09-17 14:12:12,527 default_APIIngress 9irqpcfm 52c45e21-3731-45c1-900c-fd3d2dc7d7e1 -- POST /predict 200 32.8ms
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:19,221 default_MultiplexedModel anlir3bn abfa01e8-c9c5-42a9-8bc3-70d50be329b7 -- Request for model_id: bbb
(ServeReplica:default:MultiplexedModel pid=17189) _RequestContext(route='/predict', request_id='abfa01e8-c9c5-42a9-8bc3-70d50be329b7', _internal_request_id='2479658b-eb9a-437f-a214-0d35f28ce254', app_name='default', multiplexed_model_id='bbb', grpc_context=None, is_http_request=False, cancel_on_parent_request_cancel=False)
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:19,221 default_MultiplexedModel anlir3bn abfa01e8-c9c5-42a9-8bc3-70d50be329b7 -- Loading model 'bbb'.
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:19,221 default_MultiplexedModel anlir3bn abfa01e8-c9c5-42a9-8bc3-70d50be329b7 -- Loading model: bbb
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:19,221 default_MultiplexedModel anlir3bn abfa01e8-c9c5-42a9-8bc3-70d50be329b7 -- Successfully loaded model 'bbb' in 0.2ms.
(ServeReplica:default:MultiplexedModel pid=17189) INFO 2025-09-17 14:12:19,221 default_MultiplexedModel anlir3bn abfa01e8-c9c5-42a9-8bc3-70d50be329b7 -- CALL /predict OK 1.3ms
(ServeReplica:default:APIIngress pid=17179) INFO 2025-09-17 14:12:19,218 default_APIIngress 9irqpcfm abfa01e8-c9c5-42a9-8bc3-70d50be329b7 -- model_id bbb
(ServeReplica:default:APIIngress pid=17179) INFO 2025-09-17 14:12:19,222 default_APIIngress 9irqpcfm abfa01e8-c9c5-42a9-8bc3-70d50be329b7 -- POST /predict 200 6.0m

Incorrect output with batched:

(venv) patches@computer ray-serve-debris % curl -X POST -H "Content-Type: application/octet-stream" "http://localhost:8000/predict?model_id=aaa&arg=1&kind=batched"
"Response from model_obj_for_aaa 1"

(venv) patches@computer ray-serve-debris % curl -X POST -H "Content-Type: application/octet-stream" "http://localhost:8000/predict?model_id=bbb&arg=1&kind=batched"
"Response from model_obj_for_aaa 1"  

Logging for incorrect output:

INFO 2025-09-17 14:14:47,709 serve 18282 -- Application 'default' is ready at http://127.0.0.1:8000/.
(ServeReplica:default:APIIngress pid=18311) INFO 2025-09-17 14:14:51,193 default_APIIngress 19714ahf 7ec54120-889e-4be3-96ea-2ec8558141f1 -- model_id aaa
(ServeReplica:default:APIIngress pid=18311) INFO 2025-09-17 14:14:51,208 default_APIIngress 19714ahf 7ec54120-889e-4be3-96ea-2ec8558141f1 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x11b427f10>.
(ServeReplica:default:MultiplexedModel pid=18307) INFO 2025-09-17 14:14:51,271 default_MultiplexedModel 9r086hrb 7ec54120-889e-4be3-96ea-2ec8558141f1 -- Request for model_id: aaa
(ServeReplica:default:MultiplexedModel pid=18307) _RequestContext(route='/predict', request_id='7ec54120-889e-4be3-96ea-2ec8558141f1', _internal_request_id='dd19a6df-9bf3-4d8e-b5a7-8ef23c826444', app_name='default', multiplexed_model_id='aaa', grpc_context=None, is_http_request=False, cancel_on_parent_request_cancel=False)
(ServeReplica:default:MultiplexedModel pid=18307) INFO 2025-09-17 14:14:51,272 default_MultiplexedModel 9r086hrb 7ec54120-889e-4be3-96ea-2ec8558141f1 -- Loading model 'aaa'.
(ServeReplica:default:MultiplexedModel pid=18307) INFO 2025-09-17 14:14:51,272 default_MultiplexedModel 9r086hrb 7ec54120-889e-4be3-96ea-2ec8558141f1 -- Loading model: aaa
(ServeReplica:default:MultiplexedModel pid=18307) INFO 2025-09-17 14:14:51,272 default_MultiplexedModel 9r086hrb 7ec54120-889e-4be3-96ea-2ec8558141f1 -- Successfully loaded model 'aaa' in 0.1ms.
(ServeReplica:default:MultiplexedModel pid=18307) INFO 2025-09-17 14:14:51,277 default_MultiplexedModel 9r086hrb 7ec54120-889e-4be3-96ea-2ec8558141f1 -- CALL /predict OK 57.5ms
(ServeReplica:default:APIIngress pid=18311) INFO 2025-09-17 14:14:51,278 default_APIIngress 19714ahf 7ec54120-889e-4be3-96ea-2ec8558141f1 -- POST /predict 200 86.0ms
(ServeReplica:default:APIIngress pid=18311) INFO 2025-09-17 14:14:53,303 default_APIIngress 19714ahf 0d229b95-331c-47b0-8ab6-78a9bb81d5cf -- model_id bbb
(ServeReplica:default:MultiplexedModel pid=18307) INFO 2025-09-17 14:14:53,356 default_MultiplexedModel 9r086hrb 7ec54120-889e-4be3-96ea-2ec8558141f1 -- Request for model_id: aaa
(ServeReplica:default:MultiplexedModel pid=18307) _RequestContext(route='/predict', request_id='7ec54120-889e-4be3-96ea-2ec8558141f1', _internal_request_id='dd19a6df-9bf3-4d8e-b5a7-8ef23c826444', app_name='default', multiplexed_model_id='aaa', grpc_context=None, is_http_request=False, cancel_on_parent_request_cancel=False)
(ServeReplica:default:MultiplexedModel pid=18307) INFO 2025-09-17 14:14:53,356 default_MultiplexedModel 9r086hrb 0d229b95-331c-47b0-8ab6-78a9bb81d5cf -- CALL /predict OK 51.5ms
(ServeReplica:default:APIIngress pid=18311) INFO 2025-09-17 14:14:53,357 default_APIIngress 19714ahf 0d229b95-331c-47b0-8ab6-78a9bb81d5cf -- POST /predict 200 55.5ms

I am not a Ray expert, but in this output the RequestContext is the same for both requests, including request_id and internal_request_id, which seems odd to me

Versions / Dependencies

Ray 2.48.0 Python 3.11.6

Reproduction script

from ray import serve
from fastapi import FastAPI
import logging

logger = logging.getLogger("ray.serve")

app = FastAPI()

@serve.deployment
class MultiplexedModel:
    @serve.multiplexed(max_num_models_per_replica=2)
    async def get_model(self, model_id: str):
        logger.info(f"Loading model: {model_id}")
        return f"model_obj_for_{model_id}"

    @serve.batch(max_batch_size=2, batch_wait_timeout_s=0.05, max_concurrent_batches=1)
    async def batched(self, lst):
        model_id = serve.get_multiplexed_model_id()
        logger.info(f"Request for model_id: {model_id}\n{serve.context._get_serve_request_context()}")
        model = await self.get_model(model_id)

        result = []
        for item in lst:
            result.append(f"Response from {model} {item}")
        return result

    async def non_batched(self, item):
        model_id = serve.get_multiplexed_model_id()
        logger.info(f"Request for model_id: {model_id}\n{serve.context._get_serve_request_context()}")
        model = await self.get_model(model_id)

        return f"Response from {model} {item}"

@serve.deployment
@serve.ingress(app)
class APIIngress:
    def __init__(self, model_handle) -> None:
        self.model_handle = model_handle

    @app.post("/predict")
    async def predict(self, model_id: str, arg: str, kind: str):
        logger.info(f"model_id {model_id}")
        if kind == "batched":
            return await self.model_handle.options(multiplexed_model_id=model_id).batched.remote(arg)
        else:
            return await self.model_handle.options(multiplexed_model_id=model_id).non_batched.remote(arg)

app = APIIngress.bind(MultiplexedModel.bind())

Issue Severity

Medium: It is a significant difficulty but I can work around it.

patches11 avatar Sep 17 '25 20:09 patches11

Hey @landscapepainter hope you don't mind me tagging you here, but I saw you've been working on some stuff related to this issue with this PR: https://github.com/ray-project/ray/pull/56344

I think that fixes some of what I'm seeing here, but on my initial review it seems like more work would need to be done to get batching and multiplexing working well together, based at least on the fact that get_multiplexed_model_id() still uses _get_serve_request_context() and doesn't seem to be using your new _get_serve_batch_request_context().

I'm not super familiar with this codebase, but I would imagine what would still have to happen would be something like:

  • Update some logic so batches are intelligently routed to the correct replica, and all have the same model_id
  • Update get_multiplexed_model_id to use _get_serve_batch_request_context when appropriate
  • Add some tests to verify these work together

Anything else I'm missing here, or am I totally off base?

If you can point me in the right direction of where these changes need to occur, I could take a stab at fixing this.

Thanks!

patches11 avatar Sep 22 '25 16:09 patches11

Hi @patches11 , @landscapepainter, @abrarsheikh

I have come across this very same issue, is there any active developments happening for fixing this issue?

manickavela29 avatar Oct 24 '25 17:10 manickavela29

@patches11 @manickavela29 Thanks for reporting the issue. I'll be looking into this in the coming week and keep this thread updated.

landscapepainter avatar Oct 25 '25 00:10 landscapepainter

@patches11 @manickavela29 Thanks for reporting the issue. I'll be looking into this in the coming week and keep this thread updated.

Hi @landscapepainter, Would love to be part of contribution for this to get this done quickly, as this feature would be greatly relevant for us. Do let me know if you have started working on it, can ping in slack

manickavela29 avatar Oct 27 '25 12:10 manickavela29

@manickavela29 feel free to start a PR.

abrarsheikh avatar Oct 27 '25 18:10 abrarsheikh

the way mulitplexing and batching code is there currently seems very independent and isolated by design, I will create a new decorator by name batch_mulitplex which will handle this feature particularly, and not interfering with existing batching or multiplexing to maintain backward compatibility,

Hope that is agreeable

manickavela29 avatar Oct 28 '25 06:10 manickavela29

The problem is that when using @serve.batch, you can't call serve.get_multiplexed_model_id() directly because:

  1. A batch contains multiple requests that may each have different multiplexed model IDs
  2. The batching mechanism intentionally unsets the single request context to avoid confusion (this is the behavior in latest version, probably in older version it was not failing)
  3. You need to use serve.context._get_serve_batch_request_context() to access the individual request contexts

Here's the corrected code:

from ray import serve
from fastapi import FastAPI
import logging

logger = logging.getLogger("ray.serve")
app = FastAPI()

@serve.deployment
class MultiplexedModel:
    @serve.multiplexed(max_num_models_per_replica=2)
    async def get_model(self, model_id: str):
        logger.info(f"Loading model: {model_id}")
        return f"model_obj_for_{model_id}"

    @serve.batch(max_batch_size=2, batch_wait_timeout_s=0.05, max_concurrent_batches=1)
    async def batched(self, lst):
        # Get all request contexts in the batch
        batch_request_contexts = serve.context._get_serve_batch_request_context()
        
        logger.info(f"Batch size: {len(batch_request_contexts)}")
        
        result = []
        # Process each request with its corresponding model ID
        for item, request_context in zip(lst, batch_request_contexts):
            model_id = request_context.multiplexed_model_id
            logger.info(f"Request for model_id: {model_id}")
            
            # Load the model for this specific request
            model = await self.get_model(model_id)
            result.append(f"Response from {model} {item}")
        
        return result

    async def non_batched(self, item):
        # For non-batched requests, you CAN use get_multiplexed_model_id()
        model_id = serve.get_multiplexed_model_id()
        logger.info(f"Request for model_id: {model_id}")
        model = await self.get_model(model_id)
        return f"Response from {model} {item}"

@serve.deployment
@serve.ingress(app)
class APIIngress:
    def __init__(self, model_handle) -> None:
        self.model_handle = model_handle

    @app.post("/predict")
    async def predict(self, model_id: str, arg: str, kind: str):
        logger.info(f"model_id {model_id}")
        if kind == "batched":
            return await self.model_handle.options(multiplexed_model_id=model_id).batched.remote(arg)
        else:
            return await self.model_handle.options(multiplexed_model_id=model_id).non_batched.remote(arg)

app = APIIngress.bind(MultiplexedModel.bind())

The fundamental issue was that batching groups multiple requests together, and each request can have a different multiplexed_model_id. You need to handle them individually within the batch.

❯ curl -X POST -H "Content-Type: application/octet-stream" "http://localhost:8000/predict?model_id=aaa&arg=1&kind=batched"

"Response from model_obj_for_aaa 1"      
❯ curl -X POST -H "Content-Type: application/octet-stream" "http://localhost:8000/predict?model_id=bbb&arg=1&kind=batched"

"Response from model_obj_for_bbb 1"

abrarsheikh avatar Nov 29 '25 02:11 abrarsheikh

closing this issue. Let me know if further action is needed.

abrarsheikh avatar Dec 09 '25 21:12 abrarsheikh

Thanks for the response, one question and one comment:

  1. The docs state: "Internally, the Serve router uses the model ID in the request header to route traffic to a corresponding replica. If all replicas that have the model are over-subscribed, Ray Serve routes the request to a new replica, which then loads and caches the model from the S3 bucket.". Is this still applicable for the above code example, even with batching used?
  2. This seems to ignore the whole point of batching, which is to leverage a batch to improve model throughput of a single model. In this case a batch may contain requests for the same model, which I can see how they could then be batched together. But it says nothing of how or when that will happen. In an ideal world batching + multiplexing would ensure that both requests go to a replica with the model loaded and they are routed to the same replica if multiple requests ask for the same model, to maximize batching. Am I missing something here?

Thanks

patches11 avatar Dec 09 '25 21:12 patches11

So, the router does best effort to route requests to the replica with the preloaded model, as you described above. But there can still be cases where two requests to different models land on the same replica. I think the right thing to do here is to ensure that a single batch does not contain requests for different model_ids.

Mind contributing a change for this? I expect the changes to be in python/ray/serve/batching.py -> _BatchQueue -> _process_batches.

abrarsheikh avatar Dec 09 '25 23:12 abrarsheikh

@patches11 PTAL https://github.com/ray-project/ray/pull/59334

abrarsheikh avatar Dec 10 '25 04:12 abrarsheikh

So, the router does best effort to route requests to the replica with the preloaded model, as you described above. But there can still be cases where two requests to different models land on the same replica. I think the right thing to do here is to ensure that a single batch does not contain requests for different model_ids.

Mind contributing a change for this? I expect the changes to be in python/ray/serve/batching.py -> _BatchQueue -> _process_batches.

Can the issue of two requests landing on same replica be addressed with custom routing as suggested here? https://github.com/ray-project/ray/issues/58187

manickavela-uni avatar Dec 10 '25 05:12 manickavela-uni

Can the issue of two requests landing on same replica be addressed with custom routing as suggested here? https://github.com/ray-project/ray/issues/58187

i dont see how it solves the problem. Are you assuming that one replica be only responsible for 1 model?

abrarsheikh avatar Dec 10 '25 18:12 abrarsheikh

PTAL

That is great I will take a look

patches11 avatar Dec 11 '25 23:12 patches11