starlette icon indicating copy to clipboard operation
starlette copied to clipboard

Consume request body in middleware is problematic

Open yihuang opened this issue 6 years ago • 30 comments

from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.middleware.base import BaseHTTPMiddleware


class SampleMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        _ = await request.form()
        return await call_next(request)


app = Starlette()


@app.route('/test', methods=['POST'])
async def test(request):
    _ = await request.form()  # blocked, because middleware already consumed request body
    return PlainTextResponse('Hello, world!')


app.add_middleware(SampleMiddleware)
$ uvicorn test:app --reload
$ curl -d "a=1" http://127.0.0.1:8000/test
# request is blocked

yihuang avatar Apr 30 '19 02:04 yihuang

Where is it blocking?

Is there a second request instance involved? (otherwise it should just return, shouldn't it?) https://github.com/encode/starlette/blob/7eb43757307d4702ee6a1f2739388242c703e47e/starlette/requests.py#L178-L193

blueyed avatar May 01 '19 03:05 blueyed

Looks like it: https://github.com/encode/starlette/blob/7eb43757307d4702ee6a1f2739388242c703e47e/starlette/middleware/base.py#L26 and https://github.com/encode/starlette/blob/7eb43757307d4702ee6a1f2739388242c703e47e/starlette/routing.py#L38.

Can you write a test for starlette itself?

(just for reference: https://github.com/encode/starlette/pull/498)

blueyed avatar May 01 '19 03:05 blueyed

"Consume request body in middleware is problematic"

Indeed. Consuming request data in middleware is problematic. Not just to Starlette, but generally, everywhere.

On the whole you should avoid doing so if at all possible.

There's some work we could do to make it work better, but there will always be cases where it can't work (eg. if you stream the request data, then it's just not going to be available anymore later down the line).

lovelydinosaur avatar May 20 '19 14:05 lovelydinosaur

Coming from FastAPI issue referenced above.

@tomchristie I don't understand the issue about consuming the request in a middleware, could you explain this point ? In fact, I have the need (which is current where I work) to log every requests received by my production server. Is there a better place than a middleware to do it and avoid duplicating code in every endpoint ? (I would like something like this https://github.com/Rhumbix/django-request-logging)

For now, I found this workaround (but that's not very pretty):

async def set_body(request: Request, body: bytes):
    async def receive() -> Message:
        return {"type": "http.request", "body": body}

    request._receive = receive

async def get_body(request: Request) -> bytes:
    body = await request.body()
    set_body(request, body)
    return body

but there will always be cases where it can't work (eg. if you stream the request data, then it's just not going to be available anymore later down the line)

I kind of disagree with your example. In fact, stream data is not stored by default, but stream metadata (is the stream closed) are; there will be an understandable error raised if someone try to stream twice, and that is enough imho. That's why if the body is cached, the stream consumption has to be cached too.

wyfo avatar Jul 19 '19 08:07 wyfo

There are plenty use cases like @wyfo mention. In my case i'm using JWT Signature to check all the data integrity, so i decode the token and then compare the decoded result with the body + query + path params of the request. I don't know any better way of doing this

lvalladares avatar Aug 19 '19 03:08 lvalladares

I'm working on a WTForms plugin for Starlette and I'm running into a similar issue. What is the recommended way to consume requst.form() in middleware?

amorey avatar Jan 16 '20 09:01 amorey

Currently, this is one of the better options: https://fastapi.tiangolo.com/advanced/custom-request-and-route/#accessing-the-request-body-in-an-exception-handler but unfortunately there still isn't a great way to do this as far as I'm aware.

dmontagu avatar Jan 16 '20 19:01 dmontagu

Thanks! Is there an equivalent to before_request() in Flask? https://flask.palletsprojects.com/en/1.1.x/api/#flask.Flask.before_request

amorey avatar Jan 17 '20 14:01 amorey

@amorey Depending on exactly what you want to do, you can either create a custom middleware that only does things before the request, or you can create a dependency that does whatever setup you need, and include it in the router or endpoint decorator if you don't need access to its return value and don't want to have an unused injected argument.

I think that's the closest you'll get to a direct translation of the before_request function, but I'm not very knowledgeable about flask.

dmontagu avatar Jan 18 '20 00:01 dmontagu

@dmontagu I am using FastAPI and trying to log all the 500 Internal Server error only by using the ServerErrorMiddleware from Starlette using add_middleware in FastAPI. Is there a way to the request JSON body in this case? It appears to be that I could consume the JSON body in HTTPException and RequestValidationError with add_exception_handler but nothing with ServerErrorMiddleware. Indeed ServerErrorMiddleware has request info but not the JSON Body.

Conversation from Gitter but lmk if anybody have feedbacks

    async def exception_handler(request, exc) -> Response:
        print(exc)
        return PlainTextResponse('Yes! it works!')

    app.add_middleware( ServerErrorMiddleware,  handler=exception_handler )

Okay I found a way to use with starlette-context to retrieve my payload only when I need logging https://github.com/tomwojcik/starlette-context/blob/e222f739f113b74c2dad772d417d7fcc6f82f0ae/examples/example_with_logger/logger.py

I am writing every request using an incoming middleware with context in starlette-context and retrieving through context.data at HTTPException and ServerErrorMiddleware. This is not required in ValidationError as FastAPI support natively by using exc.body. All this roundabout only when you need to log those failed responses.

You dont need an middleware for logging incoming requests if you dont want to simply use depends in your router

application.include_router(api_router, prefix=API_V1_STR, dependencies=[Depends(log_json)])

JHBalaji avatar Jun 08 '20 22:06 JHBalaji

is there an update for this? it seems like this is something a lot of people need for very legit use cases ( mainly logging it seems)

is there a plan to allow consuming the body in a middleware ? i see the body is cached on the request object _body, is it possible to cache it on the scope so it accessible from everywhere after it is read?

any other solution would also be ok, but i do feel this i needed

@JHBalaji how are you logging the request body to context, are you not running into the same issue when trying to access the body in the incoming request middleware?

talarari avatar Mar 04 '21 11:03 talarari

Another use case for this: decompressing a POST request's body.

Example (which does not work, but would be amazing if it did):

@app.middleware("http")
async def decompress_if_gzip(request: Request, call_next):
    if request.headers.get("content-encoding", "") == "gzip":
        body = await request.body()
        dec = gzip.decompress(body)
        request._body = dec
    response = await call_next(request)
    return response

Any better place to do this than in the middleware?

JivanRoquet avatar Mar 25 '21 22:03 JivanRoquet

@talarari I use contextmiddleware from from starlette

from starlette_context.middleware import ContextMiddleware


class ContextFromMiddleware(ContextMiddleware):
    """
    This class helps in setting a Context for a Request-Response which is missing in Starlette default implementation.
    Initialize an empty dict to write data
    """

    async def set_context(self, request: Request) -> dict:
        return {}

You would add the middleware in the app then you could access as context.data

@JivanRoquet There is a gzip middleware. You could import as

from fastapi.middleware.gzip import GZipMiddleware

JHBalaji avatar Mar 25 '21 22:03 JHBalaji

@JHBalaji unless I'm mistaken, this GZipMiddleware is to gzip the response's body — not unzipping the request's body

JivanRoquet avatar Mar 25 '21 23:03 JivanRoquet

@JivanRoquet yes but the FastAPI documentation does have what you might be looking for.

JHBalaji avatar Mar 26 '21 20:03 JHBalaji

@JHBalaji yes I've seen that since then, thank you for following up.

JivanRoquet avatar Mar 31 '21 09:03 JivanRoquet

I had the same problem but I also needed to consume the response. If anyone has this problem here is a solution I used:

from typing import Callable, Awaitable, Tuple, Dict, List

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import Scope, Message


class RequestWithBody(Request):
    def __init__(self, scope: Scope, body: bytes) -> None:
        super().__init__(scope, self._receive)
        self._body = body
        self._body_returned = False

    async def _receive(self) -> Message:
        if self._body_returned:
            return {"type": "http.disconnect"}
        else:
            self._body_returned = True
            return {"type": "http.request", "body": self._body, "more_body": False}


class CustomMiddleware(BaseHTTPMiddleware):
    async def dispatch(  # type: ignore
        self, request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]]
    ) -> Response:
            request_body_bytes = await request.body()
            
            # use request values

            request_with_body = RequestWithBody(request.scope, request_body_bytes)
            response = await call_next(request_with_body)
            response_content_bytes, response_headers, response_status = await self._get_response_params(response)
            
            # use response values

            return Response(response_content_bytes, response_status, response_headers)

    async def _get_response_params(self, response: StreamingResponse) -> Tuple[bytes, Dict[str, str], int]:
        response_byte_chunks: List[bytes] = []
        response_status: List[int] = []
        response_headers: List[Dict[str, str]] = []

        async def send(message: Message) -> None:
            if message["type"] == "http.response.start":
                response_status.append(message["status"])
                response_headers.append({k.decode("utf8"): v.decode("utf8") for k, v in message["headers"]})
            else:
                response_byte_chunks.append(message["body"])

        await response.stream_response(send)
        content = b"".join(response_byte_chunks)
        return content, response_headers[0], response_status[0]

kovalevvlad avatar Jun 29 '21 16:06 kovalevvlad

After reading through this, it seems the issue isn't with consuming the entire body, starlette actually will pump that async generator and save it to the request for subsequent requests: https://github.com/encode/starlette/blob/e45c5793611bfec606cd9c5d56887ddc67f77c38/starlette/requests.py#L225

The issue comes when getting form data. That function uses the underlying stream, not the body that may or may not be cached already: https://github.com/encode/starlette/blob/e45c5793611bfec606cd9c5d56887ddc67f77c38/starlette/requests.py#L246

A call to json() properly uses the cached data, but form() doesn't.

Give this a shot in a controller or middleware:

await request.body()
await request.body()
await request.json()

and I bet that works fine. Toss in a await request.form() and it will give you that Stream consumed error

My work around that I'm using in conjunction with FastAPI:

from collections import Callable
from io import BytesIO

from fastapi.routing import APIRoute
from starlette.datastructures import FormData
from starlette.formparsers import MultiPartParser, FormParser
from starlette.requests import Request
from starlette.responses import Response

try:
    from multipart.multipart import parse_options_header
except ImportError:  # pragma: nocover
    parse_options_header = None


class AsyncIteratorWrapper:
    """
    Small helper to turn BytesIO into async-able iterator
    """

    def __init__(self, bytes_: bytes):
        super().__init__()
        self._it = BytesIO(bytes_)

    def __aiter__(self):
        return self

    async def __anext__(self):
        try:
            value = next(self._it)
        except StopIteration:
            raise StopAsyncIteration
        return value


class PreProcessedFormRequest(Request):
    async def form(self) -> FormData:
        if not hasattr(self, "_form"):
            assert (
                parse_options_header is not None
            ), "The `python-multipart` library must be installed to use form parsing."
            content_type_header = self.headers.get("Content-Type")
            content_type, options = parse_options_header(content_type_header)

            body_iter = AsyncIteratorWrapper(await self.body())
            if content_type == b"multipart/form-data":
                multipart_parser = MultiPartParser(self.headers, body_iter)
                self._form = await multipart_parser.parse()
            elif content_type == b"application/x-www-form-urlencoded":
                form_parser = FormParser(self.headers, body_iter)
                self._form = await form_parser.parse()
            else:
                self._form = FormData()
        return self._form


class PreProcessedFormRoute(APIRoute):
    def get_route_handler(self) -> Callable:
        original_route_handler = super().get_route_handler()

        async def custom_route_handler(request: Request) -> Response:
            request = PreProcessedFormRequest(request.scope, request.receive)
            return await original_route_handler(request)

        return custom_route_handler

Then installed into FastAPI like:

app = FastAPI()
app.router.route_class = PreProcessedFormRoute

four43 avatar Jul 26 '21 17:07 four43

This PR seems like it might fix this actually: https://github.com/encode/starlette/pull/944

four43 avatar Jul 26 '21 17:07 four43

I had the same problem but I also needed to consume the response. If anyone has this problem here is a solution I used:

from typing import Callable, Awaitable, Tuple, Dict, List

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import Scope, Message


class RequestWithBody(Request):
    def __init__(self, scope: Scope, body: bytes) -> None:
        super().__init__(scope, self._receive)
        self._body = body
        self._body_returned = False

    async def _receive(self) -> Message:
        if self._body_returned:
            return {"type": "http.disconnect"}
        else:
            self._body_returned = True
            return {"type": "http.request", "body": self._body, "more_body": False}


class CustomMiddleware(BaseHTTPMiddleware):
    async def dispatch(  # type: ignore
        self, request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]]
    ) -> Response:
            request_body_bytes = await request.body()
            
            # use request values

            request_with_body = RequestWithBody(request.scope, request_body_bytes)
            response = await call_next(request_with_body)
            response_content_bytes, response_headers, response_status = await self._get_response_params(response)
            
            # use response values

            return Response(response_content_bytes, response_status, response_headers)

    async def _get_response_params(self, response: StreamingResponse) -> Tuple[bytes, Dict[str, str], int]:
        response_byte_chunks: List[bytes] = []
        response_status: List[int] = []
        response_headers: List[Dict[str, str]] = []

        async def send(message: Message) -> None:
            if message["type"] == "http.response.start":
                response_status.append(message["status"])
                response_headers.append({k.decode("utf8"): v.decode("utf8") for k, v in message["headers"]})
            else:
                response_byte_chunks.append(message["body"])

        await response.stream_response(send)
        content = b"".join(response_byte_chunks)
        return content, response_headers[0], response_status[0]

