starlette icon indicating copy to clipboard operation
starlette copied to clipboard

fix: add websocket prefix to StreamingResponse and FileResponse denia…

Open martin-joshy opened this issue 2 months ago • 2 comments

…l (#3048)

Hi, this is my first time contribution to open source ever, welcome to all kind feedbacks

Summary

StreamingResponse and FileResponse sent raw http.response.* messages during send_denial_response, but WebSocket context requires websocket.http.response.* prefix.

This caused:

RuntimeError: Expected ASGI message "websocket.accept","websocket.close" or "websocket.http.response.start",but got 'http.response.start'

# Checklist

- [x] I understand that this PR may be closed in case there was no previous discussion. (This doesn't apply to typos!) → Discussion already exists: #3048
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change → Added
- [x] I've updated the documentation accordingly. → No doc changes needed — behavior is now correct as intended



martin-joshy avatar Oct 29 '25 18:10 martin-joshy

I think what we want is a send wrapper that will wrap the current send, and check if it's a websocket type.

async def _send_wrap(send: Send, is_websocket_denial: bool):
    
    async def wrapped(message: Message):
        if is_websocket_denial:
            message["type"] = "websocket." + message["type"]
        await send(message)
        
    return wrapped

The idea is to avoid the additional parameter everywhere.

Kludex avatar Oct 30 '25 09:10 Kludex

Hi, thanks for the snippet it made it lot clearer. Below is the code which I tried to run test with, but I am not get 100 % coverage. Not really sure what I should be doing.

    async def send_denial_response(self, response: Response) -> None:
        if "websocket.http.response" in self.scope.get("extensions", {}):
            wrapped_send = self._send_wrap(self.send)
            await response(self.scope, self.receive, wrapped_send)
        else:
            raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")

    @staticmethod
    def _send_wrap(send: Send) -> Send:
        async def wrapped(message: Message) -> None:
            message_type = message["type"]
            if message_type.startswith("http."):
                message["type"] = "websocket." + message_type

            await send(message)

        return wrapped

martin-joshy avatar Oct 30 '25 19:10 martin-joshy