starlette icon indicating copy to clipboard operation
starlette copied to clipboard

Shield send "http.response.start" from cancellation

Open acjh opened this issue 2 years ago • 6 comments

Fixes #1634

  • Discussion #1527
  • Caused by #1157

RuntimeError: No response returned. is raised in BaseHTTPMiddleware if request is disconnected, due to task_group.cancel_scope.cancel() in StreamingResponse.__call__.<locals>.wrap and cancellation check in await checkpoint() of MemoryObjectSendStream.send.

Let's fix this behaviour change caused by anyio integration in 0.15.0.


I managed to make this error reproducible in 0.14.2 by partially emulating 0.15.0 logic: https://github.com/acjh/starlette/commit/37dd8ace69993bf437712bc54a2fdb93116ab918

starlette/concurrency.py:

  async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
+     async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
+         await func()
+         # (starlette 0.15.0) starlette.concurrency.run_until_first_complete `task_group.cancel_scope.cancel()`
+         for task in tasks:
+             if not task.done() and task != asyncio.current_task():
+                 task.cancel()
+
-     tasks = [create_task(handler(**kwargs)) for handler, kwargs in args]
+     tasks = [create_task(run(functools.partial(handler, **kwargs))) for handler, kwargs in args]
      (done, pending) = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
      [task.cancel() for task in pending]
-     [task.result() for task in done]
+     for task in done:
+         try:
+             task.result()
+         except asyncio.CancelledError:
+             pass

starlette/middleware/base.py:

  class BaseHTTPMiddleware:
      ...

      async def call_next(self, request: Request) -> Response:
          ...
-         send = queue.put
+
+         async def send(item: typing.Any) -> None:
+             await asyncio.sleep(0)  # anyio.streams.memory.MemoryObjectSendStream.send `await checkpoint()`
+             await queue.put(item)

          ...

acjh avatar Jun 25 '22 17:06 acjh

I think this can be fixed without shielding. This test fails on master but passes with this patch. Can you also try this test on your branch to see if it's a good test / if your branch fixes it?

diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py
index 49a5e3e..5210d2d 100644
--- a/starlette/middleware/base.py
+++ b/starlette/middleware/base.py
@@ -4,7 +4,7 @@ import anyio
 
 from starlette.requests import Request
 from starlette.responses import Response, StreamingResponse
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
 DispatchFunction = typing.Callable[
@@ -12,6 +12,10 @@ DispatchFunction = typing.Callable[
 ]
 
 
+class _ClientDisconnected(Exception):
+    pass
+
+
 class BaseHTTPMiddleware:
     def __init__(
         self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
@@ -28,12 +32,18 @@ class BaseHTTPMiddleware:
             app_exc: typing.Optional[Exception] = None
             send_stream, recv_stream = anyio.create_memory_object_stream()
 
+            async def recv() -> Message:
+                message = await request.receive()
+                if message["type"] == "http.disconnect":
+                    raise _ClientDisconnected
+                return message
+
             async def coro() -> None:
                 nonlocal app_exc
 
                 async with send_stream:
                     try:
-                        await self.app(scope, request.receive, send_stream.send)
+                        await self.app(scope, recv, send_stream.send)
                     except Exception as exc:
                         app_exc = exc
 
@@ -69,7 +79,10 @@ class BaseHTTPMiddleware:
 
         async with anyio.create_task_group() as task_group:
             request = Request(scope, receive=receive)
-            response = await self.dispatch_func(request, call_next)
+            try:
+                response = await self.dispatch_func(request, call_next)
+            except _ClientDisconnected:
+                return
             await response(scope, receive, send)
             task_group.cancel_scope.cancel()
 
diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py
index 976d77b..92826bc 100644
--- a/tests/middleware/test_base.py
+++ b/tests/middleware/test_base.py
@@ -1,13 +1,17 @@
 import contextvars
+from contextlib import AsyncExitStack
+from typing import AsyncGenerator, Awaitable, Callable
 
+import anyio
 import pytest
 
 from starlette.applications import Starlette
 from starlette.middleware import Middleware
 from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import PlainTextResponse, StreamingResponse
+from starlette.requests import Request
+from starlette.responses import PlainTextResponse, Response, StreamingResponse
 from starlette.routing import Route, WebSocketRoute
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
 
 
 class CustomMiddleware(BaseHTTPMiddleware):
@@ -206,3 +210,41 @@ def test_contextvars(test_client_factory, middleware_cls: type):
     client = test_client_factory(app)
     response = client.get("/")
     assert response.status_code == 200, response.content
+
+
[email protected]
+async def test_client_disconnects_before_response_is_sent() -> None:
+    # test for https://github.com/encode/starlette/issues/1527
+    app: ASGIApp
+
+    async def homepage(request: Request):
+        await anyio.sleep(5)
+        return PlainTextResponse("hi!")
+
+    async def dispatch(
+        request: Request, call_next: Callable[[Request], Awaitable[Response]]
+    ) -> Response:
+        return await call_next(request)
+
+    app = BaseHTTPMiddleware(Route("/", homepage), dispatch=dispatch)
+    app = BaseHTTPMiddleware(app, dispatch=dispatch)
+
+    async def recv_gen() -> AsyncGenerator[Message, None]:
+        yield {"type": "http.request"}
+        yield {"type": "http.disconnect"}
+
+    async def send_gen() -> AsyncGenerator[None, Message]:
+        msg = yield
+        assert msg["type"] == "http.response.start"
+        msg = yield
+        raise AssertionError("Should not be called")
+
+    scope = {"type": "http", "method": "GET", "path": "/"}
+
+    async with AsyncExitStack() as stack:
+        recv = recv_gen()
+        stack.push_async_callback(recv.aclose)
+        send = send_gen()
+        stack.push_async_callback(send.aclose)
+        await send.__anext__()
+        await app(scope, recv.__aiter__().__anext__, send.asend)

adriangb avatar Jun 25 '22 18:06 adriangb

My fix addresses the behaviour change in StreamingResponse caused by anyio integration. Your fix pre-empts the behaviour of await recv_stream.receive() for client disconnection in BaseHTTPMiddleware itself.

That behaviour of StreamingResponse is not publicly stated; your fix is sufficient for issues of BaseHTTPMiddleware usage.

That test passes on my branch if recv_gen() yields "http.disconnect" twice, otherwise it raises StopAsyncIteration in listen_for_disconnect() for the second middleware. The uvicorn ASGI server will keep yielding "http.disconnect".

 async def recv_gen() -> AsyncGenerator[Message, None]:
     yield {"type": "http.request"}
     yield {"type": "http.disconnect"}
+    yield {"type": "http.disconnect"}

acjh avatar Jun 26 '22 12:06 acjh

Apologies. I've been looking at BaseHTTPMiddleware a lot lately and got hung up on that 😅.

I asked in asgiref for confirmation on the expected behavior or ASGI servers w.r.t. sending the disconnect message multiple times. I think it would be a good idea to adapt that test (or just write a new one, up to you) to the specific situation this is supposed to fix. I think a test will be required before merging this.

adriangb avatar Jun 26 '22 19:06 adriangb

I've added a test for the specific situation this is supposed to fix.

Actually, I don't think StreamingResponse should do this, since it's an intended feature of MemoryObjectSendStream. Usages of StreamingResponse with MemoryObjectSendStream can wrap it in send if desired:

async def send(msg):
    with anyio.CancelScope(shield=True):
        await send_stream.send(msg)
- await self.app(scope, request.receive, send_stream.send)
+ await self.app(scope, request.receive, send)

I think it is probably preferable to do this in BaseHTTPMiddleware than in StreamingResponse.

I am also happy to close this PR in favour of the fix that you proposed in BaseHTTPMiddleware.

acjh avatar Jun 27 '22 14:06 acjh

If that's all that's required in BaseHTTPMiddleware to fix this, I like that a lot, 1 LOC 🥳.

Also the barrier for doing something like this in BaseHTTPMiddleware is a lower: the fix is close the the source of the issue and BaseHTTPMiddleware already is dealing with streams, tasks, cancellation and such so adding some shielding isn't moving the needle too much on complexity.

adriangb avatar Jun 27 '22 17:06 adriangb

Well, it's not actually 1 LOC 😅

I have submitted PR #1710 to shield send "http.response.start" from cancellation in BaseHTTPMiddleware.

acjh avatar Jun 27 '22 18:06 acjh

  • Closed by #1715

Kludex avatar Oct 01 '22 06:10 Kludex