haystack icon indicating copy to clipboard operation
haystack copied to clipboard

enhancement: make streaming more convenient

Open tstadel opened this issue 7 months ago • 2 comments

Is your feature request related to a problem? Please describe. Implementing streaming in a API facing app is not straight forward. Although streaming callback provides a very versatile interface, building a standard SSE streaming app needs some heavy lifting:

  • API frameworks like fastapi expect a generator for StreamingResponse whereas Haystack does not offer one (already discussed in https://github.com/deepset-ai/haystack/issues/8742)
  • rendering tool calls is error prone and highly depends on the model provider api
  • as pipelines can have multiple streaming components, getting to know which component produced a chunk/message needs additional custom code

Describe the solution you'd like A simpler way of writing fastapi based llm apps that support streaming.

Additional context If you want to use fastapi's StreamingResponse, you need to make several tweeks to get streaming to work. E.g. look at this snippet that uses a Queue as adapter between Haystack's streaming_callback approach and fastapi's Generator approach. Also note that adding additional metadata and parsing tool calls uses a decorator pattern for streaming callbacks. Ideally there would be a simpler way to do this. Any means that facilitate this custom logic would be highly appreciated:

async def run_pipeline_streaming(
    pipeline: PipelineBase,
    request: QueryStreamRequest,
    streaming_generators: list[str],
    background_tasks: BackgroundTasks,
) -> AsyncGenerator[str, None]:
    start = time()
    component_names = set(name for name, _ in pipeline.walk())
    pipeline_input = to_pipeline_input(request.model_dump())
    outputs = get_outputs(pipeline, request)
    query_id = uuid4()
    streaming_callback = CustomStreamingCallback(query_id)

    for streaming_generator in streaming_generators:
        streaming_generator_params = pipeline_input.setdefault(streaming_generator, {})
        decorated_callback = TimeToFirstTokenDecorator(
            ToolCallRenderingCallbackDecorator(
                ComponentCallbackDecorator(streaming_callback, component=streaming_generator)
            )
        )
        streaming_generator_params["streaming_callback"] = decorated_callback

    async def pipeline_run_task() -> dict[str, Any]:
        result = await _invoke_pipeline(pipeline, pipeline_input, outputs)
        streaming_callback.mark_as_done()
        return result

    # Start pipeline execution in background task group
    # Pipeline.run will be called in a separate thread while AsyncPipeline.run will be called on the main event loop
    try:
        async with asyncer.create_task_group() as task_group:
            task = task_group.soonify(pipeline_run_task)()
            async for chunk in streaming_callback.get_chunks():
                yield chunk
    except ExceptionGroup as exc_group:
        for error in exc_group.exceptions:
            # Make sure exc_info is set to the exception and not the ExceptionGroup
            # This can be removed once ExceptionGroups are supported in structlog or dc-unified-logging
            # See https://github.com/hynek/structlog/issues/676
            logger.error(f"Pipeline run failed. {error!s}", exc_info=error)
            yield error_to_stream(query_id, str(error))
        return

    try:
        haystack_result = task.value
        result = from_pipeline(haystack_result, mapping=PIPELINE_MAPPING_OUTPUT)
        query_response = _validate_query_response(result, request, haystack_result, query_id)
    except WrongQueryResponseFormatError as error:
        logger.exception(
            "Pipeline returned wrong format. Failed parsing the response. Please check your pipelines return values.",
            result=error.result,
        )
        yield error_to_stream(
            query_id,
            f"Pipeline returned wrong format. Failed parsing the response: {error!s}\n"
            f"Please check your pipelines return values: {error.result}",
        )
        return

    if request.include_result:
        payload = {"query_id": str(query_id), "result": query_response.serialize(), "type": "result"}
        yield to_stream_message(payload)


class CustomStreamingCallback:
    """
    A custom streaming callback that stores the tokens in a queue and provides a method to get the tokens as chunks.
    """

    DONE_MARKER = StreamingChunk("[DONE]", meta={"is_done": True})

    def __init__(self, query_id: UUID) -> None:
        self.query_id = query_id
        self.queue: Queue = Queue()

    def __call__(self, chunk_received: StreamingChunk) -> None:
        """
        This callback method is called when a new chunk is received from the stream.

        :param chunk_received: The chunk received from the stream.
        """
        self.queue.put_nowait(chunk_received)

    def mark_as_done(self) -> None:
        self.queue.put_nowait(self.DONE_MARKER)

    async def get_chunks(self) -> AsyncGenerator[str, None]:
        while True:
            next_chunk: StreamingChunk = await self.queue.get()
            if next_chunk == self.DONE_MARKER:
                break
            if next_chunk.content:
                payload = {
                    "query_id": str(self.query_id),
                    "delta": {"text": next_chunk.content, "meta": next_chunk.meta},
                    "type": "delta",
                }
                yield to_stream_message(payload)


class ComponentCallbackDecorator:
    """
    Decorator to augment the StreamingChunk's meta with the component's name.
    """

    def __init__(self, streaming_callback: Callable[[StreamingChunk], None], component: str) -> None:
        self.streaming_callback = streaming_callback
        self.component = component

    def __call__(self, chunk_received: StreamingChunk) -> None:
        """
        This callback method is called when a new chunk is received from the stream.

        :param chunk_received: The chunk received from the stream.
        """
        chunk_received.meta["deepset_cloud"] = {"component": self.component}
        self.streaming_callback(chunk_received)


class ToolCallRenderingCallbackDecorator:
    """
    Decorator to augment the StreamingChunk's content with the tool call data from meta.
    """

    TOOL_START = '\n\n**Tool Use:**\n```json\n{{\n  "name": "{tool_name}",\n  "arguments": '
    TOOL_END = "\n}\n```\n"

    def __init__(self, streaming_callback: Callable[[StreamingChunk], None]) -> None:
        self.streaming_callback = streaming_callback
        self._openai_tool_call_index = 0

    def __call__(self, chunk_received: StreamingChunk) -> None:
        """
        This callback method is called when a new chunk is received from the stream.

        :param chunk_received: The chunk received from the stream.
        """
        chunk_received = self._render_anthropic_tool_call(chunk_received)
        chunk_received = self._render_openai_tool_call(chunk_received)

        self.streaming_callback(chunk_received)

    def _render_openai_tool_call(self, chunk_received: StreamingChunk) -> StreamingChunk:
        tool_calls = chunk_received.meta.get("tool_calls") or []
        for tool_call in tool_calls:
            if not tool_call.function:
                continue
            # mutliple tool calls (distinguished by index) can be concatenated without finish_reason in between
            if self._openai_tool_call_index < tool_call.index:
                chunk_received.content += self.TOOL_END
            self._openai_tool_call_index = tool_call.index
            if tool_name := tool_call.function.name:
                chunk_received.content += self.TOOL_START.format(tool_name=tool_name)
            if arguments := tool_call.function.arguments:
                chunk_received.content += arguments
        if chunk_received.meta.get("finish_reason") == "tool_calls":
            chunk_received.content += self.TOOL_END
        return chunk_received

    def _render_anthropic_tool_call(self, chunk_received: StreamingChunk) -> StreamingChunk:
        content_block = chunk_received.meta.get("content_block") or {}
        if content_block.get("type") == "tool_use":
            tool_name = content_block.get("name") or ""
            content = self.TOOL_START.format(tool_name=tool_name)
            chunk_received.content += content
        delta = chunk_received.meta.get("delta") or {}
        if delta.get("type") == "input_json_delta":
            partial_json = delta.get("partial_json") or ""
            chunk_received.content += partial_json
        if delta.get("stop_reason") == "tool_use":
            content = self.TOOL_END
            chunk_received.content += content
        return chunk_received

tstadel avatar May 06 '25 09:05 tstadel

cc @mpangrazzi relevant for Hayhooks

sjrl avatar May 06 '25 10:05 sjrl

Is there any news around this one? I find pipeline level streaming to be quite limiting.

BenceV avatar Sep 14 '25 12:09 BenceV