It looks like this solution is not working when there are multiple middlewares

kigawas avatar Oct 28 '21 05:10 kigawas

@kigawas I am using this in production with multiple middlewares before and after this one. Perhaps you are doing something that I am not and hence hitting this problem

Edit: trying to reuse this solution later in another project causes problems. Could be FastAPI version dependent.

kovalevvlad avatar Oct 28 '21 22:10 kovalevvlad

any updates? getting the request body in the middleware feels like something which should work...

jacksbox avatar Nov 30 '21 11:11 jacksbox

I think we can do it in this way:

  1. when you do Request(scope, receive, send) it will set self instance into scope.
  2. next time, when you instantiate a new request object, it will get the request instance from the scope (achievable by implementing __new__). This says we will always have the same Request object.

alex-oleshkevich avatar Jan 12 '22 10:01 alex-oleshkevich

Does Starlette.exception_handler count as registering a middleware function? I can't seem to consume the (or access the cached) body in custom exception handlers either.

zevisert avatar Feb 15 '22 23:02 zevisert

Same issue, yes.

lovelydinosaur avatar Feb 16 '22 08:02 lovelydinosaur

Thanks for the confirmation, I ended up adapting an idea from FastAPI. Essentially my whole app runs as one middleware, and I pulled the same the handler lookup from ExceptionMiddleware, so that I can pass the same starlette.Request to all of my custom handlers. I should be safe this way, since I'm only catching my application's base Exception, everything else will unwind to the built-in exception handlers.

