fastapi icon indicating copy to clipboard operation
fastapi copied to clipboard

Add support for Request Batching and Parallel Workers for ML use cases.

Open tchaton opened this issue 2 years ago • 5 comments

First Check

  • [X] I added a very descriptive title to this issue.
  • [X] I used the GitHub search to find a similar issue and didn't find it.
  • [X] I searched the FastAPI documentation, with the integrated search.
  • [X] I already searched in Google "How to X in FastAPI" and didn't find any information.
  • [X] I already read and followed all the tutorial in the docs and didn't find an answer.
  • [X] I already checked if it is not related to FastAPI but to Pydantic.
  • [X] I already checked if it is not related to FastAPI but to Swagger UI.
  • [X] I already checked if it is not related to FastAPI but to ReDoc.

Commit to Help

  • [X] I commit to help with one of those options 👆

Example Code

NA

Description

Dear @tiangolo.

I was going through the MLServer framework from SeldonIO and found 2 interesting constructs which I believe could be upstreamed to FastAPI in some ways.

The first one is AdaptiveBatcher. This enables requests to be batched, processed as one, and unbatched. I have been trying to do that for a very long time with FastAPI and was very glad to find this implementation. I believe it could fit as a Middleware configurable based on the routes.

The second one is InferencePool which enables having the requests routed to a poll of workers to make the inference. I don't believe there is a clear API in FastAPI to enable this natively yet but my knowledge there is still limited. I think the InferencePool can still be launched on the side, but I don't know how it would work with unicorn workers.

I hope those features would find their way into FastAPI to better support ML Serving.

Wanted Solution

NA

Wanted Code

NA

Alternatives

No response

Operating System

Linux

Operating System Details

NA

FastAPI Version

0.1.0

Python Version

NA

Additional Context

No response

tchaton avatar Jul 04 '22 08:07 tchaton

Having a Celery wrapper would fit both of your use cases

iudeen avatar Jul 07 '22 22:07 iudeen

fastAPI is a web framwork built on starlette ( a simpler web framework) and pydantic

And to run fastAPI you need a web server ( ASGI compatible )

so

fastAPI is not design to batch web requests you need to use a queue system for that like celery / rabbitmq / Kafka ...

raphaelauv avatar Jul 10 '22 22:07 raphaelauv

Hey @raphaelauv @iudeen Thanks for answering!

Could you share some associated Celery examples to do so?

IMO, I think Batching would fit very nicely as a MiddleWare if the API is worked on. You would get a list of Requests inside your functions.

Personally, I think this provides value for FastAPI adoption for ML use cases, but this is left to maintenance to decide upon.

Best, Thomas Chaton. Lightning.ai Tech Lead

tchaton avatar Jul 18 '22 14:07 tchaton

@tchaton you can try this blog post

iudeen avatar Jul 18 '22 14:07 iudeen

Hi @tchaton ,

For another question on StackOverflow (here), I've created an example of how batching of request could work, using Middleware. However, I would recommend to take a look at Celery (see comment of @iudeen) and combine that with the below.

from fastapi import FastAPI, Request
import typing
import asyncio
import time 
import logging 

Scope = typing.MutableMapping[str, typing.Any]
Message = typing.MutableMapping[str, typing.Any]
Receive = typing.Callable[[], typing.Awaitable[Message]]
Send = typing.Callable[[Message], typing.Awaitable[None]]
RequestTuple = typing.Tuple[Scope, Receive, Send]

logger = logging.getLogger("uvicorn")

async def very_heavy_lifting(requests: dict[int,RequestTuple], batch_no) -> dict[int, RequestTuple]:
    #This mimics a heavy lifting function, takes a whole 3 seconds to process this batch
    logger.info(f"Heavy lifting for batch {batch_no} with {len(requests.keys())} requests")
    await asyncio.sleep(3)
    processed_requests: dict[int,RequestTuple] = {}
    for id, request in requests.items():
        request[0]["heavy_lifting_result"] = f"result of request {id} in batch {batch_no}"
        processed_requests[id] = (request[0], request[1], request[2])
    return processed_requests

class Batcher():
    def __init__(self, batch_max_size: int = 5, batch_max_seconds: int = 3) -> None:
        self.batch_max_size = batch_max_size
        self.batch_max_seconds = batch_max_seconds
        self.to_process: dict[int, RequestTuple] = {}
        self.processing: dict[int, RequestTuple] = {}
        self.processed: dict[int, RequestTuple] = {}
        self.batch_no = 1

    def start_batcher(self):
        _ = asyncio.get_event_loop()
        self.batcher_task = asyncio.create_task(self._batcher())

    async def _batcher(self):
        while True:
            time_out = time.time() + self.batch_max_seconds
            while time.time() < time_out:
                if len(self.to_process) >= self.batch_max_size:
                    logger.info(f"Batch {self.batch_no} is full \
                        (requests: {len(self.to_process.keys())}, max allowed: {self.batch_max_size})")
                    self.batch_no += 1
                    await self.process_requests(self.batch_no)

                    break
                await asyncio.sleep(0)
            else:
                if len(self.to_process)>0:
                    logger.info(f"Batch {self.batch_no} is over timelimit (requests: {len(self.to_process.keys())})")
                    self.batch_no += 1
                    await self.process_requests(self.batch_no)
            await asyncio.sleep(0)

    async def process_requests(self, batch_no: int):
        logger.info(f"Start of processing batch {batch_no}...")
        for id, request in self.to_process.items():
            self.processing[id] = request
        self.to_process = {}
        processed_requests  = await very_heavy_lifting(self.processing, batch_no)
        self.processed = processed_requests
        self.processing = {}
        logger.info(f"Finished processing batch {batch_no}")

batcher = Batcher() 

class InterceptorMiddleware():
    def __init__(self, app) -> None:
        self.app = app
        self.request_id: int = 0

    async def __call__(self, scope: Scope, receive: Receive, send: Send):
        if scope["type"] != "http":  # pragma: no cover
            await self.app(scope, receive, send)
            return

        self.request_id += 1
        current_id = self.request_id
        batcher.to_process[self.request_id] = (scope, receive, send)
        logger.info(f"Added request {current_id} to batch {batcher.batch_no}.")
        while True:
            request = batcher.processed.get(current_id, None)
            if not request:
                await asyncio.sleep(0.5)
            else:
                logger.info(f"Request {current_id} was processed, forwarding to FastAPI endpoint..")
                batcher.processed.pop(current_id)
                await self.app(request[0], request[1], request[2])
                await asyncio.sleep(0)

app = FastAPI()

@app.on_event("startup")
async def startup_event():
    batcher.start_batcher()
    return

app.add_middleware(InterceptorMiddleware)

@app.get("/")
async def root(request: Request):
    return {"Return value": request["heavy_lifting_result"]}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

Just figured that it might help you out in the short term :)

JarroVGIT avatar Jul 30 '22 11:07 JarroVGIT