langchain icon indicating copy to clipboard operation
langchain copied to clipboard

Filtering retrieval with ConversationalRetrievalChain

Open jorrgme opened this issue 1 year ago β€’ 1 comments

Hi everyone,

I'm trying to do something and I havenΒ΄t found enough information on the internet to make it work properly with Langchain. Here it is:

I want to develop a QA chat using markdown documents as knowledge source, using as relevant documents the ones corresponding to a certain documentation's version that the user will choose with a select box. To achieve that:

  1. I've built a FAISS vector store from documents located in two different folders, representing the documentation's versions. The folder structure looks like this:
.
β”œβ”€β”€ 4.14.2
β”‚   β”œβ”€β”€ folder1
β”‚   β”‚   └── file1.md
β”‚   β”œβ”€β”€ folder2
β”‚   β”‚   └── file2.md
└── 4.18.1
    β”œβ”€β”€ folder1
    β”‚   └── file3.md
    └── folder2
        └── file4.md
  1. Each document's metadata looks something like this: {'source': 'app/docs-versions/4.14.2/folder1/file1.md'}
  2. With all this I'm using a ConversationalRetrievalChain to retrieve info from the vector store and using an llm to answer questions entered via prompt:
memory = st.session_state.memory = ConversationBufferMemory(
    memory_key="chat_history", return_messages=True, output_key="answer"
)

source_filter = f'app/docs-versions/{version}/'
chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=store.as_retriever(
        search_kwargs={'filter': {'source': source_filter}}
    ),
    memory=memory,
    verbose=False,
    return_source_documents=True,
)

As you can see, as a summary, my goal is to filter the documents retrieved to use only the ones contained in a certain directory, representing the documentation's version.

Does anyone know how can I achieve this? The approximation I've tried doesn't seem to work for what I want to do and the retrieved documents are contained in both folders.

jorrgme avatar Jul 10 '23 11:07 jorrgme

Answer generated by a πŸ€–

Answer

I understand that you're trying to develop a QA chat system using markdown documents as a knowledge source. You want to filter the documents retrieved based on the version of the documentation chosen by the user. However, you're having trouble achieving this with the ConversationalRetrievalChain class in LangChain.

From what I can see, the ConversationalRetrievalChain class doesn't directly support filtering of documents based on their source path. The filter argument you're trying to use in search_kwargs isn't a supported feature of the as_retriever method or the underlying retrieval system.

However, I can suggest a workaround. You can create a custom retriever that wraps around the original retriever and applies the filtering. Here's how you can do it:

class FilteredRetriever:
    def __init__(self, retriever, filter_prefix):
        self.retriever = retriever
        self.filter_prefix = filter_prefix

    def retrieve(self, *args, **kwargs):
        results = self.retriever.retrieve(*args, **kwargs)
        return [doc for doc in results if doc['source'].startswith(self.filter_prefix)]

filtered_retriever = FilteredRetriever(store.as_retriever(), source_filter)

chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=filtered_retriever,
    memory=memory,
    verbose=False,
    return_source_documents=True,
)

In this code, FilteredRetriever is a simple wrapper that delegates the retrieval to the original retriever, and then filters the results based on the source path. You can use this FilteredRetriever in place of the original retriever when creating the ConversationalRetrievalChain.

Please note that this is a workaround and might not be the most efficient solution if you have a large number of documents. It retrieves all documents and then filters them, instead of filtering at the retrieval stage. For a more efficient solution, you might need to modify the retrieval system itself to support filtering, which would require changes in the underlying code of LangChain.

If you find this solution helpful and believe it could benefit other users, I encourage you to make a pull request to update the LangChain documentation. This way, others facing the same issue can also benefit from your experience.

Sources


This response is meant to be useful, save you time, and share context. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a πŸ‘ if this is helpful and πŸ‘Ž if it is irrelevant.

dosubot[bot] avatar Jul 10 '23 11:07 dosubot[bot]

Just in case someone faces the same problem here is the solution I found:

Extending the VectorStoreRetriever class like this:

class FilteredRetriever(VectorStoreRetriever):
    vectorstore: VectorStoreRetriever
    search_type: str = "similarity"
    search_kwargs: dict = Field(default_factory=dict)
    filter_prefix: str
    
    def get_relevant_documents(self, query: str) -> List[Document]:
        results = self.vectorstore.get_relevant_documents(query=query)
        return [doc for doc in results if doc.metadata['source'].startswith(self.filter_prefix)]

And using it like this:

filtered_retriever = FilteredRetriever(vectorstore=store.as_retriever(), filter_prefix=source_filter)

        chain = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=filtered_retriever,
            memory=st.session_state.memory,
            verbose=True,
            return_source_documents=True,
        )

I hope this works for anyone who comes up with this question.

Langchain doesn't seem to provide a way of doing this directly.

jorrgme avatar Jul 11 '23 09:07 jorrgme

How to attain the same thing if multiple filter conditions are there?

samanta-souhardya avatar Jul 20 '23 13:07 samanta-souhardya

