generative_ai_with_langchain icon indicating copy to clipboard operation
generative_ai_with_langchain copied to clipboard

chat_with_retrieval occur errors when use_flare = True or use_moderation =True

Open mrchaos opened this issue 1 year ago • 1 comments

In chat_with_retrieval/app.py, an error occurs if use_flare = True and , error if use_moderation =True.

chat_with_documents.py: FlareChain and ConversationalRetrievalChain have different input and output names. For FlareChain, user_input and response are used.

if use_flare:
    params = {
        "user_input": user_query,
    }
else:
    params = {
        "question": user_query,
        "chat_history": MEMORY.chat_memory.messages,
    }
response = CONV_CHAIN.run(params, callbacks=[stream_handler])

utils.py: In def init_memory(), output_key=response should be output_key=response for FlareChain.

return ConversationBufferMemory(
        memory_key='chat_history',
        return_messages=True,
        output_key='answer'
    )

mrchaos avatar Jan 03 '24 16:01 mrchaos

Hey @mrchaos! Thanks for finding that, that's cool! Do you feel like creating another PR?

benman1 avatar Jan 03 '24 18:01 benman1

Hi,

I've been fighting with this issue for a while (and several more that appeared to me as I fixed the initial ones) and finally managed to code a fix. It is implemented with a later version of Langchain (I have upgraded it during the course to avoid issues with OpenAI's deprecated models). I also implemented some changes to avoid deprecation Warnings. I leave my changes below in case it can help anyone.

The only issue I faced is with moderation chain, which seems to have been deprecated by OpenAI and there is no straightforward way to upgrade it on a Windows environment. The rest should run just fine. In terms of performance, the FLARE chain works way worse for me as it tends to hallucinate more and be more "confused" about the documents in its context.

I can do a PR if you like, but I have not tested this with earlier versions of Langchain

  • langchain==0.1.10
  • langchain-community==0.0.25
  • langchain-core==0.1.28
  • langchain-decorators==0.5.4
  • langchain-experimental==0.0.53
  • langchain-openai==0.0.8
  • langchain-text-splitters==0.0.1

Main changes:

In utils.py:

  • Upgrade libraries import to avoid deprecation warnings
  • Add a .lower() for the extension of documents, so we do not get errors with files named in caps (e.g. "FILE.PDF")
"""Utility functions and constants.

I am having some problems caching the memory and the retrieval. When
I decorate for caching, I get streamlit init errors.
"""
import logging
import pathlib
from typing import Any

from langchain.memory import ConversationBufferMemory
from langchain.schema import Document
from langchain_community.document_loaders import (
    PyPDFLoader,
    TextLoader,
    UnstructuredEPubLoader,
    UnstructuredWordDocumentLoader,
)


def init_memory():
    """Initialize the memory for contextual conversation.

    We are caching this, so it won't be deleted
     every time, we restart the server.
     """
    return ConversationBufferMemory(
        memory_key='chat_history',
        return_messages=True,
        output_key='answer'
    )

MEMORY = init_memory()


class EpubReader(UnstructuredEPubLoader):
    def __init__(self, file_path: str | list[str], **unstructured_kwargs: Any):
        super().__init__(file_path, **unstructured_kwargs, mode="elements", strategy="fast")


class DocumentLoaderException(Exception):
    pass


class DocumentLoader(object):
    """Loads in a document with a supported extension."""
    supported_extensions = {
        ".pdf": PyPDFLoader,
        ".txt": TextLoader,
        ".epub": EpubReader,
        ".docx": UnstructuredWordDocumentLoader,
        ".doc": UnstructuredWordDocumentLoader,
    }


def load_document(temp_filepath: str) -> list[Document]:
    """Load a file and return it as a list of documents.

    Doesn't handle a lot of errors at the moment.
    """
    ext = pathlib.Path(temp_filepath).suffix.lower()
    loader = DocumentLoader.supported_extensions.get(ext)
    if not loader:
        raise DocumentLoaderException(
            f"Invalid extension type {ext}, cannot load this type of file"
        )

    loaded = loader(temp_filepath)
    docs = loaded.load()
    logging.info(docs)
    return docs

In chat_with_documents.py:

  • Upgrade libraries to avoid deprecation warnings
  • Change SimpleSequentialChain for SequentialChain due to errors on the number of inputs required by the former
  • Parameterize the output_key and input variables so that they are aligned with the chain (FLARE vs normal)
"""Chat with retrieval and embeddings."""
import logging
import os
import tempfile

from langchain.chains import (
    ConversationalRetrievalChain,
    FlareChain,
    OpenAIModerationChain,
    SequentialChain,
)
from langchain.chains.base import Chain
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.schema import BaseRetriever, Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import DocArrayInMemorySearch
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings

from chat_with_retrieval.utils import MEMORY, load_document
from config import set_environment

logging.basicConfig(encoding="utf-8", level=logging.INFO)
LOGGER = logging.getLogger()
set_environment()

# Setup LLM and QA chain; set temperature low to keep hallucinations in check
LLM = ChatOpenAI(
    model_name="gpt-3.5-turbo", temperature=0, streaming=True
)


def configure_retriever(
        docs: list[Document],
        use_compression: bool = False
) -> BaseRetriever:
    """Retriever to use."""
    # Split each document documents:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
    splits = text_splitter.split_documents(docs)

    # Create embeddings and store in vectordb:
    embeddings = OpenAIEmbeddings()
    # alternatively: HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    # Create vectordb with single call to embedding model for texts:
    vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
    retriever = vectordb.as_retriever(
        search_type="mmr", search_kwargs={
            "k": 5,
            "fetch_k": 7,
            "include_metadata": True
        },
    )
    if not use_compression:
        return retriever

    embeddings_filter = EmbeddingsFilter(
        embeddings=embeddings, similarity_threshold=0.2
    )
    return ContextualCompressionRetriever(
        base_compressor=embeddings_filter,
        base_retriever=retriever,
    )


