langserve icon indicating copy to clipboard operation
langserve copied to clipboard

Default rag-conversation package does not generate follow-up answers in "chat" playground_type.

Open jo3p opened this issue 1 year ago • 0 comments

Hi all,

I've taken the default rag-conversation example (see source code here) and modified the retriever slightly to use Azure AI search. The vector store contains a synthetic dataset filled with data about disturbances in a production factory.

The code is shown below

server.py:

from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from rag_conversation import chain as rag_conversation_chain

app = FastAPI()


@app.get("/")
async def redirect_root_to_docs():
    return RedirectResponse("/docs")


add_routes(app, rag_conversation_chain, path="/rag-conversation", playground_type="chat")


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)

chain.py

import os
from operator import itemgetter
from typing import List, Tuple

from langchain_community.vectorstores.azuresearch import AzureSearch
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
    format_document,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
)

from src.utils import AzureAISearchConfig, AzureOpenAIConfig

# Load configurations
azure_ais_conf = AzureAISearchConfig.from_yaml("/some/config/location")
azure_oai_conf = AzureOpenAIConfig.from_yaml("/some/config/location")

embeddings = AzureOpenAIEmbeddings(
    azure_deployment=azure_oai_conf.embedding_model,
    openai_api_version=azure_oai_conf.api_version,
    azure_endpoint=azure_oai_conf.endpoint,
)
llm = AzureChatOpenAI(
    azure_deployment=azure_oai_conf.chat_model,
    openai_api_version=azure_oai_conf.api_version,
    azure_endpoint=azure_oai_conf.endpoint,
)
vectorstore = AzureSearch(
    azure_search_endpoint=azure_ais_conf.endpoint,
    azure_search_key=os.environ["AZURE_SEARCH_KEY"],
    index_name="langchain-vector-dummy",
    embedding_function=embeddings.embed_query,
)
retriever = vectorstore.as_retriever()

# Condense a chat history and follow-up question into a standalone question
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""  # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

# RAG answer synthesis prompt
template = """Answer the question based only on the following context:
<context>
{context}
</context>"""
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
    [
        ("system", template),
        MessagesPlaceholder(variable_name="chat_history"),
        ("user", "{question}"),
    ]
)

# Conversational Retrieval Chain
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")


def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
    doc_strings = [format_document(doc, document_prompt) for doc in docs]
    return document_separator.join(doc_strings)


def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer


# User input
class ChatHistory(BaseModel):
    chat_history: List[Tuple[str, str]] = Field(..., extra={"widget": {"type": "chat"}})
    question: str


_search_query = RunnableBranch(
    # If input includes chat_history, we condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),  # Condense follow-up question and chat into a standalone_question
        RunnablePassthrough.assign(chat_history=lambda x: _format_chat_history(x["chat_history"]))
        | CONDENSE_QUESTION_PROMPT
        | llm
        | StrOutputParser(),
    ),
    # Else, we have no chat history, so just pass through the question
    RunnableLambda(itemgetter("question")),
)

_inputs = RunnableParallel(
    {
        "question": lambda x: x["question"],
        "chat_history": lambda x: _format_chat_history(x["chat_history"]),
        "context": _search_query | retriever | _combine_documents,
    }
).with_types(input_type=ChatHistory)

chain = _inputs | ANSWER_PROMPT | llm | StrOutputParser()

When I run the code with playground_type="default" and define the chat history as follows, I get the following ouput: image

When I run the code with playground_type="chat", I get no output after the first question. image

I tried to do a further analysis and inspected the network calls. They look slightly different: The network calls for the default playground look like this: image The network calls for the chat playground look like this: image

What could be going on here?

jo3p avatar Apr 03 '24 14:04 jo3p