Answer generated by a πŸ€–

Answer

I understand that you're trying to develop a QA chat system using markdown documents as a knowledge source. You want to filter the documents retrieved based on the version of the documentation chosen by the user. However, you're having trouble achieving this with the ConversationalRetrievalChain class in LangChain.

From what I can see, the ConversationalRetrievalChain class doesn't directly support filtering of documents based on their source path. The filter argument you're trying to use in search_kwargs isn't a supported feature of the as_retriever method or the underlying retrieval system.

However, I can suggest a workaround. You can create a custom retriever that wraps around the original retriever and applies the filtering. Here's how you can do it:

class FilteredRetriever:
    def __init__(self, retriever, filter_prefix):
        self.retriever = retriever
        self.filter_prefix = filter_prefix

    def retrieve(self, *args, **kwargs):
        results = self.retriever.retrieve(*args, **kwargs)
        return [doc for doc in results if doc['source'].startswith(self.filter_prefix)]

filtered_retriever = FilteredRetriever(store.as_retriever(), source_filter)

chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=filtered_retriever,
    memory=memory,
    verbose=False,
    return_source_documents=True,
)

In this code, FilteredRetriever is a simple wrapper that delegates the retrieval to the original retriever, and then filters the results based on the source path. You can use this FilteredRetriever in place of the original retriever when creating the ConversationalRetrievalChain.

Please note that this is a workaround and might not be the most efficient solution if you have a large number of documents. It retrieves all documents and then filters them, instead of filtering at the retrieval stage. For a more efficient solution, you might need to modify the retrieval system itself to support filtering, which would require changes in the underlying code of LangChain.

If you find this solution helpful and believe it could benefit other users, I encourage you to make a pull request to update the LangChain documentation. This way, others facing the same issue can also benefit from your experience.

Sources

This response is meant to be useful, save you time, and share context. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a πŸ‘ if this is helpful and πŸ‘Ž if it is irrelevant.

This approach gives us the error : instance of BaseRetriever expected (type=type_error.arbitrary_type; expected_arbitrary_type=BaseRetriever)

samanta-souhardya avatar Jul 20 '23 16:07 samanta-souhardya

Just in case someone faces the same problem here is the solution I found:

Extending the VectorStoreRetriever class like this:

class FilteredRetriever(VectorStoreRetriever):
    vectorstore: VectorStoreRetriever
    search_type: str = "similarity"
    search_kwargs: dict = Field(default_factory=dict)
    filter_prefix: str
    
    def get_relevant_documents(self, query: str) -> List[Document]:
        results = self.vectorstore.get_relevant_documents(query=query)
        return [doc for doc in results if doc.metadata['source'].startswith(self.filter_prefix)]

And using it like this:

filtered_retriever = FilteredRetriever(vectorstore=store.as_retriever(), filter_prefix=source_filter)

        chain = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=filtered_retriever,
            memory=st.session_state.memory,
            verbose=True,
            return_source_documents=True,
        )

I hope this works for anyone who comes up with this question.

Langchain doesn't seem to provide a way of doing this directly.

Hey, I tried this method for my usecase but i am facing certain errors can you please give me the dependencies of these what you used? def get_semantic_answer_lang_chain(self, question, chat_history): # filtered_retriever = FilteredRetriever(self.vector_store.as_retriever(),"[https://iplanrequirementsstr.blob.core.windows.net/documents/converted/20.pdf.txt]") question_generator = LLMChain(llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=False) doc_chain = load_qa_with_sources_chain(self.llm, chain_type="stuff", verbose=False, prompt=self.prompt) chain = ConversationalRetrievalChain( retriever=self.vector_store.as_retriever(), question_generator=question_generator, combine_docs_chain=doc_chain, return_source_documents=True, # top_k_docs_for_context= self.k ) result = chain({"question": question, "chat_history": chat_history}) sources = "\n".join(set(map(lambda x: x.metadata["source"], result['source_documents'])))

    container_sas = self.blob_client.get_container_sas()
    contextDict ={}
    for res in result['source_documents']:
        source_key = self.filter_sourcesLinks(res.metadata['source'].replace('_SAS_TOKEN_PLACEHOLDER_', container_sas)).replace('\n', '').replace(' ', '')
        if source_key not in contextDict:
            contextDict[source_key] = []
        myPageContent = self.clean_encoding(res.page_content)
        contextDict[source_key].append(myPageContent)
    
    result['answer'] = result['answer'].split('SOURCES:')[0].split('Sources:')[0].split('SOURCE:')[0].split('Source:')[0]
    result['answer'] = self.clean_encoding(result['answer'])
    sources = sources.replace('_SAS_TOKEN_PLACEHOLDER_', container_sas)
    sources = self.filter_sourcesLinks(sources)

    return question, result['answer'], contextDict, sources

This is the code and we are using azure vector search in our usecase. Can you please tell me a workaround for azure search vector store retriever

Gauravmahapatrocdm avatar Apr 15 '24 10:04 Gauravmahapatrocdm