starlette icon indicating copy to clipboard operation
starlette copied to clipboard

Add proper synchronisation to WebSocketTestSession

Open Olocool17 opened this issue 1 year ago • 4 comments

Summary

We've been using Starlette's WebSocketTestSession in order to test some sockets on our FastAPI application, and it has worked very well. On our Windows development machines, these tests are practically instant (<0.3s), but we quickly found out that the tests could take up to 10-15 minutes on our Linux CI/CD server.

The extreme variability tipped me of that this could be a scheduling issue related to async code which led me to pinpoint the cause of this issue to Starlette's WebSocketTestSession, specifically the looping anyio.sleep(0) in _asgi_receive . Depending on event loop implementation it is not guaranteed to actually yield to another task in a timely manner: hence our tests would remain stuck on it for minutes at a time.

My solution uses anyio.Events in order to alleviate this problem and implement proper synchronisation on the _receive_queue.

Implementing the changes described in this PR resulted in our test suite times going from 10-15 minutes to <0.5s. This does not change WebSocketTestSession's interface in any way: behaviorally, everything remains the same, up to and including allowing the use of send even before entering the context with with ws_session:.

Edit: Discussion #2570 is directly relevant to this PR.

Checklist

  • [x] I understand that this PR may be closed in case there was no previous discussion. (This doesn't apply to typos!)
  • [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
  • [x] I've updated the documentation accordingly.

Olocool17 avatar May 20 '24 21:05 Olocool17

I've taken the liberty of fixing the mypy errors introduced by Jinja 3.1.4 so I can see the checks on the test runners, though this somewhat breaks the atomicity of this PR.

Olocool17 avatar May 21 '24 02:05 Olocool17

@Olocool17 Can you check if the following patch also solves your issue?

diff --git a/starlette/testclient.py b/starlette/testclient.py
index bf928d2..73cebdd 100644
--- a/starlette/testclient.py
+++ b/starlette/testclient.py
@@ -17,6 +17,7 @@ from urllib.parse import unquote, urljoin
 import anyio
 import anyio.abc
 import anyio.from_thread
+from anyio import create_memory_object_stream
 from anyio.abc import ObjectReceiveStream, ObjectSendStream
 from anyio.streams.stapled import StapledObjectStream
 
@@ -99,7 +100,11 @@ class WebSocketTestSession:
         self.scope = scope
         self.accepted_subprotocol = None
         self.portal_factory = portal_factory
-        self._receive_queue: queue.Queue[Message] = queue.Queue()
+
+        send_stream, receive_stream = create_memory_object_stream[Message](math.inf)
+        self._asgi_receive = receive_stream.receive
+        self.send = send_stream.send_nowait
+
         self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
         self.extra_headers = None
 
@@ -158,12 +163,6 @@ class WebSocketTestSession:
             await self.should_close.wait()
             tg.cancel_scope.cancel()
 
-    async def _asgi_receive(self) -> Message:
-        while self._receive_queue.empty():
-            self._queue_event = anyio.Event()
-            await self._queue_event.wait()
-        return self._receive_queue.get()
-
     async def _asgi_send(self, message: Message) -> None:
         self._send_queue.put(message)
 
@@ -188,11 +187,6 @@ class WebSocketTestSession:
                 content=b"".join(body),
             )
 
-    def send(self, message: Message) -> None:
-        self._receive_queue.put(message)
-        if hasattr(self, "_queue_event"):
-            self.portal.start_task_soon(self._queue_event.set)
-
     def send_text(self, data: str) -> None:
         self.send({"type": "websocket.receive", "text": data})
 

Kludex avatar Jun 01 '24 13:06 Kludex

@Kludex This does not work. Memory object streams use simple anyio.Event signaling in order to synchronise the streams, but trying to set() from synchronous code does not actually signal properly. (task waiting on an anyio.Event will never be awoken). This is what happens under the hood when doing send_stream.send_nowait. It being a synchronous function is quite misleading, considering it probably won't behave the way you may expect.

This is why I spin up a task to signal the event for every send: trying to do it straight from sync code will result in deadlock.

I had originally tried to use memory object streams too, but ran into this exact same issue.

Olocool17 avatar Jun 01 '24 14:06 Olocool17

I can confirm the proposed PR addresses the issue raised in discussion 2570 (returning test execution back to expected speeds) and that the alternative patch does not work.

NPrescott avatar Jun 15 '24 22:06 NPrescott

Unfortunately this patch introduced an annoying deadlock. Since producer (i.e. WebSocketTestSession::send) and consumer (i.e. WebSocketTestSession::_asgi_receive) are running in different threads, we may hit the following situation:

  • Consumer enters the waiting condition and is about to create an event object and wait for it.
  • Producer puts a message to the queue, and didn't set the event because it's not created yet.
  • Deadlock: Consumer is waiting for the event forever.

We need to think of a better solution here. :thinking:

p.s: I can easily reproduce this issue on PyPy. On CPython, the producer thread is much faster on my machine so the race never happens. But I do think it's possible to reproduce it there too.

ikalnytskyi avatar Nov 22 '24 22:11 ikalnytskyi