langchain icon indicating copy to clipboard operation
langchain copied to clipboard

Problem with sql_chain and quotation marks

Open PhilipMay opened this issue 3 years ago • 6 comments

Hi,

I observed an issue with sql_chain and quotation marks. The SQL that was send had quotation marks around and triggered an error in the DB.

This is the DB engine:

from sqlalchemy import create_engine
engine = create_engine("sqlite:///:memory:")

The solution is very simple. Just detect and remove quotation marks from the beginning and the end of the generated SQL statement.

What do you think?

PS: I can not replicate the error at the moment. So can not not provide any concrete error message. Sorry.

PPS: see code to reproduce and error message below

PhilipMay avatar Apr 21 '23 15:04 PhilipMay

Here is the code to reproduce. It is taken from https://github.com/pinecone-io/examples/blob/master/generation/langchain/handbook/06-langchain-agents.ipynb

from getpass import getpass
OPENAI_API_KEY = getpass()

##

from langchain import OpenAI
llm = OpenAI(
    openai_api_key=OPENAI_API_KEY,
    temperature=0
)

##

from langchain.callbacks import get_openai_callback
def count_tokens(agent, query):
    with get_openai_callback() as cb:
        result = agent(query)
        print(f'Spent a total of {cb.total_tokens} tokens')

    return result

##

from sqlalchemy import MetaData
metadata_obj = MetaData()

##

from sqlalchemy import Column, Integer, String, Table, Date, Float
stocks = Table(
    "stocks",
    metadata_obj,
    Column("obs_id", Integer, primary_key=True),
    Column("stock_ticker", String(4), nullable=False),
    Column("price", Float, nullable=False),
    Column("date", Date, nullable=False),
)

##

from sqlalchemy import create_engine
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)

##

from datetime import datetime
observations = [
    [1, 'ABC', 200, datetime(2023, 1, 1)],
    [2, 'ABC', 208, datetime(2023, 1, 2)],
    [3, 'ABC', 232, datetime(2023, 1, 3)],
    [4, 'ABC', 225, datetime(2023, 1, 4)],
    [5, 'ABC', 226, datetime(2023, 1, 5)],
    [6, 'XYZ', 810, datetime(2023, 1, 1)],
    [7, 'XYZ', 803, datetime(2023, 1, 2)],
    [8, 'XYZ', 798, datetime(2023, 1, 3)],
    [9, 'XYZ', 795, datetime(2023, 1, 4)],
    [10, 'XYZ', 791, datetime(2023, 1, 5)],
]

##

from sqlalchemy import insert
def insert_obs(obs):
    stmt = insert(stocks).values(
    obs_id=obs[0],
    stock_ticker=obs[1],
    price=obs[2],
    date=obs[3]
    )
    with engine.begin() as conn:
        conn.execute(stmt)

##

for obs in observations:
    insert_obs(obs)

##

from langchain.sql_database import SQLDatabase
from langchain.chains import SQLDatabaseChain
db = SQLDatabase(engine)
sql_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)

##

from langchain.agents import Tool
sql_tool = Tool(
    name='Stock DB',
    func=sql_chain.run,
    description="Useful for when you need to answer questions about stocks " \
                "and their prices."
)

##

from langchain.agents import load_tools
tools = load_tools(
    ["llm-math"],
    llm=llm
)

##

tools.append(sql_tool)

##

from langchain.agents import initialize_agent
zero_shot_agent = initialize_agent(
    agent="zero-shot-react-description",
    tools=tools,
    llm=llm,
    verbose=True,
    max_iterations=3,
)

##

result = count_tokens(
    zero_shot_agent,
    "What is the multiplication of the ratio between stock prices for 'ABC' and 'XYZ' in January 3rd and the ratio between the same stock prices in January the 4th?"
)

The error message is:

(langchain_dev) mike@MacBook-Air-Philip:~/code/git/langchain$ /Users/mike/miniconda3/envs/langchain_dev/bin/python /Users/mike/code/git/langchain/own_tests/i18n_nc.py
Password: 


> Entering new AgentExecutor chain...
 I need to compare the stock prices of 'ABC' and 'XYZ' on two different days
Action: Stock DB
Action Input: Stock prices of 'ABC' and 'XYZ' on January 3rd and January 4th

