langchain
langchain copied to clipboard
Filtering retrieval with ConversationalRetrievalChain
Hi everyone,
I'm trying to do something and I havenΒ΄t found enough information on the internet to make it work properly with Langchain. Here it is:
I want to develop a QA chat using markdown documents as knowledge source, using as relevant documents the ones corresponding to a certain documentation's version that the user will choose with a select box. To achieve that:
- I've built a FAISS vector store from documents located in two different folders, representing the documentation's versions. The folder structure looks like this:
.
βββ 4.14.2
β βββ folder1
β β βββ file1.md
β βββ folder2
β β βββ file2.md
βββ 4.18.1
βββ folder1
β βββ file3.md
βββ folder2
βββ file4.md
- Each document's metadata looks something like this:
{'source': 'app/docs-versions/4.14.2/folder1/file1.md'}
- With all this I'm using a ConversationalRetrievalChain to retrieve info from the vector store and using an llm to answer questions entered via prompt:
memory = st.session_state.memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True, output_key="answer"
)
source_filter = f'app/docs-versions/{version}/'
chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=store.as_retriever(
search_kwargs={'filter': {'source': source_filter}}
),
memory=memory,
verbose=False,
return_source_documents=True,
)
As you can see, as a summary, my goal is to filter the documents retrieved to use only the ones contained in a certain directory, representing the documentation's version.
Does anyone know how can I achieve this? The approximation I've tried doesn't seem to work for what I want to do and the retrieved documents are contained in both folders.
Answer generated by a π€
Answer
I understand that you're trying to develop a QA chat system using markdown documents as a knowledge source. You want to filter the documents retrieved based on the version of the documentation chosen by the user. However, you're having trouble achieving this with the ConversationalRetrievalChain
class in LangChain.
From what I can see, the ConversationalRetrievalChain
class doesn't directly support filtering of documents based on their source path. The filter
argument you're trying to use in search_kwargs
isn't a supported feature of the as_retriever
method or the underlying retrieval system.
However, I can suggest a workaround. You can create a custom retriever that wraps around the original retriever and applies the filtering. Here's how you can do it:
class FilteredRetriever:
def __init__(self, retriever, filter_prefix):
self.retriever = retriever
self.filter_prefix = filter_prefix
def retrieve(self, *args, **kwargs):
results = self.retriever.retrieve(*args, **kwargs)
return [doc for doc in results if doc['source'].startswith(self.filter_prefix)]
filtered_retriever = FilteredRetriever(store.as_retriever(), source_filter)
chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=filtered_retriever,
memory=memory,
verbose=False,
return_source_documents=True,
)
In this code, FilteredRetriever
is a simple wrapper that delegates the retrieval to the original retriever, and then filters the results based on the source path. You can use this FilteredRetriever
in place of the original retriever when creating the ConversationalRetrievalChain
.
Please note that this is a workaround and might not be the most efficient solution if you have a large number of documents. It retrieves all documents and then filters them, instead of filtering at the retrieval stage. For a more efficient solution, you might need to modify the retrieval system itself to support filtering, which would require changes in the underlying code of LangChain.
If you find this solution helpful and believe it could benefit other users, I encourage you to make a pull request to update the LangChain documentation. This way, others facing the same issue can also benefit from your experience.
Sources
This response is meant to be useful, save you time, and share context. It is not meant to be a precise solution, but rather a starting point for your own research.
Help me be more useful! Please leave a π if this is helpful and π if it is irrelevant.
Just in case someone faces the same problem here is the solution I found:
Extending the VectorStoreRetriever class like this:
class FilteredRetriever(VectorStoreRetriever):
vectorstore: VectorStoreRetriever
search_type: str = "similarity"
search_kwargs: dict = Field(default_factory=dict)
filter_prefix: str
def get_relevant_documents(self, query: str) -> List[Document]:
results = self.vectorstore.get_relevant_documents(query=query)
return [doc for doc in results if doc.metadata['source'].startswith(self.filter_prefix)]
And using it like this:
filtered_retriever = FilteredRetriever(vectorstore=store.as_retriever(), filter_prefix=source_filter)
chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=filtered_retriever,
memory=st.session_state.memory,
verbose=True,
return_source_documents=True,
)
I hope this works for anyone who comes up with this question.
Langchain doesn't seem to provide a way of doing this directly.
How to attain the same thing if multiple filter conditions are there?
Answer generated by a π€
Answer
I understand that you're trying to develop a QA chat system using markdown documents as a knowledge source. You want to filter the documents retrieved based on the version of the documentation chosen by the user. However, you're having trouble achieving this with the
ConversationalRetrievalChain
class in LangChain.From what I can see, the
ConversationalRetrievalChain
class doesn't directly support filtering of documents based on their source path. Thefilter
argument you're trying to use insearch_kwargs
isn't a supported feature of theas_retriever
method or the underlying retrieval system.However, I can suggest a workaround. You can create a custom retriever that wraps around the original retriever and applies the filtering. Here's how you can do it:
class FilteredRetriever: def __init__(self, retriever, filter_prefix): self.retriever = retriever self.filter_prefix = filter_prefix def retrieve(self, *args, **kwargs): results = self.retriever.retrieve(*args, **kwargs) return [doc for doc in results if doc['source'].startswith(self.filter_prefix)] filtered_retriever = FilteredRetriever(store.as_retriever(), source_filter) chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=filtered_retriever, memory=memory, verbose=False, return_source_documents=True, )
In this code,
FilteredRetriever
is a simple wrapper that delegates the retrieval to the original retriever, and then filters the results based on the source path. You can use thisFilteredRetriever
in place of the original retriever when creating theConversationalRetrievalChain
.Please note that this is a workaround and might not be the most efficient solution if you have a large number of documents. It retrieves all documents and then filters them, instead of filtering at the retrieval stage. For a more efficient solution, you might need to modify the retrieval system itself to support filtering, which would require changes in the underlying code of LangChain.
If you find this solution helpful and believe it could benefit other users, I encourage you to make a pull request to update the LangChain documentation. This way, others facing the same issue can also benefit from your experience.
Sources
This response is meant to be useful, save you time, and share context. It is not meant to be a precise solution, but rather a starting point for your own research.
Help me be more useful! Please leave a π if this is helpful and π if it is irrelevant.
This approach gives us the error : instance of BaseRetriever expected (type=type_error.arbitrary_type; expected_arbitrary_type=BaseRetriever)
Just in case someone faces the same problem here is the solution I found:
Extending the VectorStoreRetriever class like this:
class FilteredRetriever(VectorStoreRetriever): vectorstore: VectorStoreRetriever search_type: str = "similarity" search_kwargs: dict = Field(default_factory=dict) filter_prefix: str def get_relevant_documents(self, query: str) -> List[Document]: results = self.vectorstore.get_relevant_documents(query=query) return [doc for doc in results if doc.metadata['source'].startswith(self.filter_prefix)]
And using it like this:
filtered_retriever = FilteredRetriever(vectorstore=store.as_retriever(), filter_prefix=source_filter) chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=filtered_retriever, memory=st.session_state.memory, verbose=True, return_source_documents=True, )
I hope this works for anyone who comes up with this question.
Langchain doesn't seem to provide a way of doing this directly.
Hey, I tried this method for my usecase but i am facing certain errors can you please give me the dependencies of these what you used? def get_semantic_answer_lang_chain(self, question, chat_history): # filtered_retriever = FilteredRetriever(self.vector_store.as_retriever(),"[https://iplanrequirementsstr.blob.core.windows.net/documents/converted/20.pdf.txt]") question_generator = LLMChain(llm=self.llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=False) doc_chain = load_qa_with_sources_chain(self.llm, chain_type="stuff", verbose=False, prompt=self.prompt) chain = ConversationalRetrievalChain( retriever=self.vector_store.as_retriever(), question_generator=question_generator, combine_docs_chain=doc_chain, return_source_documents=True, # top_k_docs_for_context= self.k ) result = chain({"question": question, "chat_history": chat_history}) sources = "\n".join(set(map(lambda x: x.metadata["source"], result['source_documents'])))
container_sas = self.blob_client.get_container_sas()
contextDict ={}
for res in result['source_documents']:
source_key = self.filter_sourcesLinks(res.metadata['source'].replace('_SAS_TOKEN_PLACEHOLDER_', container_sas)).replace('\n', '').replace(' ', '')
if source_key not in contextDict:
contextDict[source_key] = []
myPageContent = self.clean_encoding(res.page_content)
contextDict[source_key].append(myPageContent)
result['answer'] = result['answer'].split('SOURCES:')[0].split('Sources:')[0].split('SOURCE:')[0].split('Source:')[0]
result['answer'] = self.clean_encoding(result['answer'])
sources = sources.replace('_SAS_TOKEN_PLACEHOLDER_', container_sas)
sources = self.filter_sourcesLinks(sources)
return question, result['answer'], contextDict, sources
This is the code and we are using azure vector search in our usecase. Can you please tell me a workaround for azure search vector store retriever