spark
spark copied to clipboard
[SPARK-47070] Fix invalid aggregation after in-subquery rewrite
What changes were proposed in this pull request?
tl;dr
This PR fixes a bug related to an exists
variable being lost after an incorrect subquery rewrite.
Details Imagine we had a plan with a subquery:
Aggregate [a1#0] [CASE WHEN a1#0 IN (list#3999 []) THEN Hello ELSE Hi END AS strCol#13]
: +- LocalRelation <empty>, [b1#3, b2#4, b3#5]
+ LocalRelation <empty>, [a1#0, a2#1, a3#2]
During correlated subquery rewrite, the rule RewritePredicateSubquery
would rewrite expression a1#0 IN (list#3999 [])
into exists#12
and replace the subquery with ExistenceJoin
, like so:
Aggregate [a1#0] [CASE WHEN exists#12 THEN Hello ELSE Hi END AS strCol#13]
+- Join ExistenceJoin(exists#12), (a1#0 = b1#3)
+- LocalRelation <empty>, [a1#0, a2#1, a3#2]
+- LocalRelation <empty>, [b1#3, b2#4, b3#5]
Note that exists#12
doesn't appear neither in the grouping expressions, nor is part of any aggregate function. This is an invalid aggregation. In particular, aggregate pushdown rule rewrite this plan into:
Project [CASE WHEN exists#12 THEN Hello WHEN true THEN Hi END AS strCol#13]
+- AggregatePart [a1#0], true
+- AggregatePart [a1#0], false
+- Join ExistenceJoin(exists#12), (a1#0 = b1#3)
:- AggregatePart [a1#0], false
: +- LocalRelation <empty>, [a1#0, a2#1, a3#2]
+- AggregatePart [b1#3], false
+- LocalRelation <empty>, [b1#3, b2#4, b3#5]
The decision was to fix the bug in the RewritePredicateSubquery
by enforcing the condition that newly introduced variables, if referenced among agg expressions, must either participate in aggregate functions, or appear in the grouping keys.
With the fix, the plan after RewritePredicateSubquery
will look like:
Aggregate [a1#0, exists#12] [CASE WHEN exists#12 THEN Hello ELSE Hi END AS strCol#13]
+- Join ExistenceJoin(exists#12), (a1#0 = b1#3)
+- LocalRelation <empty>, [a1#0, a2#1, a3#2]
+- LocalRelation <empty>, [b1#3, b2#4, b3#5]
NOTE: It is still possible to manually construct ExistenceJoin (e.g via dsl) and an Aggregate on top of it that violate the condition.
Does this PR introduce any user-facing change?
No
How was this patch tested?
Query tests
@jchen5 @cloud-fan
I think this approach makes sense and works, but have we considered these alternative approaches:
-
I think it could also be fixed it by pulling the in/exists expression into a Project node above the Aggregate prior to the RewritePredicateSubquery. If we do that, the existence join evaluation happens after the aggregate, which would often be better for performance. And it also seems like a simpler query plan.
-
I think another possibility is wrapping the output expression with any_value aggregate.
Adding the exists to the group-by seems slightly worse because it is essentially redundant (because it will only have one value given the rest of the group-by columns), and it looks like it would have some minor performance downsides - making the hash group-by keys larger, and perhaps making optimizer analysis of the group-by harder and limiting other optimizations in some cases.
That said, I'm fine with current approach overall because the differences mentioned above don't seem like big deals and fixing the bug is the more important thing. Just wanted to bring it up and discuss which approach would be better.
@jchen5 I considered the two other suggested approaches. I agree that performance-wise both are probably better, although it doesn't seem critical. However, I couldn't yet get a quick clean implementation; I propose to merge as for now. I've created a ticket to further look into it: https://issues.apache.org/jira/browse/SPARK-47171
@cloud-fan could please take a look at this PR? thanks!
UPD: the current proposed approach actually looks incorrect to me. It fails for queries like (adds extra group by condition, and thus returns 2 rows):
SELECT
sum(salary),
sum(salary) FILTER (WHERE EXISTS (SELECT 1
FROM dept
WHERE emp.dept_id = dept.dept_id))
FROM emp;
Will work on fixing this.
Yeah, in the prior examples it worked because the group-by columns will determine the value of the exists - that was only true when the aggregates did not use any non-group-by columns of the table. When we do have an aggregate using a non-grouped column of the outer table it wouldn't work.
This aggregate doesn't require any changes to avoid the invalid aggregation error, right? I think this is another reason to go towards something like wrapping expressions in any_value aggregate when needed.
Changed the approach to wrap into max(). Ideally we'd use any_value(), but it is not working well in Spark.
@cloud-fan
Semantically, we should wrap in any_value(), but any_value() throws RuntimeReplaceableAggregate.aggBufferAttributes should not be called and is not fully supported
It's because any_value
is not executable and need to be rewritten with ReplaceExpressions
. We can use its executable version, the first
function.
@cloud-fan updated the code & pr desc!
thanks, merging to master!