ibis
ibis copied to clipboard
bug: row_number() with `.select()` doesn't have `OVER ...` clause when used after join (`.mutate()` works correctly)
What happened?
import ibis
be = ibis.get_backend()
t1 = be.create_table("t1", {"x": [1, 2, 3]}, overwrite=True)
t2 = be.create_table("t2", {"x": [2, 3, 4]}, overwrite=True)
j = t1.join(t2, "x")
rn_mutate = j.mutate(rn=ibis.row_number())
ibis.to_sql(rn_mutate)
# SELECT
# "t4"."x",
# ROW_NUMBER() OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "rn"
# FROM (
# SELECT
# "t2"."x"
# FROM "memory"."main"."t1" AS "t2"
# INNER JOIN "memory"."main"."t2" AS "t3"
# ON "t2"."x" = "t3"."x"
# ) AS "t4"
rn_select = j.select(*j.columns, rn=ibis.row_number())
ibis.to_sql(rn_select)
# SELECT
# "t2"."x",
# ROW_NUMBER() AS "rn"
# FROM "memory"."main"."t1" AS "t2"
# INNER JOIN "memory"."main"."t2" AS "t3"
# ON "t2"."x" = "t3"."x"
If I try to execute the rn_select, I get Catalog Error: Scalar Function with name row_number does not exist!
We are just missing the OVER ... clause: If we add it, then it works:
be.sql(
"""
SELECT
"t2"."x",
ROW_NUMBER() OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) - 1 AS "rn"
FROM "memory"."main"."t1" AS "t2"
INNER JOIN "memory"."main"."t2" AS "t3"
ON "t2"."x" = "t3"."x"
"""
)
What version of ibis are you using?
main
What backend(s) are you using, if any?
All backends probably suffer from this?
Relevant log output
No response
Code of Conduct
- [X] I agree to follow this project's Code of Conduct
Ok, I don't know how to fix this yet, but the issue is this:
When you call mutate on a JoinChain, we finalize the JoinChain, then call mutate on it, which leads to the row_number call being correctly inferred as a window function in a separate project:
r0 := DatabaseTable: memory.main.t1
x int64
r1 := DatabaseTable: memory.main.t2
x int64
r2 := JoinChain[r0]
JoinLink[inner, r1]
r0.x == r1.x
values:
x: r0.x
Project[r2]
x: r2.x
rn: WindowFunction(func=RowNumber(), how='rows')
But the Join class has special handling for select to allow for dereferencing field names and that's mucking up the window function creation:
r0 := DatabaseTable: memory.main.t1
x int64
r1 := DatabaseTable: memory.main.t2
x int64
JoinChain[r0]
JoinLink[inner, r1]
r0.x == r1.x
values:
x: r0.x
rn: RowNumber()