polars
polars copied to clipboard
Quadratic scaling in expression depth when collecting expressions
Checks
- [X] I have checked that this issue has not already been reported.
- [X] I have confirmed this bug exists on the latest version of Polars.
Reproducible example
import sys
import time
import polars as pl
expr = pl.col("a")
for i in range(int(sys.argv[1])):
expr += pl.col("a")
df = pl.LazyFrame({"a": [1]})
q = df.select(expr)
ldf = q._ldf.optimization_toggle(
True,
True,
True,
True,
True,
True,
True,
False,
False,
)
start = time.perf_counter()
ldf.collect()
print(f"{sys.argv[1]}: {time.perf_counter() - start}")
for i in 1 200 400 800 1600; do python run.py $i; done
1: 0.0029712169998674653
200: 0.0200193919990852
materialized names collided in common subexpression elimination.
400: 0.07028934900154127
materialized names collided in common subexpression elimination.
800: 0.28725086899794405
materialized names collided in common subexpression elimination.
1600: 1.2589194770007452
Log output
backtrace and run without CSE
backtrace and run without CSE
backtrace and run without CSE
Issue description
Optimising, I think, the logical plan has performance that is quadratic in the expression depth. I think the culprit is AExpr::to_field
which is called to attach a dtype at every node in the graph. But since the result is not cached, this is O(N) for a node, and done O(N) times.
Running samply python run.py 1600
shows almost all the time is in get_arithmetic_field -> to_field -> stacker::maybe_grow -> get_arithmetic_field -> ...
Expected behavior
Although this is somewhat pathological, I would expect this to be linear in the expression depth.
Installed versions
--------Version info---------
Polars: 0.20.26
Index type: UInt32
Platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35
Python: 3.11.9 (main, Apr 3 2024, 16:33:49) [GCC 11.4.0]
----Optional dependencies----
adbc_driver_manager: <not installed>
cloudpickle: <not installed>
connectorx: <not installed>
deltalake: <not installed>
fastexcel: <not installed>
fsspec: <not installed>
gevent: <not installed>
hvplot: <not installed>
matplotlib: <not installed>
nest_asyncio: <not installed>
numpy: <not installed>
openpyxl: <not installed>
pandas: <not installed>
pyarrow: <not installed>
pydantic: <not installed>
pyiceberg: <not installed>
pyxlsb: <not installed>
sqlalchemy: <not installed>
torch: <not installed>
xlsx2csv: <not installed>
xlsxwriter: <not installed>
How does it perform if you turn off comm_subexpr_elim
?
And how do you get so deep expressions? 😉
How does it perform if you turn off
comm_subexpr_elim
?
Same, indeed with all optimisations off it's still quadratic, just faster overall:
for i in 1 200 400 800 1600; do python slow-optimise.py ${i}; done
1: 0.0027743840000766795
200: 0.010377078000601614
400: 0.030803929999819957
800: 0.12257709299956332
1600: 0.47342776200093795
And how do you get so deep expressions? 😉
I've seen things...
It was more that I was transpiling the plan and noticed performance slowdowns when annotating the logical plan with dtypes for every expression node.
Will see if we can apply some memoization in the to_field
call.
I am not entirely sure the tree itself isn't quadratic. I need to do double the to_field
calls of the depth. Because at every depth, we branch.
I don't think so. The expression is (for $N = 4$)
((((a + a) + a) + a) + a)
Or
(+)
| \
(+) a
| \
(+) a
| \
(+) a
| \
a a
So if you need to call to_field
on every node, that's $2 N - 1$ calls, which is why you're getting double the depth, I think. But at every node, to_field
needs to recurse into the subtree if we're not memoising, or otherwise doing a one-pass bottom-up production of the dtypes of every node given a schema context.
BTW: should I expect that the AExpr processing treats expressions with common sub-expressions like DAGs, or like trees (increasing the perceived size?).
That is, if I write:
expr = pl.col("a")
expr = expr + expr
expr2 = expr + expr
Which is:
(+)
/ \
\ /
(+)
/ \
\ /
a
Is it "seen" as:
(+)
/ \
(+) (+)
/ | | \
a a a a
It is seen as trees. They are turned into DAGS during execution if CSE recognized it.
It is seen as trees. They are turned into DAGS during execution if CSE recognized it.
Thanks. I suspect that means there are (perhaps pathological) cases where expression processing time can be exponential: any time there is sharing in the expression. Though perhaps that is an unlikely case for typical queries.
Alright, did a simple pointer print. On 10 iterations, we get 270 calls from which 20 are unique.
Will do 2 things to improve:
- add memoization.
- add top-level fields to IR state
Edit, memoization doesn't work. Those are not from the same call stack. We must cache the dtype on the IR state instead.
I've seen things...
👀🤣