langchain icon indicating copy to clipboard operation
langchain copied to clipboard

pref: reduce DB query error rate

Open Undertone0809 opened this issue 1 year ago • 1 comments

Reduce DB query error rate

If you use sql agent of SQLDatabaseToolkit to query data, it is prone to errors in query fields and often uses fields that do not exist in database tables for queries. However, the existing prompt does not effectively make the agent aware that there are problems with the fields they query. At this time, we urgently need to improve the prompt so that the agent realizes that they have queried non-existent fields and allows them to use the schema_sql_db, that is, ListSQLDatabaseTool first queries the corresponding fields in the table in the database, and then uses QuerySQLDatabaseTool for querying.

There is a demo of my project to show this problem.

Original Agent

def create_mysql_kit():
    db = SQLDatabase.from_uri("mysql+pymysql://xxxxxxx")
    llm = OpenAI(temperature=0)

    toolkit = SQLDatabaseToolkit(db=db, llm=llm)
    agent_executor = create_sql_agent(
        llm=OpenAI(temperature=0),
        toolkit=toolkit,
        verbose=True
    )
    agent_executor.run("Who are the users of sysuser in this system? Tell me the username of all users")


if __name__ == '__main__':
    create_mysql_kit()

original output

> Entering new AgentExecutor chain...
Action: list_tables_sql_db
Action Input: ""
Observation: app_sysrole_menus, app_bimfacemodel, app_project_users, app_measuringpointdata, auth_user, auth_user_groups, django_apscheduler_djangojobexecution, app_project, app_elementpoint, django_apscheduler_djangojob, django_content_type, app_sysrole, django_admin_log, app_bimfaceaccount, app_measuringpoint_warning_thresholds, app_measuringpoint, app_company, auth_group_permissions, app_sysuser, app_sysuser_companies, app_sysmenu, app_datawarningthreshold, auth_group, auth_permission, app_datawarningrecord, auth_user_user_permissions, app_bimfaceaccount_bimface_models, django_migrations, app_measuringitem, django_session
Thought: I should query the app_sysuser table to get the users in the system.
Action: query_sql_db
Action Input: SELECT username FROM app_sysuser LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'")
[SQL: SELECT username FROM app_sysuser LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT username FROM app_sysuser LIMIT 10;
Observation: 

SELECT username FROM app_sysuser LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT username FROM app_sysuser LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'")
[SQL: SELECT username FROM app_sysuser LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT name FROM app_sysuser LIMIT 10;
Observation: 

SELECT name FROM app_sysuser LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT name FROM app_sysuser LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'name' in 'field list'")
[SQL: SELECT name FROM app_sysuser LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT id, username FROM app_sysuser LIMIT 10;
Observation: 

SELECT id, username FROM app_sysuser LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT id, username FROM app_sysuser LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'id' in 'field list'")
[SQL: SELECT id, username FROM app_sysuser LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT user_id, username FROM app_sysuser LIMIT 10;
Observation: 

SELECT user_id, username FROM app_sysuser WHERE user_id IS NOT NULL AND user_id BETWEEN 0 AND 10 UNION ALL SELECT user_id, username FROM app_sysuser WHERE username IS NOT NULL ORDER BY user_id ASC LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT user_id, username FROM app_sysuser WHERE user_id IS NOT NULL AND user_id BETWEEN 0 AND 10 UNION ALL SELECT user_id, username FROM app_sysuser WHERE username IS NOT NULL ORDER BY user_id ASC LIMIT 10;
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'")
[SQL: SELECT user_id, username FROM app_sysuser WHERE user_id IS NOT NULL AND user_id BETWEEN 0 AND 10 UNION ALL SELECT user_id, username FROM app_sysuser WHERE username IS NOT NULL ORDER BY user_id ASC LIMIT 10;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should double check my query before executing it.
Action: query_checker_sql_db
Action Input: SELECT user_id, user_name FROM app_sysuser LIMIT 10;
Observation: 

SELECT user_id, user_name FROM app_sysuser WHERE user_id NOT IN (NULL) AND user_name BETWEEN 'A' AND 'Z' UNION ALL SELECT user_id, user_name FROM app_sysuser WHERE user_id IS NOT NULL AND user_name LIKE '%a%' LIMIT 10;
Thought: The query looks correct, so I should execute it.
Action: query_sql_db
Action Input: SELECT user_id, user_name FROM app_sysuser WHERE user_id NOT IN (NULL) AND user_name BETWEEN 'A' AND 'Z' UNION ALL SELECT user_id, user_name FROM app_sysuser WHERE user_id IS NOT NULL AND user_name LIKE '%a%' LIMIT 10;
Observation: []
Thought: I now know the final answer.
Final Answer: There are no users in the sysuser table.

> Finished chain.

Process finished with exit code 0

As you can see, agent always use incorrect sql to query the data. Actually, the username field name of my app_sysuser table is user_name, not the username believed by the agent.

Optimized agent output

> Entering new AgentExecutor chain...
Action: list_tables_sql_db
Action Input: ""
Observation: app_project_users, django_content_type, app_measuringpoint_warning_thresholds, app_bimfaceaccount_bimface_models, django_session, app_sysuser, django_migrations, app_bimfacemodel, app_elementpoint, app_measuringpoint, django_apscheduler_djangojobexecution, auth_permission, app_sysuser_companies, app_measuringpointdata, app_measuringitem, app_sysrole_menus, django_admin_log, auth_group, django_apscheduler_djangojob, app_sysmenu, app_project, app_bimfaceaccount, app_datawarningthreshold, app_datawarningrecord, auth_user_groups, auth_user_user_permissions, auth_group_permissions, auth_user, app_company, app_sysrole
Thought: I should query the app_sysuser table to get the usernames of all the users.
Action: query_sql_db
Action Input: SELECT username FROM app_sysuser
Observation: Error: (pymysql.err.OperationalError) (1054, "Unknown column 'username' in 'field list'")
[SQL: SELECT username FROM app_sysuser]
(Background on this error at: https://sqlalche.me/e/20/e3q8)
Thought: I should use the schema_sql_db tool to get the correct table fields.
Action: schema_sql_db
Action Input: app_sysuser
Observation: 
CREATE TABLE app_sysuser (
	user_id INTEGER NOT NULL AUTO_INCREMENT, 
	user_account VARCHAR(50) COLLATE utf8mb4_general_ci NOT NULL, 
	user_password VARCHAR(255) COLLATE utf8mb4_general_ci NOT NULL, 
	user_name VARCHAR(50) COLLATE utf8mb4_general_ci NOT NULL, 
	user_avatar VARCHAR(255) COLLATE utf8mb4_general_ci, 
	user_email VARCHAR(254) COLLATE utf8mb4_general_ci, 
	user_phone VARCHAR(20) COLLATE utf8mb4_general_ci, 
	user_create_time DATETIME(6) NOT NULL, 
	user_update_time DATETIME(6) NOT NULL, 
	user_status SMALLINT, 
	role_id_id INTEGER, 
	PRIMARY KEY (user_id), 
	CONSTRAINT app_sysuser_role_id_id_51b50f39_fk_app_sysrole_role_id FOREIGN KEY(role_id_id) REFERENCES app_sysrole (role_id)
)COLLATE utf8mb4_general_ci DEFAULT CHARSET=utf8mb4 ENGINE=InnoDB

/*
3 rows from app_sysuser table:
user_id	user_account	user_password	user_name	user_avatar	user_email	user_phone	user_create_time	user_update_time	user_status	role_id_id
xxxxxxxxxxxxxx
*/
Thought: I should query the app_sysuser table to get the usernames of all the users.
Action: query_sql_db
Action Input: SELECT user_account FROM app_sysuser LIMIT 10
Observation: [('baiyun',), ('eatrice',), ('lisi',), ('pingxiang',), ('wangwu',), ('zeeland',), ('zsj',), ('zzw',)]
Thought: I now know the final answer
Final Answer: The usernames of the users in the sysuser table are baiyun, eatrice, lisi, pingxiang, wangwu, zeeland, zsj, and zzw.

> Finished chain.

Process finished with exit code 0

I have tested about 10 related prompts and they all work properly, with a much lower error rate compared to before

Who can review?

@vowelparrot

Undertone0809 avatar May 27 '23 11:05 Undertone0809

@vowelparrot plz review

Undertone0809 avatar May 27 '23 11:05 Undertone0809

@hwchase17 Actually, schema_sql_db and the other three are only used by the SQLDatabaseToolkit. It is used together. In addition, the other three tools all have mutual references. Eg InfoSQLDatabaseTool description and QueryCheckerTool description

InfoSQLDatabaseTool reference list_tables_sql_db. QueryCheckerTool reference query_sql_db.

Undertone0809 avatar May 28 '23 06:05 Undertone0809

I think it's a good solution right now. It can well improve the operation effect of agent, the previous operation effect is not so good. Perhaps we should try to build some kind of abstraction to make toolkit more compatible with different tools.

Undertone0809 avatar May 29 '23 13:05 Undertone0809

I think it's a good solution right now. It can well improve the operation effect of agent, the previous operation effect is not so good. Perhaps we should try to build some kind of abstraction to make toolkit more compatible with different tools.

i get that its currently only used in toolkit, but theres nothign to garuntee that thats the case

What I would do is: (1) keep this description as is by default (it is a good default for isolation, which is what defaults should be), (2) pass in the updated description when initializing inside the toolkit, eg QuerySQLDataBaseTool(db=self.db, description="...") in agent_toolkits/sql/toolkit.py

hwchase17 avatar May 29 '23 13:05 hwchase17

Thanks for your suggestion! Passing parameter is a better solution. What do you think about new commit?

Undertone0809 avatar May 29 '23 14:05 Undertone0809

I don't know how to solve this github action problem. Can you give me some suggestion?

Undertone0809 avatar May 29 '23 15:05 Undertone0809

@hwchase17 plz review

Undertone0809 avatar May 31 '23 06:05 Undertone0809