langchain icon indicating copy to clipboard operation
langchain copied to clipboard

Issue: Stream a response from LangChain's OpenAI with Pyton Flask API

Open zigax1 opened this issue 1 year ago β€’ 15 comments

Issue you'd like to raise.

I am using Python Flask app for chat over data. So in the console I am getting streamable response directly from the OpenAI since I can enable streming with a flag streaming=True.

The problem is, that I can't β€œforward” the stream or β€œshow” the strem than in my API call.

Code for the processing OpenAI and chain is:

def askQuestion(self, collection_id, question):
        collection_name = "collection-" + str(collection_id)
        self.llm = ChatOpenAI(model_name=self.model_name, temperature=self.temperature, openai_api_key=os.environ.get('OPENAI_API_KEY'), streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True,  output_key='answer')
        
        chroma_Vectorstore = Chroma(collection_name=collection_name, embedding_function=self.embeddingsOpenAi, client=self.chroma_client)


        self.chain = ConversationalRetrievalChain.from_llm(self.llm, chroma_Vectorstore.as_retriever(similarity_search_with_score=True),
                                                            return_source_documents=True,verbose=VERBOSE, 
                                                            memory=self.memory)
        

        result = self.chain({"question": question})
        
        res_dict = {
            "answer": result["answer"],
        }

        res_dict["source_documents"] = []

        for source in result["source_documents"]:
            res_dict["source_documents"].append({
                "page_content": source.page_content,
                "metadata":  source.metadata
            })

        return res_dict`

and the API route code:

def stream(collection_id, question):
    completion = document_thread.askQuestion(collection_id, question)
    for line in completion:
        yield 'data: %s\n\n' % line

@app.route("/collection/<int:collection_id>/ask_question", methods=["POST"])
@stream_with_context
def ask_question(collection_id):
    question = request.form["question"]
    # response_generator = document_thread.askQuestion(collection_id, question)
    # return jsonify(response_generator)

    def stream(question):
        completion = document_thread.askQuestion(collection_id, question)
        for line in completion['answer']:
            yield line

    return Response(stream(question), mimetype='text/event-stream')

I am testing my endpoint with curl and I am passing flag -N to the curl, so I should get the streamable response, if it is possible.

When I make API call first the endpoint is waiting to process the data (I can see in my terminal in VS code the streamable answer) and when finished, I get everything displayed in one go.

Thanks

Suggestion:

No response

zigax1 avatar May 18 '23 21:05 zigax1

You could use the stream_with_context function and pass in the stream generator stream https://flask.palletsprojects.com/en/2.1.x/patterns/streaming/

@app.route("/collection/<int:collection_id>/ask_question", methods=["POST"])
def ask_question(collection_id):
    question = request.form["question"]
    # response_generator = document_thread.askQuestion(collection_id, question)
    # return jsonify(response_generator)

    def stream(question):
        completion = document_thread.askQuestion(collection_id, question)
        for line in completion['answer']:
            yield line

    return app.response_class(stream_with_context(stream(question)))

AvikantSrivastava avatar May 20 '23 08:05 AvikantSrivastava

You could use the stream_with_context function and pass in the stream generator stream https://flask.palletsprojects.com/en/2.1.x/patterns/streaming/

@app.route("/collection/<int:collection_id>/ask_question", methods=["POST"])
def ask_question(collection_id):
    question = request.form["question"]
    # response_generator = document_thread.askQuestion(collection_id, question)
    # return jsonify(response_generator)

    def stream(question):
        completion = document_thread.askQuestion(collection_id, question)
        for line in completion['answer']:
            yield line

    return app.response_class(stream_with_context(stream(question)))

Sadly it doesn't work and I did exactly as you told me.

zigax1 avatar May 20 '23 12:05 zigax1

I'm also wondering how this is done. Tried stream_template, stream_with_context, and my server only sends the response once it is done loading and not while it is streaming. I also tried different callback handlers to no avail.

sunwooz avatar May 25 '23 02:05 sunwooz

@agola11 can you answer this?

I tried doing the same in FastAPI, it did not work. Raised an Issue https://github.com/hwchase17/langchain/issues/5296

AvikantSrivastava avatar May 26 '23 13:05 AvikantSrivastava

I am still playing around and trying to solve it, but without any success.

@agola11 @hwchase17 @AvikantSrivastava

For now, my code looks like this:


class MyCustomHandler(BaseCallbackHandler):
    def on_llm_new_token(self, token: str, **kwargs) -> None:
        yield token

class DocumentThread:

    def askQuestion(self, collection_id, question):
        collection_name = "collection-" + str(collection_id)
        self.llm = ChatOpenAI(model_name=self.model_name, temperature=self.temperature, openai_api_key=os.environ.get('OPENAI_API_KEY'), streaming=True, callback_manager=CallbackManager([MyCustomHandler()]))
        self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True,  output_key='answer')
        
        chroma_Vectorstore = Chroma(collection_name=collection_name, embedding_function=self.embeddingsOpenAi, client=self.chroma_client)

   
        self.chain = ConversationalRetrievalChain.from_llm(self.llm, chroma_Vectorstore.as_retriever(similarity_search_with_score=True),
                                                            return_source_documents=True,verbose=VERBOSE, 
                                                            memory=self.memory)
        
        result = self.chain({"question": question})
        
        res_dict = {
            "answer": result["answer"],
        }

        res_dict["source_documents"] = []

        for source in result["source_documents"]:
            res_dict["source_documents"].append({
                "page_content": source.page_content,
                "metadata":  source.metadata
            })

        return res_dict
    and endpoint definition:
@app.route("/collection/<int:collection_id>/ask_question", methods=["POST"])
def ask_question(collection_id):
    question = request.form["question"]

    def generate_tokens(question):  
        result = document_thread.askQuestion(collection_id, question)
        for token in result['answer']:
            yield token

    return Response(stream_with_context(generate_tokens(question)), mimetype='text/event-stream')


zigax1 avatar May 27 '23 18:05 zigax1

What you need is overwrite the StreamingStdOutCallbackHandler's 'on_llm_new_token' method, as I realized that the method only print the token in stream, but do nothing to the output. So I put the token to a Queue in a thread, then read it from the other thread. It works for me.

import queue
import sys

q = queue.Queue()
os.environ["OPENAI_API_KEY"] = "sk-your-key"
stop_item = "###finish###"

class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with q.mutex:
            q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        sys.stdout.write(token)
        sys.stdout.flush()
        q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        q.put(stop_item)


llm = ChatOpenAI(temperature=0.5, streaming=True, callbacks=[
                 StreamingStdOutCallbackHandlerYield()])

longmans avatar May 29 '23 11:05 longmans

Switched from Flask to FastAPI.. Moved to: #5409

zigax1 avatar May 29 '23 15:05 zigax1

What you need is overwrite the StreamingStdOutCallbackHandler's 'on_llm_new_token' method, as I realized that the method only print the token in stream, but do nothing to the output. So I put the token to a Queue in a thread, then read it from the other thread. It works for me.

import queue
import sys

q = queue.Queue()
os.environ["OPENAI_API_KEY"] = "sk-your-key"
stop_item = "###finish###"

class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with q.mutex:
            q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        sys.stdout.write(token)
        sys.stdout.flush()
        q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        q.put(stop_item)


llm = ChatOpenAI(temperature=0.5, streaming=True, callbacks=[
                 StreamingStdOutCallbackHandlerYield()])

working on a similar implementation but can't get it to work. would you mind sharing how you're reading the queue from the other thread?

mziru avatar May 29 '23 18:05 mziru

What you need is overwrite the StreamingStdOutCallbackHandler's 'on_llm_new_token' method, as I realized that the method only print the token in stream, but do nothing to the output. So I put the token to a Queue in a thread, then read it from the other thread. It works for me.

import queue
import sys

q = queue.Queue()
os.environ["OPENAI_API_KEY"] = "sk-your-key"
stop_item = "###finish###"

class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with q.mutex:
            q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        sys.stdout.write(token)
        sys.stdout.flush()
        q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        q.put(stop_item)


llm = ChatOpenAI(temperature=0.5, streaming=True, callbacks=[
                 StreamingStdOutCallbackHandlerYield()])

working on a similar implementation but can't get it to work. would you mind sharing how you're reading the queue from the other thread?

wait, nevermind, got it to work! thanks for the first answer.

mziru avatar May 29 '23 18:05 mziru

chain.apply doesn't return generator for synchronize function call, it's make it hard to streaming output .. why don't use the asyncio api aapply , which make it possible to do token output

should care about the AsyncIteratorCallbackHandler , it will stop iterator when stream completing, need to count the rest tokens and return as last data event

qixiang-mft avatar Jun 01 '23 05:06 qixiang-mft

With the usage of threading and callback we can have a streaming response from flask API.

In flask API, you may create a queue to register tokens through langchain's callback.

class StreamingHandler(BaseCallbackHandler):
    ...

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.queue.put(token)

You may get tokens from the same queue in your flask route.

from flask import Response, stream_with_context
import threading 

@app.route(....):
def stream_output():
   q = Queue()
   
   def generate(rq: Queue):
      ...
      # add your logic to prevent while loop
      # to run indefinitely  
      while( ...):
          yield rq.get()
   
   callback_fn = StreamingHandler(q)
   
   threading.Thread(target= askQuestion, args=(collection_id, question, callback_fn))
   return Response(stream_with_context(generate(q))

In your langchain's ChatOpenAI add the above custom callback StreamingHandler.

self.llm = ChatOpenAI(
  model_name=self.model_name, 
  temperature=self.temperature, 
  openai_api_key=os.environ.get('OPENAI_API_KEY'), 
  streaming=True, 
  callback=[callback_fn,]
)

For reference:

https://python.langchain.com/en/latest/modules/callbacks/getting_started.html#creating-a-custom-handler https://flask.palletsprojects.com/en/2.3.x/patterns/streaming/#streaming-with-context

varunsinghal avatar Jun 03 '23 17:06 varunsinghal

@varunsinghal @longmans nice work, I am building Flask-Langchain & want to include streaming functionality. Have you tested this approach with multiple concurrent requests?

Would be fantastic if one of you could open a PR to add an extension-based callback handler and route class (or decorator?) to handle streaming responses to the Flask-Langchain project - this probably isn't functionality that belongs in the main Langchain library as it is Flask-specific.

francisjervis avatar Jun 04 '23 19:06 francisjervis

@varunsinghal Thank you for the great answer! Could you elaborate more on the implementation of your method? I couldn't reproduce a code with your method to get it to work. Thanks in advance!

VionaWang avatar Jun 04 '23 23:06 VionaWang

Working on the same problem. No success at the moment... @varunsinghal I am not getting your solution tbh

riccardolinares avatar Jun 07 '23 21:06 riccardolinares

hi @VionaWang @riccardolinares can you please share your code samples, so that I can make suggestions/debug on what could be going wrong over there?

varunsinghal avatar Jun 14 '23 02:06 varunsinghal

With the usage of threading and callback we can have a streaming response from flask API.

managed to get streaming work BUT with a ConversationalRetrievalChain it's printing also the condensed question before the answer, and I tried to replace BaseCallbackHandler with FinalStreamingStdOutCallbackHandler but it's the same

manuel-84 avatar Jun 23 '23 11:06 manuel-84

solved in a very hacky way (of course can be improved), if the prompt comes from the condensator then the streaming will be discarded - so the final streamed tokens will contain only the answer without condensed question


class QueueCallback(BaseCallbackHandler):
    def __init__(self, q):
        self.q = q
        self.discard = False
    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> Any:
        if prompts[0].__contains__('Standalone question'):
            self.discard = True
        else:
            self.discard = False
    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        if not self.discard:
            self.q.put(token)
    def on_llm_end(self, *args, **kwargs: Any) -> None:
        return self.q.empty()

manuel-84 avatar Jun 23 '23 13:06 manuel-84

@stream_with_context

How did you make it work, been bugging me and also, where do you import the LLMResult from

JoAmps avatar Jun 28 '23 12:06 JoAmps

With the usage of threading and callback we can have a streaming response from flask API.

In flask API, you may create a queue to register tokens through langchain's callback.

class StreamingHandler(BaseCallbackHandler):
    ...

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.queue.put(token)

You may get tokens from the same queue in your flask route.

from flask import Response, stream_with_context
import threading 

@app.route(....):
def stream_output():
   q = Queue()
   
   def generate(rq: Queue):
      ...
      # add your logic to prevent while loop
      # to run indefinitely  
      while( ...):
          yield rq.get()
   
   callback_fn = StreamingHandler(q)
   
   threading.Thread(target= askQuestion, args=(collection_id, question, callback_fn))
   return Response(stream_with_context(generate(q))

In your langchain's ChatOpenAI add the above custom callback StreamingHandler.

self.llm = ChatOpenAI(
  model_name=self.model_name, 
  temperature=self.temperature, 
  openai_api_key=os.environ.get('OPENAI_API_KEY'), 
  streaming=True, 
  callback=[callback_fn,]
)

For reference:

https://python.langchain.com/en/latest/modules/callbacks/getting_started.html#creating-a-custom-handler https://flask.palletsprojects.com/en/2.3.x/patterns/streaming/#streaming-with-context

It would be great, if you showed the whole code

JoAmps avatar Jun 28 '23 12:06 JoAmps

Please i can't see the code of the working solution, can you please show it ?

youssef595 avatar Aug 30 '23 20:08 youssef595

Here's a full minimal working example, taking from all of the answers above (with typings, modularity using Blueprints and minimal error handling as a bonus):

To explain how it all works:

  1. The controller endpoint defines an ask_question function. This function is responsible for starting the generation process in a separate thread as soon as we hit the endpoint. Note how it uses a custom callback of type StreamingStdOutCallbackHandlerYield and sets streaming=True. It delegates all of its streaming behavior to the custom class and uses a q variable that I will talk about shortly.
  2. The return type of the controller is a Response that runs the generate function. This function is the one that's actually "listening" for the streamable LLM output and yielding the result back as a stream to the HTTP caller as soon as it gets it.
  3. The way it all works is thanks to the StreamingStdOutCallbackHandlerYield. It basically writes all LLM output as soon as it comes back from OpenAI. Note how it writes it back to a Queue object that's created at controller level.
  4. Finally, see how I stop the generate function as soon as I get a special literal named STOP_ITEM. This is returned from the custom callback when the on_llm_end is executed, or when we have an error (on_llm_error). In which case, I also return the error just before returning the STOP_ITEM.

routes/stream.py

import os
import threading
from queue import Queue

from flask import Response

from utils.streaming import StreamingStdOutCallbackHandlerYield, generate

page = Blueprint(os.path.splitext(os.path.basename(__file__))[0], __name__)

# Define the expected input type
class Input(TypedDict):
    prompt: str

@page.route("/", methods=["POST"])
def stream_text() -> Response:
    data: Input = request.get_json()

    prompt = data["prompt"]
    q = Queue()

    def ask_question(callback_fn: StreamingStdOutCallbackHandlerYield):
        # Note that a try/catch is not needed here. Callback takes care of all errors in `on_llm_error`
        llm = OpenAI(streaming=True, callbacks=[callback_fn])
        return llm(prompt=prompt)

    callback_fn = StreamingStdOutCallbackHandlerYield(q)
    threading.Thread(target=ask_question, args=(callback_fn,)).start()
    return Response(generate(q), mimetype="text/event-stream")

utils/streaming.py

import queue
from typing import Any, Dict, List, Union

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import LLMResult

STOP_ITEM = "[END]"
"""
This is a special item that is used to signal the end of the stream.
"""


class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    """
    This is a callback handler that yields the tokens as they are generated.
    For a usage example, see the :func:`generate` function below.
    """

    q: queue.Queue
    """
    The queue to write the tokens to as they are generated.
    """

    def __init__(self, q: queue.Queue) -> None:
        """
        Initialize the callback handler.
        q: The queue to write the tokens to as they are generated.
        """
        super().__init__()
        self.q = q

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with self.q.mutex:
            self.q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        # Writes to stdout
        # sys.stdout.write(token)
        # sys.stdout.flush()
        # Pass the token to the generator
        self.q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        self.q.put(STOP_ITEM)

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Run when LLM errors."""
        self.q.put("%s: %s" % (type(error).__name__, str(error)))
        self.q.put(STOP_ITEM)


def generate(rq: queue.Queue):
    """
    This is a generator that yields the items in the queue until it reaches the stop item.

    Usage example:
    ```
    def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield):
        llm = OpenAI(streaming=True, callbacks=[callback_fn])
        return llm(prompt="Write a poem about a tree.")

    @app.route("/", methods=["GET"])
    def generate_output():
        q = Queue()
        callback_fn = StreamingStdOutCallbackHandlerYield(q)
        threading.Thread(target=askQuestion, args=(callback_fn,)).start()
        return Response(generate(q), mimetype="text/event-stream")
    ```
    """
    while True:
        result: str = rq.get()
        if result == STOP_ITEM or result is None:
            break
        yield result
Complete folder structure

Here's a the working tree, if you're struggling where the files are located:

.
β”œβ”€β”€ README.md
β”œβ”€β”€ requirements.txt
└── src
    β”œβ”€β”€ main.py
    β”œβ”€β”€ routes
    β”‚   └── stream.py
    └── utils
        └── streaming.py

main.py:

from dotenv import load_dotenv
from flask import Flask
from flask_cors import CORS

from routes.stream import page as stream_route

# Load environment variables
load_dotenv(
    dotenv_path=".env",  # Relative to where the script is running from
)

app = Flask(__name__)
# See https://github.com/corydolphin/flask-cors/issues/257
app.url_map.strict_slashes = False

CORS(app)

app.register_blueprint(stream_route, url_prefix="/api/chat")

if __name__ == "__main__":
    app.run()

I will soon follow with a full repository (probably)

usersina avatar Sep 12 '23 16:09 usersina

My previous solution is a performance killer, so here's a better, more concise one:

import asyncio
import json

from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.memory import ConversationSummaryBufferMemory
from langchain.chains import ConversationChain
from langchain.llms.openai import OpenAI

@page.route("/general", methods=["POST"])
async def general_chat():
    try:
        memory = ConversationSummaryBufferMemory(
            llm=OpenAI(), chat_memory=[]
        )
        handler = AsyncIteratorCallbackHandler()
        conversation = ConversationChain(
            llm=OpenAI(streaming=True, callbacks=[handler]), memory=memory
        )

        async def ask_question_async():
            asyncio.create_task(conversation.apredict(input="Hello, how are you?"))
            async for chunk in handler.aiter():
                yield f"data: {json.dumps({'content': chunk, 'tokens': 0})}\n\n"

        return ask_question_async(), {"Content-Type": "text/event-stream"}

    except Exception as e:
        return {"error": "{}: {}".format(type(e).__name__, str(e))}, 500

Note that AsyncIteratorCallbackHandler does not work with agents yet. See this issue for more details.

usersina avatar Nov 26 '23 20:11 usersina

memory = ConversationSummaryBufferMemory(
            llm=OpenAI(), chat_memory=[]
        )
        handler = AsyncIteratorCallbackHandler()
        conversation = ConversationChain(
            llm=OpenAI(streaming=True, callbacks=[handler]), memory=memory
        )

        async def ask_question_async():
            asyncio.create_task(conversation.apredict(input="Hello, how are you?"))

What led you to choose conversation.apredict instead of the standard method of directly passing the user query to created chain?

girithodu avatar Dec 19 '23 08:12 girithodu

memory = ConversationSummaryBufferMemory(
            llm=OpenAI(), chat_memory=[]
        )
        handler = AsyncIteratorCallbackHandler()
        conversation = ConversationChain(
            llm=OpenAI(streaming=True, callbacks=[handler]), memory=memory
        )

        async def ask_question_async():
            asyncio.create_task(conversation.apredict(input="Hello, how are you?"))

What led you to choose conversation.apredict instead of the standard method of directly passing the user query to created chain?

Because apredict is asynchronous. In fact you might also be able to directly call arun IIRC. In the end, all methods explicitly make a Chain.__call__ call. I cannot say much about performance without any bench-marking though...

usersina avatar Dec 19 '23 09:12 usersina

@usersina

How about doing this using Retrieval chain, trying to but getting errrors

JoAmps avatar Jan 31 '24 08:01 JoAmps

@usersina

How about doing this using Retrieval chain, trying to but getting errrors

My previous solution is a performance killer, so here's a better, more concise one:

import asyncio
import json

from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.memory import ConversationSummaryBufferMemory
from langchain.chains import ConversationChain
from langchain.llms.openai import OpenAI

@page.route("/general", methods=["POST"])
async def general_chat():
    try:
        memory = ConversationSummaryBufferMemory(
            llm=OpenAI(), chat_memory=[]
        )
        handler = AsyncIteratorCallbackHandler()
        conversation = ConversationChain(
            llm=OpenAI(streaming=True, callbacks=[handler]), memory=memory
        )

        async def ask_question_async():
            asyncio.create_task(conversation.apredict(input="Hello, how are you?"))
            async for chunk in handler.aiter():
                yield f"data: {json.dumps({'content': chunk, 'tokens': 0})}\n\n"

        return ask_question_async(), {"Content-Type": "text/event-stream"}

    except Exception as e:
        return {"error": "{}: {}".format(type(e).__name__, str(e))}, 500

Note that AsyncIteratorCallbackHandler does not work with agents yet. See this issue for more details.

How about doing this using Retrieval chain, trying to but getting errors

JoAmps avatar Jan 31 '24 09:01 JoAmps

@JoAmps I'm not too sure without seeing any code, but I really recommend you switch over to LCEL, there's so much you can customize and implement that way, especially as you move closer to production.

usersina avatar Feb 06 '24 10:02 usersina

@usersina thanks for providing your code. I've tried what you recommended in your comment, and it works except I do not get the final output from the agent. I get the chain thought process returned in my Flask app, but it stops short of returning the final answer. What am I missing?

streaming.py

import sys
import queue
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

STOP_ITEM = "[END]"
"""
This is a special item that is used to signal the end of the stream.
"""

class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
    """
    This is a callback handler that yields the tokens as they are generated.
    For a usage example, see the :func:`generate` function below.
    """

    q: queue.Queue
    """
    The queue to write the tokens to as they are generated.
    """

    def __init__(self, q: queue.Queue) -> None:
        """
        Initialize the callback handler.
        q: The queue to write the tokens to as they are generated.
        """
        super().__init__()
        self.q = q

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """Run when LLM starts running."""
        with self.q.mutex:
            self.q.queue.clear()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """Run on new LLM token. Only available when streaming is enabled."""
        # Writes to stdout
        sys.stdout.write(token)
        sys.stdout.flush()
        # Pass the token to the generator
        self.q.put(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """Run when LLM ends running."""
        sys.stdout.write("THE END!!!")
        self.q.put(response.output)
        self.q.put(STOP_ITEM)

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Run when LLM errors."""
        sys.stdout.write(f"LLM Error: {error}\n")
        self.q.put("%s: %s" % (type(error).__name__, str(error)))
        self.q.put(STOP_ITEM)
    
    def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
        """Print out that we are entering a chain."""
        self.q.put("Entering the chain...\n\n")

    def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
        sys.stdout.write(f"Tool: {serialized['name']}\n")
        self.q.put(f"Tool: {serialized['name']}\n")
    
    def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
        sys.stdout.write(f"{action.log}\n")
        self.q.put(f"{action.log}\n")

def generate(rq: queue.Queue):
    """
    This is a generator that yields the items in the queue until it reaches the stop item.

    Usage example:
    ```
    def askQuestion(callback_fn: StreamingStdOutCallbackHandlerYield):
        llm = OpenAI(streaming=True, callbacks=[callback_fn])
        return llm(prompt="Write a poem about a tree.")

    @app.route("/", methods=["GET"])
    def generate_output():
        q = Queue()
        callback_fn = StreamingStdOutCallbackHandlerYield(q)
        threading.Thread(target=askQuestion, args=(callback_fn,)).start()
        return Response(generate(q), mimetype="text/event-stream")
    ```
    """
    while True:
        result: str = rq.get()
        if result == STOP_ITEM or result is None:
            break
        yield result
            

routes.py

@app.route('/chat', methods=['POST'])
@auth.secured()
def chat():
    message = request.json['messages']
    
    chat_message_history = CustomChatMessageHistory(
    session_id=session['conversation_id'], connection_string="sqlite:///chat_history.db"
    )
    
    q = Queue()
    callback_fn = StreamingStdOutCallbackHandlerYield(q)

    def ask_question(callback_fn: StreamingStdOutCallbackHandlerYield):
        
        # Callback manager
        cb_manager = CallbackManager(handlers=[callback_fn])
        
        ## SQLDbAgent is a custom Tool class created to Q&A over a MS SQL Database
        sql_search = SQLSearchAgent(llm=llm, k=30, callback_manager=cb_manager, return_direct=True)

        ## ChatGPTTool is a custom Tool class created to talk to ChatGPT knowledge
        chatgpt_search = ChatGPTTool(llm=llm, callback_manager=cb_manager, return_direct=True)
        tools = [sql_search, chatgpt_search]

        agent = ConversationalChatAgent.from_llm_and_tools(llm=llm, tools=tools, system_message=CUSTOM_CHATBOT_PREFIX, human_message=CUSTOM_CHATBOT_SUFFIX)
        memory = ConversationBufferWindowMemory(memory_key="chat_history", return_messages=True, k=10, chat_memory=chat_message_history)
        brain_agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, memory=memory, handle_parsing_errors=True, streaming=True)
        return brain_agent_executor.run(message['content'])
   
    threading.Thread(target=ask_question, args=(callback_fn,)).start()
    return Response(generate(q), mimetype="text/event-stream")

ElderBlade avatar Feb 20 '24 00:02 ElderBlade

@mmoore7 - there might have been a change to the stop condition, that or the tool/train of thought end event is getting called. I cannot say for sure since I have long moved from Flask and classic LangChain to LangChain Expression Language and FastAPI for better streaming.

usersina avatar Feb 20 '24 21:02 usersina

LangServe has a number of examples that get streaming working out of the box with FastAPI.

https://github.com/langchain-ai/langserve/tree/main?tab=readme-ov-file#examples

We strongly recommed using LCEL, and depending on what you're doing either using the astream API or the astream_events API.

I am marking this issue as closed as there's enough examples and documentation for folks to solve this without much difficulty.

LangServe will provide streaming that will be available to the RemoteRunnable js client in just a few lines of code!

eyurtsev avatar Mar 07 '24 21:03 eyurtsev