zevisert avatar Feb 16 '22 21:02 zevisert

uvicorn ->app(receive,send) -> receive = queue,get() if log message by receive,receive will be consume and request will block because nothing in queue,get()

heyfavour avatar Apr 18 '22 13:04 heyfavour

I think we can do it in this way:

  1. when you do Request(scope, receive, send) it will set self instance into scope.
  2. next time, when you instantiate a new request object, it will get the request instance from the scope (achievable by implementing __new__). This says we will always have the same Request object.

You mean this? This definitely works and it's pretty smart 😄

diff --git a/starlette/requests.py b/starlette/requests.py
index 66c510c..69cad4c 100644
--- a/starlette/requests.py
+++ b/starlette/requests.py
@@ -188,11 +188,22 @@ async def empty_send(message: Message) -> typing.NoReturn:
 
 
 class Request(HTTPConnection):
+    def __new__(
+        cls, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
+    ):
+        if "request" in scope:
+            return scope["request"]
+
+        obj = object.__new__(cls)
+        scope["request"] = obj
+        return obj

kigawas avatar Jun 23 '22 10:06 kigawas

@kigawas so this modification didnt merge into master, right?

csrgxtu avatar Sep 15 '22 11:09 csrgxtu

@kigawas so this modification didnt merge into master, right?

Check this instead: https://github.com/encode/starlette/issues/1702

kigawas avatar Sep 16 '22 08:09 kigawas