langchain icon indicating copy to clipboard operation
langchain copied to clipboard

Add caching to BaseChatModel (issue #1644)

Open UmerHA opened this issue 1 year ago • 14 comments

Add caching to BaseChatModel

Fixes #1644

(Sidenote: While testing, I noticed we have multiple implementations of Fake LLMs, used for testing. I consolidated them.)

Who can review?

Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: Models

  • @hwchase17
  • @agola11

Twitter: @UmerHAdil | Discord: RicChilligerDude#7589

UmerHA avatar May 22 '23 12:05 UmerHA

Any comments on this? @hwchase17 @agola11

Would be great to have caching included!

kaikun213 avatar May 24 '23 16:05 kaikun213

Someone please take a look at this. Really need this :) Thanks

abdulzain6 avatar May 29 '23 10:05 abdulzain6

Need this too.

realjustinwu avatar May 31 '23 14:05 realjustinwu

I hope it's reviewed soon, we need caching for ChatModels !

ielmansouri avatar Jun 02 '23 14:06 ielmansouri

langchain.llm_cache = SQLiteCache(database_path=".langchain.db")
chat = ChatOpenAI(temperature=0, openai_api_key=get_openai_api_key())
messages = [
    SystemMessage(content="You are a helpful assistant that translates English to French."),
    HumanMessage(content="I love programming.")
]
start = time.time()
print(chat(messages))
print(f"first time = {time.time() - start}")
start = time.time()
print(chat(messages))
print(f"first time = {time.time() - start}")

I test this code with this PR. First request miss cache, so It works. But, second request hit cache, and error occur.

cls = <class 'langchain.schema.ChatGeneration'>
values = {'generation_info': None, 'text': "J'adore la programmation."}

    @root_validator
    def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
>       values["text"] = values["message"].content
E       KeyError: 'message'

With InMemoryCache, test code work fine

langchain.llm_cache = InMemoryCache()
chat = ChatOpenAI(temperature=0, openai_api_key=get_openai_api_key())
messages = [
    SystemMessage(content="You are a helpful assistant that translates English to French."),
    HumanMessage(content="I love programming.")
]
start = time.time()
print(chat(messages))
print(f"first time = {time.time() - start}")
start = time.time()
print(chat(messages))
print(f"first time = {time.time() - start}")

My guess is that InMemoryCache is just python dictionary, so it save data as ChatGeneration type. However, SQLiteCache is local database, so it save data as Generation type. If cache hit with SQLiteCache (and other type cache), loaded data is Generation type, not ChatGeneration. So, there is no "message" property in loaded data. Fast(?) solution is implementing langchain.chat_model seperately, which save and load ChatGeneration type. But, it need two cache code, llm and chat_model for all cache implementation. To solve this problem, BaseCache need to be modified (I think). But, It is complicated.

Rienkim avatar Jun 04 '23 13:06 Rienkim

Hey @Rienkim, thanks for pointing that out! I'll take a look & add more tests that use more different caching options.

ETA should be this week. In the meanwhile, I'll turn this PR into a draft.

UmerHA avatar Jun 05 '23 11:06 UmerHA

@Rienkim Fixed it & added more tests

UmerHA avatar Jun 06 '23 22:06 UmerHA

@UmerHA thank you for working on this.

I found that _combine_llm_outputs implemented in this PR can be an issue with OpenAICallbackHandler https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/openai_info.py#LL99C42-L99C42

deepblue avatar Jun 06 '23 22:06 deepblue

Hi folks, what is the latest on this? Is there a timeline for merging this?

sam-cohan avatar Jun 06 '23 22:06 sam-cohan

@UmerHA thank you for working on this.

I found that _combine_llm_outputs implemented in this PR can be an issue with OpenAICallbackHandler https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/openai_info.py#LL99C42-L99C42

Can you post example code with actual & expected behavior? To me it seems fine, but maybe I'm missing something:

langchain.llm_cache = SQLiteCache(".langchain.test.db")
langchain.llm_cache.clear()

oai_cb = OpenAICallbackHandler()

llm = ChatOpenAI(callbacks=[oai_cb])

messages = [
    SystemMessage(content="You are a helpful assistant that translates English to French."),
    HumanMessage(content="I love programming.")
]

print(llm(messages))
print(oai_cb)

print(llm(messages))
print(oai_cb)

print(llm(messages))
print(oai_cb)

gives first

content="J'adore programmer." additional_kwargs={} example=False
Tokens Used: 34
	Prompt Tokens: 28
	Completion Tokens: 6