> Entering new SQLDatabaseChain chain...
Stock prices of 'ABC' and 'XYZ' on January 3rd and January 4th
SQLQuery: "SELECT stock_ticker, price, date FROM stocks WHERE (stock_ticker = 'ABC' OR stock_ticker = 'XYZ') AND (date = '2023-01-03' OR date = '2023-01-04') LIMIT 5"Traceback (most recent call last):
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1900, in _execute_context
    self.dialect.do_execute(
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/default.py", line 736, in do_execute
    cursor.execute(statement, parameters)
sqlite3.OperationalError: near ""SELECT stock_ticker, price, date FROM stocks WHERE (stock_ticker = 'ABC' OR stock_ticker = 'XYZ') AND (date = '2023-01-03' OR date = '2023-01-04') LIMIT 5"": syntax error

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/mike/code/git/langchain/own_tests/i18n_nc.py", line 121, in <module>
    result = count_tokens(
  File "/Users/mike/code/git/langchain/own_tests/i18n_nc.py", line 17, in count_tokens
    result = agent(query)
  File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 116, in __call__
    raise e
  File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 113, in __call__
    outputs = self._call(inputs)
  File "/Users/mike/code/git/langchain/langchain/agents/agent.py", line 792, in _call
    next_step_output = self._take_next_step(
  File "/Users/mike/code/git/langchain/langchain/agents/agent.py", line 695, in _take_next_step
    observation = tool.run(
  File "/Users/mike/code/git/langchain/langchain/tools/base.py", line 107, in run
    raise e
  File "/Users/mike/code/git/langchain/langchain/tools/base.py", line 104, in run
    observation = self._run(*tool_args, **tool_kwargs)
  File "/Users/mike/code/git/langchain/langchain/agents/tools.py", line 31, in _run
    return self.func(*args, **kwargs)
  File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 213, in run
    return self(args[0])[self.output_keys[0]]
  File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 116, in __call__
    raise e
  File "/Users/mike/code/git/langchain/langchain/chains/base.py", line 113, in __call__
    outputs = self._call(inputs)
  File "/Users/mike/code/git/langchain/langchain/chains/sql_database/base.py", line 86, in _call
    result = self.database.run(sql_cmd)
  File "/Users/mike/code/git/langchain/langchain/sql_database.py", line 220, in run
    cursor = connection.execute(text(command))
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1380, in execute
    return meth(self, multiparams, params, _EMPTY_EXECUTION_OPTS)
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/sql/elements.py", line 334, in _execute_on_connection
    return connection._execute_clauseelement(
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1572, in _execute_clauseelement
    ret = self._execute_context(
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1943, in _execute_context
    self._handle_dbapi_exception(
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 2124, in _handle_dbapi_exception
    util.raise_(
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/util/compat.py", line 211, in raise_
    raise exception
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/base.py", line 1900, in _execute_context
    self.dialect.do_execute(
  File "/Users/mike/miniconda3/envs/langchain_dev/lib/python3.9/site-packages/sqlalchemy/engine/default.py", line 736, in do_execute
    cursor.execute(statement, parameters)
sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) near ""SELECT stock_ticker, price, date FROM stocks WHERE (stock_ticker = 'ABC' OR stock_ticker = 'XYZ') AND (date = '2023-01-03' OR date = '2023-01-04') LIMIT 5"": syntax error
[SQL:  "SELECT stock_ticker, price, date FROM stocks WHERE (stock_ticker = 'ABC' OR stock_ticker = 'XYZ') AND (date = '2023-01-03' OR date = '2023-01-04') LIMIT 5"]
(Background on this error at: https://sqlalche.me/e/14/e3q8)

PhilipMay avatar Apr 23 '23 05:04 PhilipMay

IMO after this line:

https://github.com/hwchase17/langchain/blob/acfd11c8e424a456227abde8df8b52a705b63024/langchain/chains/sql_database/base.py#L83

We should add a normalization that:

  1. strip() the SQL string
  2. Checks if quotation marks are at the beginning and the end and then removed them.

PhilipMay avatar Apr 23 '23 06:04 PhilipMay

Started a PR: https://github.com/hwchase17/langchain/pull/3385

PhilipMay avatar Apr 23 '23 10:04 PhilipMay

For what its worth adding this into the default prompt template for a database chain worked Only have single quotes on any sql command sent to the engine. Hope its of some help

isdsava avatar May 03 '23 08:05 isdsava

Duplicate of #2027

hansvdam avatar May 04 '23 09:05 hansvdam

I'm not sure, but to me the problem is that the prompt template is wrong in asking for an SQL-statement with quotes. I think it is better to just not do that: https://github.com/hwchase17/langchain/pull/4101

hansvdam avatar May 04 '23 10:05 hansvdam

Hi, @PhilipMay! I'm Dosu, and I'm helping the LangChain team manage their backlog. I wanted to let you know that we are marking this issue as stale.

From what I understand, the issue is with the sql_chain function adding quotation marks around the generated SQL statement, which causes a syntax error in the database. You have started a PR to address this issue, and other users have suggested solutions as well. isdsava suggests adding a default prompt template that only allows single quotes in SQL commands, while hansvdam suggests modifying the prompt template to not ask for an SQL statement with quotes.

Before we close this issue, we wanted to check if it is still relevant to the latest version of the LangChain repository. If it is, please let us know by commenting on the issue. Otherwise, feel free to close the issue yourself, or it will be automatically closed in 7 days.

Thank you for your contribution and understanding!

dosubot[bot] avatar Sep 17 '23 17:09 dosubot[bot]

Facing same issue. Also, Mixed case columns is a problem on OracleDB

pvangara avatar Oct 11 '23 21:10 pvangara