langserve
langserve copied to clipboard
Default rag-conversation package does not generate follow-up answers in "chat" playground_type.
Hi all,
I've taken the default rag-conversation example (see source code here) and modified the retriever slightly to use Azure AI search. The vector store contains a synthetic dataset filled with data about disturbances in a production factory.
The code is shown below
server.py:
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from rag_conversation import chain as rag_conversation_chain
app = FastAPI()
@app.get("/")
async def redirect_root_to_docs():
return RedirectResponse("/docs")
add_routes(app, rag_conversation_chain, path="/rag-conversation", playground_type="chat")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
chain.py
import os
from operator import itemgetter
from typing import List, Tuple
from langchain_community.vectorstores.azuresearch import AzureSearch
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
format_document,
)
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import (
RunnableBranch,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
)
from src.utils import AzureAISearchConfig, AzureOpenAIConfig
# Load configurations
azure_ais_conf = AzureAISearchConfig.from_yaml("/some/config/location")
azure_oai_conf = AzureOpenAIConfig.from_yaml("/some/config/location")
embeddings = AzureOpenAIEmbeddings(
azure_deployment=azure_oai_conf.embedding_model,
openai_api_version=azure_oai_conf.api_version,
azure_endpoint=azure_oai_conf.endpoint,
)
llm = AzureChatOpenAI(
azure_deployment=azure_oai_conf.chat_model,
openai_api_version=azure_oai_conf.api_version,
azure_endpoint=azure_oai_conf.endpoint,
)
vectorstore = AzureSearch(
azure_search_endpoint=azure_ais_conf.endpoint,
azure_search_key=os.environ["AZURE_SEARCH_KEY"],
index_name="langchain-vector-dummy",
embedding_function=embeddings.embed_query,
)
retriever = vectorstore.as_retriever()
# Condense a chat history and follow-up question into a standalone question
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:""" # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
# RAG answer synthesis prompt
template = """Answer the question based only on the following context:
<context>
{context}
</context>"""
ANSWER_PROMPT = ChatPromptTemplate.from_messages(
[
("system", template),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{question}"),
]
)
# Conversational Retrieval Chain
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
buffer = []
for human, ai in chat_history:
buffer.append(HumanMessage(content=human))
buffer.append(AIMessage(content=ai))
return buffer
# User input
class ChatHistory(BaseModel):
chat_history: List[Tuple[str, str]] = Field(..., extra={"widget": {"type": "chat"}})
question: str
_search_query = RunnableBranch(
# If input includes chat_history, we condense it with the follow-up question
(
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
run_name="HasChatHistoryCheck"
), # Condense follow-up question and chat into a standalone_question
RunnablePassthrough.assign(chat_history=lambda x: _format_chat_history(x["chat_history"]))
| CONDENSE_QUESTION_PROMPT
| llm
| StrOutputParser(),
),
# Else, we have no chat history, so just pass through the question
RunnableLambda(itemgetter("question")),
)
_inputs = RunnableParallel(
{
"question": lambda x: x["question"],
"chat_history": lambda x: _format_chat_history(x["chat_history"]),
"context": _search_query | retriever | _combine_documents,
}
).with_types(input_type=ChatHistory)
chain = _inputs | ANSWER_PROMPT | llm | StrOutputParser()
When I run the code with playground_type="default" and define the chat history as follows, I get the following ouput:
When I run the code with playground_type="chat", I get no output after the first question.
I tried to do a further analysis and inspected the network calls. They look slightly different:
The network calls for the default playground look like this:
The network calls for the
chat playground look like this:
What could be going on here?