def configure_chain(retriever: BaseRetriever, use_flare: bool = True) -> Chain:
    """Configure chain with a retriever.

    Passing in a max_tokens_limit amount automatically
    truncates the tokens when prompting your llm!
    """
    output_key = 'response' if use_flare else 'answer'
    MEMORY.output_key = output_key
    params = dict(
        llm=LLM,
        retriever=retriever,
        memory=MEMORY,
        verbose=True,
        max_tokens_limit=4000,
    )
    if use_flare:
        # different set of parameters and init
        # unfortunately, have to use "protected" class
        return FlareChain.from_llm(
            **params
        )
    return ConversationalRetrievalChain.from_llm(
        **params
    )


def configure_retrieval_chain(
        uploaded_files,
        use_compression: bool = False,
        use_flare: bool = False,
        use_moderation: bool = False
) -> Chain:
    """Read documents, configure retriever, and the chain."""
    docs = []
    temp_dir = tempfile.TemporaryDirectory()
    for file in uploaded_files:
        temp_filepath = os.path.join(temp_dir.name, file.name)
        with open(temp_filepath, "wb") as f:
            f.write(file.getvalue())
        docs.extend(load_document(temp_filepath))

    retriever = configure_retriever(docs=docs, use_compression=use_compression)
    chain = configure_chain(retriever=retriever, use_flare=use_flare)
    if not use_moderation:
        return chain

    input_variables = ["user_input"] if use_flare else ["chat_history", "question"]
    moderation_input = "response" if use_flare else "answer"
    moderation_chain = OpenAIModerationChain(input_key=moderation_input)
    return SequentialChain(chains=[chain, moderation_chain], 
                           input_variables=input_variables)

 

In app.py:

  • Parameterize the params provided to the chain depending on use_flare
  • Change .run() to .invoke() (and the respective inputs) to avoid deprecation warnings
  • Extract the actual answer from the response of the chain (.invoke() returns and object with question, context and answer)
"""Document loading functionality.

Run like this:
> PYTHONPATH=. streamlit run chat_with_retrieval/chat_with_documents.py
"""
import logging

import streamlit as st
from streamlit.external.langchain import StreamlitCallbackHandler

from chat_with_retrieval.chat_with_documents import configure_retrieval_chain
from chat_with_retrieval.utils import MEMORY, DocumentLoader

logging.basicConfig(encoding="utf-8", level=logging.INFO)
LOGGER = logging.getLogger()

st.set_page_config(page_title="LangChain: Chat with Documents", page_icon="🦜")
st.title("🦜 LangChain: Chat with Documents")


uploaded_files = st.sidebar.file_uploader(
    label="Upload files",
    type=list(DocumentLoader.supported_extensions.keys()),
    accept_multiple_files=True
)
if not uploaded_files:
    st.info("Please upload documents to continue.")
    st.stop()

# use compression by default:
use_compression = st.checkbox("compression", value=False)
use_flare = st.checkbox("flare", value=False)
use_moderation = st.checkbox("moderation", value=False)

CONV_CHAIN = configure_retrieval_chain(
    uploaded_files,
    use_compression=use_compression,
    use_flare=use_flare,
    use_moderation=use_moderation
)

if st.sidebar.button("Clear message history"):
    MEMORY.chat_memory.clear()

avatars = {"human": "user", "ai": "assistant"}

if  len(MEMORY.chat_memory.messages) == 0:
    st.chat_message("assistant").markdown("Ask me anything!")

for msg in MEMORY.chat_memory.messages:
    st.chat_message(avatars[msg.type]).write(msg.content)

assistant = st.chat_message("assistant")
if user_query := st.chat_input(placeholder="Give me 3 keywords for what you have right now"):
    st.chat_message("user").write(user_query)
    container = st.empty()
    stream_handler = StreamlitCallbackHandler(container)
    with st.chat_message("assistant"):
        params = {
            "question": user_query,
            "chat_history": MEMORY.chat_memory.messages
        }
        if use_flare:
            params = {"user_input": user_query}
        config = {'callbacks': [stream_handler]}
        response = CONV_CHAIN.invoke(input=params, config=config)
        output_key = 'response' if use_flare else 'answer'

        # Display the response from the chatbot
        if response:
            container.markdown(response[output_key])

CSalle avatar Mar 13 '24 13:03 CSalle

@CSalle that's really cool! (sorry, for only getting back to you now). I'll test this on langchain 0.0.284 soon.

I've started another branch that's called softupdate and runs on a newer version of LangChain (0.1.13). I've only updated a few notebooks so far (chapter 3), but if you like you could create a PR against that branch.

benman1 avatar Apr 04 '24 15:04 benman1

Done!

CSalle avatar Apr 04 '24 18:04 CSalle

Thanks, @CSalle and @mrchaos! Closing this now.

benman1 avatar Apr 10 '24 12:04 benman1

@mrchaos and @CSalle it took me a while, but I've finally pushed these changes to main, starting with 2b886bf029f9a64486b6b461de208ef5564d227f. It's also on the softupdate branch. Thanks both of you, that was really helpful! Please let me know if anything else needs to change,

benman1 avatar Apr 11 '24 08:04 benman1