test-suite-sql-eval
test-suite-sql-eval copied to clipboard
Bug with tables aliases
Currently creating mapping between aliases and actual tables names is performed globally for whole SQL, what is not correct for more complicated once.
Lets consider following SQL taken directly from SPIDER (there are more queries like this):
SELECT T1.asset_id , T1.asset_details FROM Assets AS T1 JOIN Asset_Parts AS T2 ON T1.asset_id = T2.asset_id GROUP BY T1.asset_id HAVING count(*) = 2
INTERSECT
SELECT T1.asset_id , T1.asset_details FROM Assets AS T1 JOIN Fault_Log AS T2 ON T1.asset_id = T2.asset_id GROUP BY T1.asset_id HAVING count(*) < 2
First subquery maps T2
to Asset_Parts
, but second subquery maps T2
to Fault_Log
. Current evaluation script is not able to handle this. It may allow false positives to pass and throw exceptions because script is searching for column in wrong table.
Relevant code fragment: https://github.com/taoyds/test-suite-sql-eval/blob/e97acc546ecbee8fa27fa8dbf025ef61493a876c/process_sql.py#L150-L156
Hi, I just write a table alias expander with sqlglot
to eliminate all table aliases in SQL before feeding them into this test suite for EM.
import sqlglot
from sqlglot import exp
def expand_alias(expr: exp.Expression) -> exp.Expression:
for iue in expr.find_all(exp.Union): # match all INTERSECT, UNION, EXCEPT
assert isinstance(iue.left, exp.Expression) and isinstance(iue.right, exp.Expression)
iue.set("left", expand_alias(iue.left))
iue.set("right", expand_alias(iue.right))
alias_to_table: dict[str, str] = {}
for table in expr.find_all(exp.Table):
if not table.alias:
continue
alias = table.alias.lower()
tablename = table.name.lower()
alias_to_table[alias] = tablename
table.set("this", tablename)
table.set("alias", None)
for column in expr.find_all(exp.Column):
if column.table:
column_tablename = column.table.lower()
if column_tablename in alias_to_table:
actual_table = alias_to_table[column_tablename]
column.set("table", actual_table)
return expr
sql = """
select T1.aaa from table1 as T1 join table2 as T2 on T1.aaa = T2.aaa
intersect
select T1.aaa from table1 as T1 join table3 as T2 on T1.aaa = T2.aaa
"""
expr = sqlglot.parse_one(sql)
refined_sql = expand_alias(expr).sql()
print(refined_sql)
# OUTPUT:
# SELECT table1.aaa FROM table1 JOIN table2 ON table1.aaa = table2.aaa
# INTERSECT
# SELECT table1.aaa FROM table1 JOIN table3 ON table1.aaa = table3.aaa
NOTES:
- sqlglot prefer
<>
over!=
, but this repo prefer!=
you would likep_str_new = expand_alias(sqlglot.parse_one(p_str)).sql().replace(" <> ", " != ")
- sqlglot prefer
WHERE NOT aaa IN bbb
, but this repo only supportsWHERE aaa NOT IN bbb
To add the support, check 'not' before call parse_val_unit here: https://github.com/taoyds/test-suite-sql-eval/blob/e97acc546ecbee8fa27fa8dbf025ef61493a876c/process_sql.py#L309-L310
It also supports table aliases without AS
keywords, so it's helpful for #17 as well.
Maybe someone should refactor this repo with sqlglot for better readability and extensibility...
Thanks for the nice piece of code! It seems to be the easiest solution. I'll keep this issue open to show the presence of this problem.