langchain
langchain copied to clipboard
RetrievalQA and AzureOpenAIEmbeddings lead to TypeError: cannot pickle '_thread.lock' object
System Info
python = "^3.10" openai = "^1.3.8" langchain = "^0.0.349"
Who can help?
No response
Information
- [ ] The official example notebooks/scripts
- [ ] My own modified scripts
Related Components
- [ ] LLMs/Chat Models
- [X] Embedding Models
- [ ] Prompts / Prompt Templates / Prompt Selectors
- [ ] Output Parsers
- [ ] Document Loaders
- [ ] Vector Stores / Retrievers
- [ ] Memory
- [ ] Agents / Agent Executors
- [ ] Tools / Toolkits
- [ ] Chains
- [ ] Callbacks/Tracing
- [ ] Async
Reproduction
import chromadb
from langchain.embeddings import AzureOpenAIEmbeddings
from langchain.vectorstores.chroma import Chroma
client = chromadb.HttpClient(
host=CHROMA_SERVER_HOST,
port=CHROMA_SERVER_HTTP_PORT,
)
embeddings = AzureOpenAIEmbeddings(
openai_api_type=AZURE_OPENAI_API_TYPE,
azure_endpoint=AZURE_OPENAI_API_BASE,
api_key=AZURE_OPENAI_API_KEY,
openai_api_version=AZURE_OPENAI_API_VERSION,
azure_deployment=AZURE_EMBEDDING_DEPLOYMENT_NAME,
)
vectordb = Chroma(
client=client,
collection_name=CHROMA_COLLECTION_NAME_FBIG_1000,
embedding_function=embeddings,
)
Expected behavior
TypeError: cannot pickle '_thread.RLock' object When I use openai = "0.28.1" it doesn't have the above error
🤖
Hello @duongvinhduc! 🖐️
I'm Dosu, a friendly bot that's here to help you navigate the world of LangChain. From bug hunting to answering questions and even becoming a contributor, I'm here to assist you. While we wait for a human maintainer to join us, feel free to reach out to me. Let's make your experience with LangChain as smooth as possible! 😊
Based on the information you've provided, it seems like you're encountering a TypeError: cannot pickle '_thread.RLock' object when using RetrievalQA and AzureOpenAIEmbeddings in LangChain. This error is common when trying to pickle (serialize) objects that are not serializable. In Python, not all objects can be serialized. Some objects like locks, file handles, database connections, etc., cannot be serialized.
In your case, the error is likely occurring because you're trying to pickle the AzureOpenAIEmbeddings object, which includes a client object that is not serializable. The client object is created in the validate_environment method:
values["client"] = openai.AzureOpenAI(**client_params).embeddings
values["async_client"] = openai.AsyncAzureOpenAI(**client_params).embeddings
These client objects are instances of the openai.AzureOpenAI and openai.AsyncAzureOpenAI classes, which likely contain non-serializable objects (like locks or open network connections).
To resolve this issue, you might need to refactor your code to ensure that the AzureOpenAIEmbeddings object is not being pickled, or to remove the client objects before pickling. If you need to use the client objects after unpickling, you can reinitialize them.
There was a similar issue reported in the LangChain repository: RetrievalQA and HuggingFacePipeline lead to TypeError: cannot pickle '_thread.lock' object. The temporary fix suggested was to change a line of code in the text_generation.py file of the transformers library. However, this change might have unknown downsides. Another user confirmed that the issue seems to be resolved in version transformers==4.31.0.
I hope this helps! If you have any other questions or need further clarification, feel free to ask.
Best, Dosu
Sources
RetrievalQAandHuggingFacePipelinelead toTypeError: cannot pickle '_thread.lock' object- libs/langchain/langchain/embeddings/azure_openai.py
This response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research.
Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.
If you want to continue the conversation, start your reply with @dosu-bot.
@duongvinhduc can you please use code block to format the code you shared in python? That will be helpful.
@umair313 It's true that I had to re-initialize AzureOpenAIEmbeddings is there a way to only initialize once?
@dosu-bot ,
i am also facing same issue. Now since openai has updated it's API hence i need to use openai==1.7.1 and langchain ==0.1.0.
because of this not able to pickle to local.
One more question: Do we have issue with store = FAISS.from_texts(docs, embeddings, metadatas=metadatas) also? because when trying to directly access store by using chain = RetrievalQAWithSourcesChain.from_chain_type(llm=llm, retriever=store.as_retriever())
then it is giving answer in 5 to 8 words. why? And this is happening in latest version only. any idea .
Same issue here, any solutions?
Same issue here, any solutions?
refer : https://github.com/langchain-ai/langchain/issues/16267
same issue
refer : https://github.com/langchain-ai/langchain/issues/16267
Add max_token patameter in chatopenAI
For those still having this specific issue (ChromaDB + HuggingFace models), I extended the ChromadbRM class with a "fake" deepcopy method to avoid deep copying a client:
class ChromadbRMC(dspy.Retrieve):
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
# other methods are still the same
def __init__(...
DSpy optimizer works now. Hope this helps
Adding max_token parameter worked for me
llm_config = { #"request_timeout": 600, "config_list": [ {"model": selected_model, "api_key": selected_key}, ], "seed": 42, # seed for reproducibility "temperature": 0, # temperature of 0 means deterministic output, "max_tokens": 256, # max number of tokens to generate # solves the bug }
here is a hands-on (means 'sort of quick and dirty' or pragmatic) solution for "serializing" the AzureEmbedings class in libs/community/langchain_community/embeddings/azure_openai.py -> tested with tag v0.1.13 code.
If I find the time to get familiar with the contribution guideline I might make it a PR - but who knows if they would want to include this in the first place?
Until then here is the additional lines of code that will make AzureEmbeddings class serialize-able and therefore pickle-able ...
I included a couple of additional lines before and after for easier orientation BUT it's mainly the __getstate__ and __setstate__ methods that do the "magic":
"""Azure OpenAI embeddings wrapper."""
from __future__ import annotations
import os
from typing import Callable, Dict, Optional, Union
import openai
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_openai.embeddings.base import OpenAIEmbeddings
class AzureOpenAIEmbeddings(OpenAIEmbeddings):
"""
`Azure OpenAI` Embeddings API.
To use, you should have the environment variable ``AZURE_OPENAI_API_KEY`` set
with your API key or pass it as a named parameter to the constructor.
Example:
from langchain_openai import AzureOpenAIEmbeddings
openai = AzureOpenAIEmbeddings(model="text-embedding-3-large")
"""
azure_endpoint: Union[str, None] = None
deployment: Optional[str] = Field(default=None, alias="azure_deployment")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
azure_ad_token: Optional[SecretStr] = None
azure_ad_token_provider: Union[Callable[[], str], None] = None
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
validate_base_url: bool = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
# Existing validation logic...
# Your existing client_params logic
if not values.get("client"):
sync_specific = {"http_client": values["http_client"]}
values["client"] = openai.AzureOpenAI(**client_params, **sync_specific).embeddings
if not values.get("async_client"):
async_specific = {"http_client": values["http_async_client"]}
values["async_client"] = openai.AsyncAzureOpenAI(**client_params, **async_specific).embeddings
return values
def __getstate__(self):
state = self.__dict__.copy()
state.pop('client', None)
state.pop('async_client', None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
client_params = {
"api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint,
"azure_deployment": self.deployment,
"api_key": self.openai_api_key.get_secret_value() if self.openai_api_key else None,
"azure_ad_token": self.azure_ad_token.get_secret_value() if self.azure_ad_token else None,
"azure_ad_token_provider": self.azure_ad_token_provider,
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
self.client = openai.AzureOpenAI(**client_params).embeddings
self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings
@property
def _llm_type(self) -> str:
return "azure-openai-chat"
# Add here any additional methods or properties that were in the original AzureEmbedings class.
If you import this class instead of the "original" one it should fix pickle errors eg in chains with RetrievalQA.
And while I am at it here is the same type of fix for the AzureChatOpenAI at libs/community/langchain_community/chat_models/azure_openai.py class, which imho will cause the same type of pickle-ing "issue", also tested on the tag v0.1.13 code.
from __future__ import annotations
import logging
import os
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import openai
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_openai.llms.base import BaseOpenAI
logger = logging.getLogger(__name__)
class AzureOpenAI(BaseOpenAI):
"""
Azure-specific OpenAI large language models.
...
"""
azure_endpoint: Union[str, None] = None
deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment")
openai_api_version: str = Field(default="", alias="api_version")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
azure_ad_token: Optional[SecretStr] = None
azure_ad_token_provider: Union[Callable[[], str], None] = None
openai_api_type: str = ""
validate_base_url: bool = True
# ... existing class methods ...
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
# ... existing validation logic ...
return values
def __getstate__(self):
state = self.__dict__.copy()
state.pop('client', None)
state.pop('async_client', None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
client_params = {
"api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint,
"azure_deployment": self.deployment_name,
"api_key": self.openai_api_key.get_secret_value() if self.openai_api_key else None,
"azure_ad_token": self.azure_ad_token.get_secret_value() if self.azure_ad_token else None,
"azure_ad_token_provider": self.azure_ad_token_provider,
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
self.client = openai.AzureOpenAI(**client_params).completions
self.async_client = openai.AsyncAzureOpenAI(**client_params).completions
@property
def _identifying_params(self) -> Mapping[str, Any]:
# ... existing method ...
@property
def _invocation_params(self) -> Dict[str, Any]:
# ... existing method ...
@property
def _llm_type(self) -> str:
# ... existing method ...
@property
def lc_attributes(self) -> Dict[str, Any]:
# ... existing method ...
Same instructions and comments as in the above comment apply here as well.
sadly the code above results in missing attribute __fields_set__ errors when "un-pickling" the AzureOpenAIEmbeddings object (was done wihin an MLFlow chain) for me :-(
I guess that pydantic import is somehow to blame for that -> see comment below on how to solve that (as well).
@duongvinhduc | @pranay518 | @GuillermoGarciaF | @Abhijeetcode911
The only way that I found to solve this issue is by lowering your lanchain version to
- langchain==0.0.350
- openai==0.27.6 and also now the import syntax will change.
- from langchain_openai import OpenAIEmbeddings
will be changed as
- from langchain_community.embeddings import OpenAIEmbeddings
It worked like this to me !!! After that no issue on pickling your embedding file.
import os
import pickle
from dotenv import load_dotenv, find_dotenv
import openai
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.document_loaders import TextLoader
def generate_embeded_pickle(model= "text-embedding-3-large"):
embeddings_model = OpenAIEmbeddings(model= model)
docs = load_split(file_path = "src/data/test.pdf")
embeddings_file = 'src/pickle_store/documentToParse.pkl'
if not os.path.exists(embeddings_file):
embeddings = FAISS.from_documents(docs, embeddings_model)
with open(embeddings_file, "wb") as file:
pickle.dump(embeddings, file)
print("Pickle file saved successfully at:", embeddings_file)
else:
print("Pickle file already exists!")
I managed to (additionally) "fix" the __fields_set__ related issue (I mentioned running into above) by extending the modifications to the class a little bit more
ie further additons to the AzureAIEmbeddings class were a __fields_set__ = set() attribute, where I looked up its details/definition in the BaseModel super-superclass and adding some logic to include that in the __getstate__ and __setstate__ methods. But I think that there's some tickets related to 'seriliization' issues in the works already.
still need a solution with langchain >= 0.2