starlette
starlette copied to clipboard
Shield send "http.response.start" from cancellation
Fixes #1634
- Discussion #1527
- Caused by #1157
RuntimeError: No response returned.
is raised in BaseHTTPMiddleware
if request is disconnected, due to task_group.cancel_scope.cancel()
in StreamingResponse.__call__.<locals>.wrap
and cancellation check in await checkpoint()
of MemoryObjectSendStream.send
.
Let's fix this behaviour change caused by anyio integration in 0.15.0.
I managed to make this error reproducible in 0.14.2 by partially emulating 0.15.0 logic: https://github.com/acjh/starlette/commit/37dd8ace69993bf437712bc54a2fdb93116ab918
starlette/concurrency.py:
async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
+ async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
+ await func()
+ # (starlette 0.15.0) starlette.concurrency.run_until_first_complete `task_group.cancel_scope.cancel()`
+ for task in tasks:
+ if not task.done() and task != asyncio.current_task():
+ task.cancel()
+
- tasks = [create_task(handler(**kwargs)) for handler, kwargs in args]
+ tasks = [create_task(run(functools.partial(handler, **kwargs))) for handler, kwargs in args]
(done, pending) = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
[task.cancel() for task in pending]
- [task.result() for task in done]
+ for task in done:
+ try:
+ task.result()
+ except asyncio.CancelledError:
+ pass
starlette/middleware/base.py:
class BaseHTTPMiddleware:
...
async def call_next(self, request: Request) -> Response:
...
- send = queue.put
+
+ async def send(item: typing.Any) -> None:
+ await asyncio.sleep(0) # anyio.streams.memory.MemoryObjectSendStream.send `await checkpoint()`
+ await queue.put(item)
...
I think this can be fixed without shielding. This test fails on master
but passes with this patch. Can you also try this test on your branch to see if it's a good test / if your branch fixes it?
diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py
index 49a5e3e..5210d2d 100644
--- a/starlette/middleware/base.py
+++ b/starlette/middleware/base.py
@@ -4,7 +4,7 @@ import anyio
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
@@ -12,6 +12,10 @@ DispatchFunction = typing.Callable[
]
+class _ClientDisconnected(Exception):
+ pass
+
+
class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
@@ -28,12 +32,18 @@ class BaseHTTPMiddleware:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()
+ async def recv() -> Message:
+ message = await request.receive()
+ if message["type"] == "http.disconnect":
+ raise _ClientDisconnected
+ return message
+
async def coro() -> None:
nonlocal app_exc
async with send_stream:
try:
- await self.app(scope, request.receive, send_stream.send)
+ await self.app(scope, recv, send_stream.send)
except Exception as exc:
app_exc = exc
@@ -69,7 +79,10 @@ class BaseHTTPMiddleware:
async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
- response = await self.dispatch_func(request, call_next)
+ try:
+ response = await self.dispatch_func(request, call_next)
+ except _ClientDisconnected:
+ return
await response(scope, receive, send)
task_group.cancel_scope.cancel()
diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py
index 976d77b..92826bc 100644
--- a/tests/middleware/test_base.py
+++ b/tests/middleware/test_base.py
@@ -1,13 +1,17 @@
import contextvars
+from contextlib import AsyncExitStack
+from typing import AsyncGenerator, Awaitable, Callable
+import anyio
import pytest
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.responses import PlainTextResponse, StreamingResponse
+from starlette.requests import Request
+from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
-from starlette.types import ASGIApp, Receive, Scope, Send
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
class CustomMiddleware(BaseHTTPMiddleware):
@@ -206,3 +210,41 @@ def test_contextvars(test_client_factory, middleware_cls: type):
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content
+
+
[email protected]
+async def test_client_disconnects_before_response_is_sent() -> None:
+ # test for https://github.com/encode/starlette/issues/1527
+ app: ASGIApp
+
+ async def homepage(request: Request):
+ await anyio.sleep(5)
+ return PlainTextResponse("hi!")
+
+ async def dispatch(
+ request: Request, call_next: Callable[[Request], Awaitable[Response]]
+ ) -> Response:
+ return await call_next(request)
+
+ app = BaseHTTPMiddleware(Route("/", homepage), dispatch=dispatch)
+ app = BaseHTTPMiddleware(app, dispatch=dispatch)
+
+ async def recv_gen() -> AsyncGenerator[Message, None]:
+ yield {"type": "http.request"}
+ yield {"type": "http.disconnect"}
+
+ async def send_gen() -> AsyncGenerator[None, Message]:
+ msg = yield
+ assert msg["type"] == "http.response.start"
+ msg = yield
+ raise AssertionError("Should not be called")
+
+ scope = {"type": "http", "method": "GET", "path": "/"}
+
+ async with AsyncExitStack() as stack:
+ recv = recv_gen()
+ stack.push_async_callback(recv.aclose)
+ send = send_gen()
+ stack.push_async_callback(send.aclose)
+ await send.__anext__()
+ await app(scope, recv.__aiter__().__anext__, send.asend)
My fix addresses the behaviour change in StreamingResponse
caused by anyio integration.
Your fix pre-empts the behaviour of await recv_stream.receive()
for client disconnection in BaseHTTPMiddleware
itself.
That behaviour of StreamingResponse
is not publicly stated; your fix is sufficient for issues of BaseHTTPMiddleware
usage.
That test passes on my branch if recv_gen()
yields "http.disconnect"
twice, otherwise it raises StopAsyncIteration
in listen_for_disconnect()
for the second middleware. The uvicorn
ASGI server will keep yielding "http.disconnect"
.
async def recv_gen() -> AsyncGenerator[Message, None]:
yield {"type": "http.request"}
yield {"type": "http.disconnect"}
+ yield {"type": "http.disconnect"}
Apologies. I've been looking at BaseHTTPMiddleware
a lot lately and got hung up on that 😅.
I asked in asgiref for confirmation on the expected behavior or ASGI servers w.r.t. sending the disconnect message multiple times. I think it would be a good idea to adapt that test (or just write a new one, up to you) to the specific situation this is supposed to fix. I think a test will be required before merging this.
I've added a test for the specific situation this is supposed to fix.
Actually, I don't think StreamingResponse
should do this, since it's an intended feature of MemoryObjectSendStream
.
Usages of StreamingResponse
with MemoryObjectSendStream
can wrap it in send
if desired:
async def send(msg):
with anyio.CancelScope(shield=True):
await send_stream.send(msg)
- await self.app(scope, request.receive, send_stream.send)
+ await self.app(scope, request.receive, send)
I think it is probably preferable to do this in BaseHTTPMiddleware
than in StreamingResponse
.
I am also happy to close this PR in favour of the fix that you proposed in BaseHTTPMiddleware
.
If that's all that's required in BaseHTTPMiddleware
to fix this, I like that a lot, 1 LOC 🥳.
Also the barrier for doing something like this in BaseHTTPMiddleware is a lower: the fix is close the the source of the issue and BaseHTTPMiddleware already is dealing with streams, tasks, cancellation and such so adding some shielding isn't moving the needle too much on complexity.
Well, it's not actually 1 LOC 😅
I have submitted PR #1710 to shield send "http.response.start" from cancellation in BaseHTTPMiddleware
.
- Closed by #1715