langchain
langchain copied to clipboard
How to add memory to SQLDatabaseChain?
Issue you'd like to raise.
I want to create a chain to make query against my database. Also I want to add memory to this chain.
Example of dialogue I want to see:
Query: Who is an owner of website with domain domain.com? Answer: Boba Bobovich Query: Tell me his email Answer: Boba Bobovich's email is [email protected]
I have this code:
import os
from langchain import OpenAI, SQLDatabase, SQLDatabaseChain, PromptTemplate
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory()
db = SQLDatabase.from_uri(os.getenv("DB_URI"))
llm = OpenAI(temperature=0, verbose=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, memory=memory)
db_chain.run("Who is owner of the website with domain https://damon.name")
db_chain.run("Tell me his email")
print(memory.load_memory_variables({}))
It gives:
> Entering new chain...
Who is owner of the website with domain https://damon.name
SQLQuery:SELECT first_name, last_name FROM owners JOIN websites ON owners.id = websites.owner_id WHERE domain = 'https://damon.name' LIMIT 5;
SQLResult: [('Geo', 'Mertz')]
Answer:Geo Mertz is the owner of the website with domain https://damon.name.
> Finished chain.
> Entering new chain...
Tell me his email
SQLQuery:SELECT email FROM owners WHERE first_name = 'Westley' AND last_name = 'Waters'
SQLResult: [('[email protected]',)]
Answer:Westley Waters' email is [email protected].
> Finished chain.
{'history': "Human: Who is owner of the website with domain https://damon.name\nAI: Geo Mertz is the owner of the website with domain https://damon.name.\nHuman: Tell me his email\nAI: Westley Waters' email is [email protected]."}
Well, it saves context to memory but chain doesn't use it to give a proper answer (wrong email). How to fix it?
Also I don't want to use an agent because I want to manage to do this with a simple chain first. Tell me if it's impossible with simple chain.
Suggestion:
No response
My understanding It seems that you want to create a chain to query your database and add memory to the chain to maintain the context of the conversation. You are using the SQLDatabaseChain and ConversationBufferMemory from the LangChain library. However, you are facing an issue where the context is saved in memory, but the chain does not use it to provide the correct answer.
Resolution To fix this issue, you can modify your SQLDatabaseChain to utilize the memory when generating the response. You can achieve this by extending the SQLDatabaseChain class and overriding the run method to include the memory in the query generation process. Here's an example of how you can create a custom SQLDatabaseChain with memory support. Treat this code as a sudo approach may not reflect the exact implementation.
from langchain import SQLDatabaseChain
from langchain.memory import ConversationBufferMemory
class SQLDatabaseChainWithMemory(SQLDatabaseChain):
def __init__(self, *args, memory: ConversationBufferMemory, **kwargs):
super().__init__(*args, **kwargs)
self.memory = memory
def run(self, inputs):
# Add the memory variables to the inputs
inputs_with_memory = {**inputs, **self.memory.load_memory_variables({})}
# Call the parent class's run method with the updated inputs
return super().run(inputs_with_memory)
Now, you can use this SQLDatabaseChainWithMemory class to create your chain with memory support:
memory = ConversationBufferMemory()
db_chain = SQLDatabaseChainWithMemory.from_llm(llm, db, verbose=True, memory=memory)
This custom chain will include the memory variables in the inputs when generating the response, allowing it to provide the correct answer based on the conversation context.
My understanding It seems that you want to create a chain to query your database and add memory to the chain to maintain the context of the conversation. You are using the SQLDatabaseChain and ConversationBufferMemory from the LangChain library. However, you are facing an issue where the context is saved in memory, but the chain does not use it to provide the correct answer.
Resolution To fix this issue, you can modify your SQLDatabaseChain to utilize the memory when generating the response. You can achieve this by extending the SQLDatabaseChain class and overriding the run method to include the memory in the query generation process. Here's an example of how you can create a custom SQLDatabaseChain with memory support. Treat this code as a sudo approach may not reflect the exact implementation.
from langchain import SQLDatabaseChain
from langchain.memory import ConversationBufferMemory
class SQLDatabaseChainWithMemory(SQLDatabaseChain):
def __init__(self, *args, memory: ConversationBufferMemory, **kwargs):
super().__init__(*args, **kwargs)
self.memory = memory
def run(self, inputs):
# Add the memory variables to the inputs
inputs_with_memory = {**inputs, **self.memory.load_memory_variables({})}
# Call the parent class's run method with the updated inputs
return super().run(inputs_with_memory)
Now, you can use this SQLDatabaseChainWithMemory class to create your chain with memory support:
memory = ConversationBufferMemory()
db_chain = SQLDatabaseChainWithMemory.from_llm(llm, db, verbose=True, memory=memory)
This custom chain will include the memory variables in the inputs when generating the response, allowing it to provide the correct answer based on the conversation context.
Hey did you get any solution to this problem?
My understanding It seems that you want to create a chain to query your database and add memory to the chain to maintain the context of the conversation. You are using the SQLDatabaseChain and ConversationBufferMemory from the LangChain library. However, you are facing an issue where the context is saved in memory, but the chain does not use it to provide the correct answer.
Resolution To fix this issue, you can modify your SQLDatabaseChain to utilize the memory when generating the response. You can achieve this by extending the SQLDatabaseChain class and overriding the run method to include the memory in the query generation process. Here's an example of how you can create a custom SQLDatabaseChain with memory support. Treat this code as a sudo approach may not reflect the exact implementation.
from langchain import SQLDatabaseChain from langchain.memory import ConversationBufferMemory class SQLDatabaseChainWithMemory(SQLDatabaseChain): def __init__(self, *args, memory: ConversationBufferMemory, **kwargs): super().__init__(*args, **kwargs) self.memory = memory def run(self, inputs): # Add the memory variables to the inputs inputs_with_memory = {**inputs, **self.memory.load_memory_variables({})} # Call the parent class's run method with the updated inputs return super().run(inputs_with_memory)
Now, you can use this SQLDatabaseChainWithMemory class to create your chain with memory support:
memory = ConversationBufferMemory() db_chain = SQLDatabaseChainWithMemory.from_llm(llm, db, verbose=True, memory=memory)
This custom chain will include the memory variables in the inputs when generating the response, allowing it to provide the correct answer based on the conversation context.
This is not working
can you give me a correct impleamentation?
@mkx000 @charanhu I am working on the exact issue. I will let you know once I find a good solution.
I have starting working on a PR for this issue. This will be a good feature to have.
@bleschunov This article will give the solution - https://python.langchain.com/docs/modules/agents/how_to/add_memory_openai_functions
Sample
from langchain import SQLDatabase, SQLDatabaseChain,
from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType
from langchain.chat_models import ChatOpenAI
from langchain.prompts import MessagesPlaceholder
from langchain.memory import ConversationBufferMemory
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613") # type: ignore
db = SQLDatabase.from_uri()
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
tools = [
Tool(
name="dbchain",
func=db_chain.run,
description="Chat with SQLDB",
)
]
agent_kwargs = {
"extra_prompt_messages": [MessagesPlaceholder(variable_name="memory")],
}
memory = ConversationBufferMemory(memory_key="memory", return_messages=True)
agent = initialize_agent(
tools,
llm,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True,
agent_kwargs=agent_kwargs,
memory=memory
)
@gugupy I tried this out and this does not work.
@gugupy I tried this out and this does not work.
Works for me. It can remember the previous query.
Ask your query? Is there any realm named demo?
> Entering new chain...
Invoking: `dbchain` with `Is there any realm named demo?`
> Entering new chain...
Is there any realm named demo?
SQLQuery:SELECT name FROM realm WHERE name = 'demo' LIMIT 1;
SQLResult: [('demo',)]
Answer:Yes, there is a realm named demo.
> Finished chain.
Yes, there is a realm named demo.Yes, there is a realm named demo.
> Finished chain.
Yes, there is a realm named demo.
Ask your query? How many users in the realm?
Invoking: `dbchain` with `How many users are in the realm named demo?`
> Entering new chain...
How many users are in the realm named demo?
SQLQuery:SELECT COUNT(*) FROM user_entity WHERE realm_id = (SELECT id FROM realm WHERE name = 'demo');
SQLResult: [(35,)]
Answer:There are 35 users in the realm named demo.
> Finished chain.
There are 35 users in the realm named demo.There are 35 users in the realm named demo.
> Finished chain.
There are 35 users in the realm named demo.```
It works sometimes but is very unreliable since it crashes with syntax errors. Ideally, the default SQL agent packaged in langchain with all of the SQL-Toolkit tools should have memory integrated.
I am refering to the ZERO_SHOT_REACT_DESCRIPTION
SQL-Agent with all of SQL-toolkit tools created by function create_sql_agent
in sql sub-package in agents package.
For anyone coming in the future. Currently, passing memory directly in SQLDatabaseChain
or SQLDatabaseSequentialChain
is not possible but I am working on that (PR: #7546) . However, you can create an SQL-agent with memory as follow.
from langchain.agents.agent_toolkits import create_sql_agent,SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.sql_database import SQLDatabase
from langchain.chat_models.openai import ChatOpenAI
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory(memory_key = 'history' , input_key = 'input')
llm = ChatOpenAI(model_name = GPT_MODEL_NAME , temperature = 0)
db = SQLDatabase.from_uri(CONN_STRING)
suffix = """Begin!
Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information if not relevant)
Question: {input}
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
{agent_scratchpad}
"""
executor = create_sql_agent(
llm = llm,
toolkit = SQLDatabaseToolkit(db=db, llm=llm),
agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
input_variables = ["input", "agent_scratchpad","history"],
suffix = suffix # must have history as variable,
agent_executor_kwargs = {'memory':memory}
)
thanks for that snippet for the agent @keenborder786 !
will work with you on adding to sqldatabasechain
@keenborder786 what do you mean by "#must have history as variable," and can you kindly explain how to run the agen? below is the way I am writing the code
question = "Describe the actor table"
agent_executor.run(question)
question = "What table was I talking about?"
agent_executor.run(question)
and the output that I got
> Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input: ""
Observation: actor, address, category, city, country, customer, film, film_actor, film_category, inventory, language, payment, rental, staff, store
Thought:There are several tables in the database that I can query. I should query the schema of the "actor" table to describe it.
Action: sql_db_schema
Action Input: "actor"
Observation:
CREATE TABLE actor (
actor_id SERIAL NOT NULL,
first_name VARCHAR(45) NOT NULL,
last_name VARCHAR(45) NOT NULL,
last_update TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL,
CONSTRAINT actor_pkey PRIMARY KEY (actor_id)
)
/*
3 rows from actor table:
actor_id first_name last_name last_update
1 Penelope Guiness 2013-05-26 14:47:57.620000
2 Nick Wahlberg 2013-05-26 14:47:57.620000
3 Ed Chase 2013-05-26 14:47:57.620000
*/
Thought:The actor table has columns actor_id, first_name, last_name, and last_update. It contains information about actors in the database.
Final Answer: The actor table has columns actor_id, first_name, last_name, and last_update.
> Finished chain.
'The actor table has columns actor_id, first_name, last_name, and last_update.'
and for the second question
> Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input: ""
Observation: actor, address, category, city, country, customer, film, film_actor, film_category, inventory, language, payment, rental, staff, store
Thought:I have a list of tables in the database. Now I can query the schema of each table to find the one I was talking about.
Action: sql_db_schema
Action Input: "actor"
Observation:
CREATE TABLE actor (
actor_id SERIAL NOT NULL,
first_name VARCHAR(45) NOT NULL,
last_name VARCHAR(45) NOT NULL,
last_update TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL,
CONSTRAINT actor_pkey PRIMARY KEY (actor_id)
)
/*
3 rows from actor table:
actor_id first_name last_name last_update
1 Penelope Guiness 2013-05-26 14:47:57.620000
2 Nick Wahlberg 2013-05-26 14:47:57.620000
3 Ed Chase 2013-05-26 14:47:57.620000
*/
Thought:The table I was talking about is the actor table.
Final Answer: actor
> Finished chain.
'actor'
@keenborder786 works, but for some reason the information obtained from the database for previous questions is not used for subsequent questions. Each time another query is made to the database. BTW any idea how can I use GPTCache with SQLDatabaseChain? I run this code:
agent_executor.run("Select 10 first articles")
agent_executor.run("What is the name of the article with the ID number 108775044")
and getting this:
Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input: ""
Observation: articles, customers, transactions
Thought:The 'articles' table seems to be the most relevant for this query. I should check its schema to understand its structure and the data it contains.
Action: sql_db_schema
Action Input: "articles"
Observation:
CREATE TABLE articles (
article_id INTEGER,
product_code INTEGER,
prod_name VARCHAR(250),
product_type_no INTEGER,
product_type_name VARCHAR(50),
product_group_name VARCHAR(50),
graphical_appearance_no INTEGER,
graphical_appearance_name VARCHAR(50),
colour_group_code INTEGER,
colour_group_name VARCHAR(50),
perceived_colour_value_id INTEGER,
perceived_colour_value_name VARCHAR(50),
perceived_colour_master_id INTEGER,
perceived_colour_master_name VARCHAR(50),
department_no INTEGER,
department_name VARCHAR(50),
index_code VARCHAR(10),
index_name VARCHAR(50),
index_group_no INTEGER,
index_group_name VARCHAR(50),
section_no INTEGER,
section_name VARCHAR(50),
garment_group_no INTEGER,
garment_group_name VARCHAR(50),
detail_desc VARCHAR(1500) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci
)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4
/*
3 rows from articles table:
article_id product_code prod_name product_type_no product_type_name product_group_name graphical_appearance_no graphical_appearance_name colour_group_code colour_group_name perceived_colour_value_id perceived_colour_value_name perceived_colour_master_id perceived_colour_master_name department_no department_name index_code index_name index_group_no index_group_name section_no section_name garment_group_no garment_group_name detail_desc
564999003 564999 Drake sweatpants 272 Trousers Garment Lower body 1010016 Solid 93 Dark Green 2 Medium Dusty 19 Gree8748 Young Boy Jersey Fancy I Children Sizes 134-170 4 Baby/Children 47 Young Boy 1005 Jersey Fancy Joggers in sweatshirt fabric with an elasticated drawstring waist, side pockets and tapered legs wit
564999008 564999 Drake sweatpants 272 Trousers Garment Lower body 1010005 Colour blocking 73 Dark Blue 4 Dark 2 Blue8748 Young Boy Jersey Fancy I Children Sizes 134-170 4 Baby/Children 47 Young Boy 1005 Jersey Fancy Joggers in sweatshirt fabric with an elasticated drawstring waist, side pockets and tapered legs wit
564999009 564999 Drake sweatpants 272 Trousers Garment Lower body 1010005 Colour blocking 9 Black 4 Dark 5 Black 8748Young Boy Jersey Fancy I Children Sizes 134-170 4 Baby/Children 47 Young Boy 1005 Jersey Fancy Joggers in sweatshirt fabric with an elasticated drawstring waist, side pockets and tapered legs wit
*/
Thought:The 'articles' table contains information about different articles, including their names, types, and descriptions. To answer the question, I need to select the first 10 articles. I will do this by ordering the articles by 'article_id' and limiting the results to 10. I will only select the 'article_id' and 'prod_name' columns, as these seem to be the most relevant to the question.
Action: sql_db_query_checker
Action Input: "SELECT article_id, prod_name FROM articles ORDER BY article_id LIMIT 10"
Observation: SELECT article_id, prod_name FROM articles ORDER BY article_id LIMIT 10
Thought:The query syntax is correct. Now I can execute it to get the first 10 articles.
Action: sql_db_query
Action Input: "SELECT article_id, prod_name FROM articles ORDER BY article_id LIMIT 10"
Observation: [(0, 'prod_name'), (108775015, 'Strap top'), (108775044, 'Strap top'), (108775051, 'Strap top (1)'), (110065001, 'OP T-shirt (Idro)'), (110065002, 'OP T-shirt (Idro)'), (110065011, 'OP T-shirt (Idro)'), (111565001, '20 den 1p Stockings'), (111565003, '20 den 1p Stockings'), (111586001, 'Shape Up 30 den 1p Tights')]
Thought:I now know the final answer.
Final Answer: The first 10 articles are:
1. Article ID: 108775015, Product Name: Strap top
2. Article ID: 108775044, Product Name: Strap top
3. Article ID: 108775051, Product Name: Strap top (1)
4. Article ID: 110065001, Product Name: OP T-shirt (Idro)
5. Article ID: 110065002, Product Name: OP T-shirt (Idro)
6. Article ID: 110065011, Product Name: OP T-shirt (Idro)
7. Article ID: 111565001, Product Name: 20 den 1p Stockings
8. Article ID: 111565003, Product Name: 20 den 1p Stockings
9. Article ID: 111586001, Product Name: Shape Up 30 den 1p Tights
and second one
Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input: ""
Observation: articles, customers, transactions
Thought:The 'articles' table seems to be the most relevant one for this query, as it likely contains information about articles. I should check its schema to see if it contains a column for the article name and id.
Action: sql_db_schema
Action Input: "articles"
Observation:
CREATE TABLE articles (
article_id INTEGER,
product_code INTEGER,
prod_name VARCHAR(250),
product_type_no INTEGER,
product_type_name VARCHAR(50),
product_group_name VARCHAR(50),
graphical_appearance_no INTEGER,
graphical_appearance_name VARCHAR(50),
colour_group_code INTEGER,
colour_group_name VARCHAR(50),
perceived_colour_value_id INTEGER,
perceived_colour_value_name VARCHAR(50),
perceived_colour_master_id INTEGER,
perceived_colour_master_name VARCHAR(50),
department_no INTEGER,
department_name VARCHAR(50),
index_code VARCHAR(10),
index_name VARCHAR(50),
index_group_no INTEGER,
index_group_name VARCHAR(50),
section_no INTEGER,
section_name VARCHAR(50),
garment_group_no INTEGER,
garment_group_name VARCHAR(50),
detail_desc VARCHAR(1500) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci
)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4
/*
3 rows from articles table:
article_id product_code prod_name product_type_no product_type_name product_group_name graphical_appearance_no graphical_appearance_name colour_group_code colour_group_name perceived_colour_value_id perceived_colour_value_name perceived_colour_master_id perceived_colour_master_name department_no department_name index_code index_name index_group_no index_group_name section_no section_name garment_group_no garment_group_name detail_desc
564999003 564999 Drake sweatpants 272 Trousers Garment Lower body 1010016 Solid 93 Dark Green 2 Medium Dusty 19 Gree8748 Young Boy Jersey Fancy I Children Sizes 134-170 4 Baby/Children 47 Young Boy 1005 Jersey Fancy Joggers in sweatshirt fabric with an elasticated drawstring waist, side pockets and tapered legs wit
564999008 564999 Drake sweatpants 272 Trousers Garment Lower body 1010005 Colour blocking 73 Dark Blue 4 Dark 2 Blue8748 Young Boy Jersey Fancy I Children Sizes 134-170 4 Baby/Children 47 Young Boy 1005 Jersey Fancy Joggers in sweatshirt fabric with an elasticated drawstring waist, side pockets and tapered legs wit
564999009 564999 Drake sweatpants 272 Trousers Garment Lower body 1010005 Colour blocking 9 Black 4 Dark 5 Black 8748Young Boy Jersey Fancy I Children Sizes 134-170 4 Baby/Children 47 Young Boy 1005 Jersey Fancy Joggers in sweatshirt fabric with an elasticated drawstring waist, side pockets and tapered legs wit
*/
Thought:The 'articles' table has a column named 'article_id' which likely contains the id of the articles, and a column named 'prod_name' which likely contains the name of the articles. I can use these columns to find the name of the article with id 108775044. I will write a SQL query to get this information.
Action: sql_db_query_checker
Action Input: "SELECT prod_name FROM articles WHERE article_id = 108775044 LIMIT 10"
Observation: SELECT prod_name FROM articles WHERE article_id = 108775044 LIMIT 10
Thought:The query syntax is correct. Now I can execute it to get the name of the article with id 108775044.
Action: sql_db_query
Action Input: "SELECT prod_name FROM articles WHERE article_id = 108775044 LIMIT 10"
Observation: [('Strap top',)]
Thought:I now know the final answer
Final Answer: The name of the article with id 108775044 is 'Strap top'.
I'm not able to get the chat history to populate. still getting an error: ValueError: Missing some input keys: {'history'}
from langchain.prompts.prompt import PromptTemplate
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
from langchain.agents.agent_toolkits import create_sql_agent,SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.sql_database import SQLDatabase
engine_athena = create_engine(connection_string, echo=False)
memory = ConversationBufferMemory(memory_key = 'history' , input_key = 'input')
_DEFAULT_TEMPLATE = """
Go!
Given an input question
Use the following format:
Question: “Question here”
SQLQuery: “SQL Query to run”
SQLResult: “Result of the SQLQuery”
Answer: “Final answer here”
Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information if not relevant)
Question: {input}"""
print(memory)
PROMPT = PromptTemplate(
input_variables=["input","history"], template=_DEFAULT_TEMPLATE
)
db = SQLDatabase(engine_athena,
include_tables=[<table_name>],
sample_rows_in_table_info=2)
llm= OpenAI(temperature=0.4,model_name=)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=PROMPT, return_intermediate_steps=True, memory=memory)
result = db_chain(<uesr_question>)
result["intermediate_steps"]
Error:
Entering new SQLDatabaseChain chain... what is the total cost of our cloud spend in 2018? SQLQuery:
ValueError Traceback (most recent call last)
in <cell line: 1>()
----> 1 result = db_chain(
6 frames /usr/local/lib/python3.10/dist-packages/langchain/chains/base.py in _validate_inputs(self, inputs) 81 missing_keys = set(self.input_keys).difference(inputs) 82 if missing_keys: ---> 83 raise ValueError(f"Missing some input keys: {missing_keys}") 84 85 def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
ValueError: Missing some input keys: {'history'}
@ameerhakme Try creating SQLDatabaseChain
as below. When creating instance of SQLDatabaseChain.from_llm
LLMChain instantiate without memory, and prompt template with history keyword.
llm = ChatOpenAI(temperature=0, model=openai_model_name, verbose=verbose)
db = SQLDatabase.from_uri(
CONN_STRING,
include_tables=include_tables,
schema=postgresql_schema,
sample_rows_in_table_info=3
)
memory = ConversationBufferMemory(input_key='input', memory_key="history")
dbchain = SQLDatabaseChain(
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=memory),
database=db,
verbose=verbose
)
@gugupy : Using above approach, majority of the times query itself is failing (with syntax errors) where as the same was working OK when we initialize dbchain = SQLDatabaseChain (llm=llm, .. )
@panyamravi You can avoid syntax error by improving prompt something like Do not give invalid SQL queries
or use parameters use_query_checker
and query_checker_prompt
in the Chain.
@gugupy : I tried adding query checker as well but no luck. if I use, SQLDBChain without query_checker set to True its working OK but memory is not preserved. Thanks
@gugupy : I tried adding query checker as well but no luck. if I use, SQLDBChain without query_checker set to True its working OK but memory is not preserved. Thanks
Can you share your code snippet?
@gugupy : Pls find below:
I am connecting to a local SQL server and its working OK if I use below dbchain:
dbchain = SQLDatabaseChain(
#llm_chain=LLMChain(llm=llm, prompt=PROMPT, memory=memory),
llm=llm,
database=db,
verbose=True,
use_query_checker=True
)
However, memory is not retained even in the above case. If I use llm_chain approach (as shown below), query itself not working.
from langchain.prompts.prompt import PromptTemplate
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
from langchain.agents.agent_toolkits import create_sql_agent,SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.agents import create_csv_agent
from dotenv import load_dotenv
from langchain import PromptTemplate
import os
from langchain.llms import AzureOpenAI
from langchain.agents import load_tools, initialize_agent
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.memory import ConversationBufferMemory
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["OPENAI_API_VERSION"] = "2022-12-01"
os.environ["OPENAI_API_BASE"] = "<API_BASE>"
os.environ["OPENAI_API_KEY"] = "<API_KEY>"
server = 'localhost'
database = 'test'
username = 'testuser
password = 'testuser'
memory = ConversationBufferMemory(memory_key = 'history' , input_key = 'input')
_DEFAULT_TEMPLATE = """
Given an input question, first create a syntactically correct sql server query to run, then look at the results of the query and return the answer.
Do not give invalid SQL queries.
Use the following format:
Question: “Question here”
SQLQuery: “SQL Query to run”
SQLResult: “Result of the SQLQuery”
Answer: “Final answer here”
Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information if not relevant)
Question: {input}"""
print(memory)
PROMPT = PromptTemplate(input_variables=["input","history"], template=_DEFAULT_TEMPLATE)
llm = AzureOpenAI(deployment_name="model-gpt-35",model_name="gpt-35-turbo", temperature=0)
# Create the connection string
conn_str = f"mssql+pymssql://{username}:{password}@{server}:1433/{database}"
# Setup database
db = SQLDatabase.from_uri(
conn_str,
)
dbchain = SQLDatabaseChain(
llm_chain=LLMChain(llm=llm, prompt=PROMPT, memory=memory),
#llm=llm,
database=db,
verbose=True,
use_query_checker=True
)
question = "get me the store which sold more bikes"
result = dbchain.run(question)
print(result)
question = "get the address of the above store"
result = dbchain.run(question)
print(result)
@keenborder786 : Is the memory issue resolved. If so, is it committed. I went through the PR and did the needed changes in my local instance of langchain_experimental (base.py) file as per your PR but sql query memory is not retained even then.
i have same issue with this, i want to use this method from this link
https://python.langchain.com/docs/use_cases/question_answering/how_to/multiple_retrieval
any documentation that implement memory at SQLDatabaseChain and using that for multiple sources?
thank you.
@ameerhakme did you manage to solve the problem? I got the same error message
Is memory available for sqldatabasechain if yes can someone share the right snippet of code ?
this works for me
@gugupy thanks!
llm = ChatOpenAI(temperature=0, model=openai_model_name, verbose=verbose)
db = SQLDatabase.from_uri(
CONN_STRING,
include_tables=include_tables,
schema=postgresql_schema,
sample_rows_in_table_info=3
memory = ConversationBufferMemory(input_key='input', memory_key="history")
dbchain = SQLDatabaseChain(
llm_chain=LLMChain(llm=llm, prompt=prompt, memory=memory),
database=db,
verbose=verbose
)
I found a way that works perfectly. This technique involves ChatPromptTemplate, Memory (not passed in SQLDatabaseChain, so saving context manually) Here are the steps
- Create ChatPromptTemplate (instead of PromptTemplate) as follows
- system, Your usual prompt here
- MessagePlaceholder = history
- human, {input}
- Create a memory object
- Create the prompt value with as usual, with required variables along with history = memory.load_memory_variable({})['history']
- Pass prompt value to SQLDatabaseChain, get the results
- Save the context in memory with user input query and result from chain
Code
chat_template = """ Based on the schema given {info} write an executable query for the user input.
Execute it in the database and get sql results. Make a response to user from sql results based on
the question.
Input: "user input"
SQL query: "SQL Query here"
"""
chat_prompt = ChatPromptTemplate.from_messages([
('system', chat_template),
MessagesPlaceholder(variable_name='history'),
('human', "{input}")
])
llm = GooglePalm(temperature=0.2)
db = SQLDatabase.from_uri('sqlite:///Chinook.db')
table_info = db.table_info
m1 = ConversationBufferWindowMemory(k=4,return_messages=True)
db_chain = SQLDatabaseChain.from_llm(llm, db,verbose = True)
while True:
query = input('human:')
if query != '':
chat = m1.load_memory_variables({})['history']
prompt = chat_prompt.format(info=table_info, history=chat, input=query)
response = db_chain.run(prompt)
m1.save_context({'input': query}, {'output': response})
else:
break
Hi @Annamalai-S thanks also got it to retain memory, but how to integrate into gradio? Will that mean removing the while True loop?