Problem with sql_chain and quotation marks
Hi,
I observed an issue with sql_chain and quotation marks. The SQL that was send had quotation marks around and triggered an error in the DB.
This is the DB engine:
from sqlalchemy import create_engine
engine = create_engine("sqlite:///:memory:")
The solution is very simple. Just detect and remove quotation marks from the beginning and the end of the generated SQL statement.
What do you think?
PS: I can not replicate the error at the moment. So can not not provide any concrete error message. Sorry.
PPS: see code to reproduce and error message below
Here is the code to reproduce. It is taken from https://github.com/pinecone-io/examples/blob/master/generation/langchain/handbook/06-langchain-agents.ipynb
from getpass import getpass
OPENAI_API_KEY = getpass()
##
from langchain import OpenAI
llm = OpenAI(
openai_api_key=OPENAI_API_KEY,
temperature=0
)
##
from langchain.callbacks import get_openai_callback
def count_tokens(agent, query):
with get_openai_callback() as cb:
result = agent(query)
print(f'Spent a total of {cb.total_tokens} tokens')
return result
##
from sqlalchemy import MetaData
metadata_obj = MetaData()
##
from sqlalchemy import Column, Integer, String, Table, Date, Float
stocks = Table(
"stocks",
metadata_obj,
Column("obs_id", Integer, primary_key=True),
Column("stock_ticker", String(4), nullable=False),
Column("price", Float, nullable=False),
Column("date", Date, nullable=False),
)
##
from sqlalchemy import create_engine
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
##
from datetime import datetime
observations = [
[1, 'ABC', 200, datetime(2023, 1, 1)],
[2, 'ABC', 208, datetime(2023, 1, 2)],
[3, 'ABC', 232, datetime(2023, 1, 3)],
[4, 'ABC', 225, datetime(2023, 1, 4)],
[5, 'ABC', 226, datetime(2023, 1, 5)],
[6, 'XYZ', 810, datetime(2023, 1, 1)],
[7, 'XYZ', 803, datetime(2023, 1, 2)],
[8, 'XYZ', 798, datetime(2023, 1, 3)],
[9, 'XYZ', 795, datetime(2023, 1, 4)],
[10, 'XYZ', 791, datetime(2023, 1, 5)],
]
##
from sqlalchemy import insert
def insert_obs(obs):
stmt = insert(stocks).values(
obs_id=obs[0],
stock_ticker=obs[1],
price=obs[2],
date=obs[3]
)
with engine.begin() as conn:
conn.execute(stmt)
##
for obs in observations:
insert_obs(obs)
##
from langchain.sql_database import SQLDatabase
from langchain.chains import SQLDatabaseChain
db = SQLDatabase(engine)
sql_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
##
from langchain.agents import Tool
sql_tool = Tool(
name='Stock DB',
func=sql_chain.run,
description="Useful for when you need to answer questions about stocks " \
"and their prices."
)
##
from langchain.agents import load_tools
tools = load_tools(
["llm-math"],
llm=llm
)
##
tools.append(sql_tool)
##
from langchain.agents import initialize_agent
zero_shot_agent = initialize_agent(
agent="zero-shot-react-description",
tools=tools,
llm=llm,
verbose=True,
max_iterations=3,
)
##
result = count_tokens(
zero_shot_agent,
"What is the multiplication of the ratio between stock prices for 'ABC' and 'XYZ' in January 3rd and the ratio between the same stock prices in January the 4th?"
)
The error message is:
(langchain_dev) mike@MacBook-Air-Philip:~/code/git/langchain$ /Users/mike/miniconda3/envs/langchain_dev/bin/python /Users/mike/code/git/langchain/own_tests/i18n_nc.py
Password:
> Entering new AgentExecutor chain...
I need to compare the stock prices of 'ABC' and 'XYZ' on two different days
Action: Stock DB
Action Input: Stock prices of 'ABC' and 'XYZ' on January 3rd and January 4th
> Entering new SQLDatabaseChain chain...
Stock prices of 'ABC' and 'XYZ' on January 3rd and January 4th
SQLQuery: "SELECT stock_ticker, price, date FROM stocks WHERE (stock_ticker = 'ABC' OR stock_ticker = 'XYZ') AND (date = '2023-01-03' OR date = '2023-01-04') LIMIT 5"Traceback (most recent call last):
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1900, in _execute_context
self.dialect.do_execute(
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/default.py", line 736, in do_execute
cursor.execute(statement, parameters)
sqlite3.OperationalError: near ""SELECT stock_ticker, price, date FROM stocks WHERE (stock_ticker = 'ABC' OR stock_ticker = 'XYZ') AND (date = '2023-01-03' OR date = '2023-01-04') LIMIT 5"": syntax error
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/mike/code/git/langchain/own_tests/i18n_nc.py", line 121, in <module>
result = count_tokens(
File "/Users/mike/code/git/langchain/own_tests/i18n_nc.py", line 17, in count_tokens
result = agent(query)
File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 116, in __call__
raise e
File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 113, in __call__
outputs = self._call(inputs)
File "/Users/mike/code/git/langchain/langchain/agents/agent.py", line 792, in _call
next_step_output = self._take_next_step(
File "/Users/mike/code/git/langchain/langchain/agents/agent.py", line 695, in _take_next_step
observation = tool.run(
File "/Users/mike/code/git/langchain/langchain/tools/base.py", line 107, in run
raise e
File "/Users/mike/code/git/langchain/langchain/tools/base.py", line 104, in run
observation = self._run(*tool_args, **tool_kwargs)
File "/Users/mike/code/git/langchain/langchain/agents/tools.py", line 31, in _run
return self.func(*args, **kwargs)
File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 213, in run
return self(args[0])[self.output_keys[0]]
File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 116, in __call__
raise e
File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 113, in __call__
outputs = self._call(inputs)
File "/Users/mike/code/git/langchain/langchain/chains/sql_database/base.py", line 86, in _call
result = self.database.run(sql_cmd)
File "/Users/mike/code/git/langchain/langchain/sql_database.py", line 220, in run
cursor = connection.execute(text(command))
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1380, in execute
return meth(self, multiparams, params, _EMPTY_EXECUTION_OPTS)
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/sql/elements.py", line 334, in _execute_on_connection
return connection._execute_clauseelement(
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1572, in _execute_clauseelement
ret = self._execute_context(
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1943, in _execute_context
self._handle_dbapi_exception(
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 2124, in _handle_dbapi_exception
util.raise_(
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/util/compat.py", line 211, in raise_
raise exception
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1900, in _execute_context
self.dialect.do_execute(
File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/default.py", line 736, in do_execute
cursor.execute(statement, parameters)
sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) near ""SELECT stock_ticker, price, date FROM stocks WHERE (stock_ticker = 'ABC' OR stock_ticker = 'XYZ') AND (date = '2023-01-03' OR date = '2023-01-04') LIMIT 5"": syntax error
[SQL: "SELECT stock_ticker, price, date FROM stocks WHERE (stock_ticker = 'ABC' OR stock_ticker = 'XYZ') AND (date = '2023-01-03' OR date = '2023-01-04') LIMIT 5"]
(Background on this error at: https://sqlalche.me/e/14/e3q8)
IMO after this line:
https://github.com/hwchase17/langchain/blob/acfd11c8e424a456227abde8df8b52a705b63024/langchain/chains/sql_database/base.py#L83
We should add a normalization that:
-
strip()the SQL string - Checks if quotation marks are at the beginning and the end and then removed them.
Started a PR: https://github.com/hwchase17/langchain/pull/3385
For what its worth adding this into the default prompt template for a database chain worked
Only have single quotes on any sql command sent to the engine.
Hope its of some help
Duplicate of #2027
I'm not sure, but to me the problem is that the prompt template is wrong in asking for an SQL-statement with quotes. I think it is better to just not do that: https://github.com/hwchase17/langchain/pull/4101
Hi, @PhilipMay! I'm Dosu, and I'm helping the LangChain team manage their backlog. I wanted to let you know that we are marking this issue as stale.
From what I understand, the issue is with the sql_chain function adding quotation marks around the generated SQL statement, which causes a syntax error in the database. You have started a PR to address this issue, and other users have suggested solutions as well. isdsava suggests adding a default prompt template that only allows single quotes in SQL commands, while hansvdam suggests modifying the prompt template to not ask for an SQL statement with quotes.
Before we close this issue, we wanted to check if it is still relevant to the latest version of the LangChain repository. If it is, please let us know by commenting on the issue. Otherwise, feel free to close the issue yourself, or it will be automatically closed in 7 days.
Thank you for your contribution and understanding!
Facing same issue. Also, Mixed case columns is a problem on OracleDB