Successful Requests: 1
Total Cost (USD): $6.8e-05

then

content="J'adore programmer." additional_kwargs={} example=False
Tokens Used: 34
	Prompt Tokens: 28
	Completion Tokens: 6
Successful Requests: 2
Total Cost (USD): $6.8e-05

and then

content="J'adore programmer." additional_kwargs={} example=False
Tokens Used: 34
	Prompt Tokens: 28
	Completion Tokens: 6
Successful Requests: 3
Total Cost (USD): $6.8e-05

What would you expect to be different? :)

UmerHA avatar Jun 07 '23 09:06 UmerHA

@UmerHA thank you for working on this. I found that _combine_llm_outputs implemented in this PR can be an issue with OpenAICallbackHandler https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/openai_info.py#LL99C42-L99C42

Can you post example code with actual & expected behavior? To me it seems fine, but maybe I'm missing something:

langchain.llm_cache = SQLiteCache(".langchain.test.db")
langchain.llm_cache.clear()

oai_cb = OpenAICallbackHandler()

llm = ChatOpenAI(callbacks=[oai_cb])

messages = [
    SystemMessage(content="You are a helpful assistant that translates English to French."),
    HumanMessage(content="I love programming.")
]

print(llm(messages))
print(oai_cb)

print(llm(messages))
print(oai_cb)

print(llm(messages))
print(oai_cb)

gives first

content="J'adore programmer." additional_kwargs={} example=False
Tokens Used: 34
	Prompt Tokens: 28
	Completion Tokens: 6
Successful Requests: 1
Total Cost (USD): $6.8e-05

then

content="J'adore programmer." additional_kwargs={} example=False
Tokens Used: 34
	Prompt Tokens: 28
	Completion Tokens: 6
Successful Requests: 2
Total Cost (USD): $6.8e-05

and then

content="J'adore programmer." additional_kwargs={} example=False
Tokens Used: 34
	Prompt Tokens: 28
	Completion Tokens: 6
Successful Requests: 3
Total Cost (USD): $6.8e-05

What would you expect to be different? :)

@deepblue ^^^ (in case you missed it)

pors avatar Jun 09 '23 10:06 pors

@deepblue could you elaborate on your concerns? I am waiting for the feature :)

jakobsa avatar Jun 14 '23 12:06 jakobsa

Just realized I made an error in my code- used BaseChatModel._combine_llm_outputs in this PR instead of the default ChatOpenAI._combine_llm_outputs for testing. Retested with the right method and it's all clear. The original method doesn't affect our subclassed code. Apologies for the mix-up and the late response.

I tested and confirmed that it's working as expected

deepblue avatar Jun 15 '23 08:06 deepblue

@hwchase17 @agola11

Good to go?

pors avatar Jun 15 '23 09:06 pors

@hwchase17 @agola11

Any update on this?

kaikun213 avatar Jun 22 '23 09:06 kaikun213

The latest updates on your projects. Learn more about Vercel for Git ↗︎

1 Ignored Deployment
Name Status Preview Comments Updated (UTC)
langchain ⬜️ Ignored (Inspect) Jun 24, 2023 6:44pm

vercel[bot] avatar Jun 24 '23 04:06 vercel[bot]

Getting the below error when I use MomentoCache. @UmerHA please let me know if this is a bug or if I am doing anything wrong.

Edit: The error pops up for all calls after the cache has atleast 1 key set.

@root_validator
    def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
>       values["text"] = values["message"].content
E       KeyError: 'message'

Code:

import langchain
from datetime import timedelta
from langchain.cache import MomentoCache

langchain.llm_cache = MomentoCache.from_client_params("langchain_momento", imedelta(days=1))

# Further code for constructing and calling the chain using ChatOpenAI

ghost avatar Jul 11 '23 11:07 ghost

Getting the below error when I use MomentoCache. @UmerHA please let me know if this is a bug or if I am doing anything wrong.

Edit: The error pops up for all calls after the cache has atleast 1 key set.

@root_validator
    def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
>       values["text"] = values["message"].content
E       KeyError: 'message'

Code:

import langchain
from datetime import timedelta
from langchain.cache import MomentoCache

langchain.llm_cache = MomentoCache.from_client_params("langchain_momento", imedelta(days=1))

# Further code for constructing and calling the chain using ChatOpenAI

Can you post the full code, error message, and stack trace?

UmerHA avatar Jul 11 '23 11:07 UmerHA