[SPARK-46743] Count bug after constant folding
What changes were proposed in this pull request?
This covers a corner case in the COUNT bug handling. Right now it is split across two rules (PullupCorrelatedPredicates and RewriteCorrelatedScalarSubquery), where the first marks potential COUNT bug subqueries, and the second performs more accurate detection. Both of them rely on the fact that Aggregate remains at the top of the subquery, which is usually a safe assumption. However, when the subquery can be constant folded, the aggregate gets replaced with the project and the second part of COUNT bug handling falls through.
An example when it happens: https://issues.apache.org/jira/browse/SPARK-46743 -- involves a temp view, which gets inlined and allows us to constant fold the subquery. (Therefore, replacing the temp view with an actual table makes the query return correct results).
This PR makes sure that the Aggregate always remains on top of the subquery body until the RewriteCorrelatedScalarSubquery rule (we later still run constant folding, so the constant aggregates would be folded away at a later point).
Why are the changes needed?
Correctness bug. See the reasoning above.
Does this PR introduce any user-facing change?
Incorrect results become fixed
How was this patch tested?
Query test
Was this patch authored or co-authored using generative AI tooling?
No
@jchen5
What about if there's another node above the aggregate in the subquery, such as a filter after the aggregate (having clause)?
added a test, but any non-trivial node about the aggregate (such as filter) results in having a DomainJoin, so constant folding does not kick in.
Thanks for the fix, looks good overall.
Let's add a gating flag for this change just in case of any issues.
added a flag
@cloud-fan
thanks, merging to master!