GH-35166: [C++] Increase precision of decimals in aggregate functions
Rationale for this change
As documented in #35166, when Arrow performs a sum, product, or mean of an array of type decimalX(P, S), it returns a scalar of type decimalX(P, S). This is true even if the aggregate does not fit in the specified precision. For instance, a sum of two decimal128(1, 0)'s such as 1 + 9 is a decimal128(2, 0). But (in Python):
import pyarrow as pa
from decimal import Decimal
arr = pa.array([Decimal("1"), Decimal("9")], type=pa.decimal128(1, 0))
assert arr.sum().type == pa.decimal128(1, 0)
This is recognized in the rules for binary addition and multiplication of decimals (see footnote 1 in this section), but this does not apply to array aggregates.
In #35166 I did a bit of research following a question from @westonpace , and it seems that there's no standard approach to this across DBMS's, but a common solution is to set the precision of the result of a sum to the maximum possible precision of the underlying type. That is, a sum of decimal128(1, 0)'s becomes a decimal128(38, 0).
However, products and means differ further. For instance, in both instances, duckdb converts a decimal to a double, which makes sense as the precision of the product of an array of decimals would likely be huge, e.g., an array of size N with precision 2 decimals would have precision at least 2^N.
This PR implements the minimum possible change: replace all return types of a product, sum, or mean aggregate of decimal128(P, S) to decimal128(38, S) and decimal256(P, S) to decimal256(76, S).
Please note, this PR is not done (see the checklist below), as it would be good to get feedback on the following before going through the whole checklist:
- Is this the correct change to make (especially for products and means); and
- The implementation relies on overriding "out types," which is how the current mean implementation works, but perhaps there's a better way to approach this.
What changes are included in this PR?
- [x] Update C++ kernels to support the change
- [x] Update docs to reflect the change
- [x] Fix tests in the languages that depend on the C++ engine
- [x] Determine if there are other languages which do not depend on the C++ engine which should also be updated
Are these changes tested?
They are tested in the following implementations:
- [x] Python
- [x] C++
- [ ] ~~Java~~ Java does not currently implement any dependent tests
Are there any user-facing changes?
Yes. This changes the return type of a scalar aggregate of decimals.
This PR includes breaking changes to public APIs.
Specifically, the return type of a scalar aggregate of decimals changes. This is unlikely to break downstream applications as the underlying data has not changed, but if an application relies on the (incorrect!) type information for some reason, it would break.
- GitHub Issue: #35166
:warning: GitHub issue #35166 has been automatically assigned in GitHub to PR creator.
@zeroshade I noticed in #43957 that you were adding in Decimal32/64 types, which I think will have the same problem that this PR addresses. I was curious if you might have interest in reviewing this PR?
@khwilson Sure thing, i'll try to take a look at this in the next day or so
Hi @zeroshade just checking in! Thanks again for taking a look
Thanks for the review!
By a user-side cast, do you mean that users should essentially do:
select avg(cast(blah as decimal(big-precision)))
instead of
select avg(blah)
or do you mean that this code should "inject" a cast on the "user" side?
If you mean putting the cast onto the user, then I would think you'd want to add an error if the answer can't fit into the default precision, but that seems like it would be more disruptive (and out of step with how other systems handle decimal aggregates).
If you mean "injecting" the cast on the user side, would that end up creating a copy of the array?
@khwilson Hey sorry for the delay from me here, I've been traveling a lot lately for work and have been at ASF Community Over Code this week. I promise i'll get to this soon. In the meantime, you're in the very capable hands of @mapleFU
No problem! Hope your travels were fun!
Generally this method is ok for me, but I'm not so familiar with the "common solutions" here. I'll dive into Presto/ClickHouse to see the common pattern here
I enumerated several here: https://github.com/apache/arrow/issues/35166#issuecomment-2336776704
Clickhouse for instance just ignores precision.
Would you mind making this Ready for review?
Sure!
@mapleFU I believe this is done now. Some notes on the diff:
- The hash aggregates had to be updated as well (missed them in the first pass)
- I've also added in Decimal32/64 support the basic aggregates (sum, product, mean, min/max, index). However, there's quite a lot of missing support for these types still in compute (most notably in casts)
- Docs are updated to reflect the change
And a note that quite a few tests are failing for what appears to be the same reason as #41390. Happy to address them if you'd like.
I'm lukewarm about the approach here. Silently casting to the max precision discards metadata about the input; it also risks producing errors further down the line (if e.g. the max precision is deemed too large for other operations). It also doesn't automatically eliminate any potential overflow, for example:
>>> a = pa.array([789.3] * 20).cast(pa.decimal128(38, 35))
>>> a
<pyarrow.lib.Decimal128Array object at 0x7f0f103ca7a0>
[
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440,
789.29999999999995452526491135358810440
]
>>> pc.sum(a)
<pyarrow.Decimal128Scalar: Decimal('-1228.11834604692408266343214451664848480')>
We should instead check that the result of an aggregate fits into the resulting Decimal type, while overflows currently pass silently:
>>> a = pa.array([123., 456., 789.]).cast(pa.decimal128(4, 1))
>>> a
<pyarrow.lib.Decimal128Array object at 0x7f0ed06261a0>
[
123.0,
456.0,
789.0
]
>>> pc.sum(a)
<pyarrow.Decimal128Scalar: Decimal('1368.0')>
>>> pc.sum(a).validate(full=True)
Traceback (most recent call last):
...
ArrowInvalid: Decimal value 13680 does not fit in precision of decimal128(4, 1)
Two problems with just validating afterward: First, I'd expect in reasonable cases for the validation to fail. A sum of 1m decimals of approximately the same size you'd expect to have 6 more digits of precision. I assume this is why all the DBMSs I looked at increase the precision by default.
Second, just checking for overflow doesn't solve the underlying problem. Consider:
a = pa.array([789.3] * 18).cast(pa.decimal128(38, 35))
print(pc.sum(a))
pc.sum(a).validate(full=True) # passes
In duckdb, they implement an intermediate check to make sure that there's not an internal overflow:
tab = pa.Table.from_pydict({"a": a})
duckdb.query("select sum(a) from tab")
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# duckdb.duckdb.OutOfRangeException: Out of Range Error: Overflow in HUGEINT addition:
# 157859999999999990905052982270717620880 + 78929999999999995452526491135358810440
Notably, this lack of overflow checking also applies to integer sums in arrow:
>>> pa.array([9223372036854775800] * 2, type=pa.int64())
<pyarrow.lib.Int64Array object at 0x10c1d8b80>
[
9223372036854775800,
9223372036854775800
]
>>> pc.sum(pa.array([9223372036854775800] * 2, type=pa.int64()))
<pyarrow.Int64Scalar: -16>
>>> pc.sum(pa.array([9223372036854775800] * 2, type=pa.int64())).validate(full=True)
Two problems with just validating afterward: First, I'd expect in reasonable cases for the validation to fail. A sum of 1m decimals of approximately the same size you'd expect to have 6 more digits of precision.
It depends obviously if all decimals are of the same sign, and what their actual magnitude is.
Second, just checking for overflow doesn't solve the underlying problem.
In the example above, I used a validate call simply to show that the result was indeed erroneous. I didn't mean we should actually call validation afterwards. We should instead check for overflow at each individual aggregation step (for each add or multiply, for example). This is required even if we were to bump the result's precision to the max.
Notably, this lack of overflow checking also applies to integer sums in arrow:
Yes, and there's already a bug open for it: https://github.com/apache/arrow/issues/37090
Nice! I'm excited for the checked variants of sum and product!
With the integer overflow example, I only meant to point out that the compute module currently allows overflows, so I think it would be unexpected for sum to complain about an overflow only if the underlying type was a decimal. But if the goal with #37090 is to replace sum with a checked version, then the solution of erroring makes a lot of sense, and I'd be happy to implement it when #37536 gets merged. :-)
Still, I do think that users would find it unexpected to get an error if the sum fit in the underlying storage since this is how all the databases I've used (and the four I surveyed in #35166) have operated.
Hi @pitrou sorry for dropping the ball on this a bit earlier. Are you still interested in merging this at some point? I'm happy to explore some alternative options for how to do the decimal expansion, or alternatively am happy to push #37536 forward
Hey @pitrou i was thinking about this PR again recently. Would you still be interested in merging if I cleaned it up a bit?
@khwilson I'm still not sure this is actually desirable (@zanmato1984 what do you think?).
What I think is definitely desirable is to check for overflow when doing sum/mean/product of decimals, regardless of which data type is selected for output.
@khwilson I'm still not sure this is actually desirable (@zanmato1984 what do you think?).
Most DBMSes do precision promotion for sum aggregation, and most promotions are arbitrary. So I think the idea of this issue/PR makes very much sense. I will look further into the implementation later.
What I think is definitely desirable is to check for overflow when doing sum/mean/product of decimals, regardless of which data type is selected for output.
Agreed, checked versions are necessary, though imo, independent of precision promotion.
Cool. I’m happy to do some more clean up or do a little work on creating the checked versions of the aggregates to push this along
Cool. I’m happy to do some more clean up or do a little work on creating the checked versions of the aggregates to push this along
Thank you for the contribution and passion. How about we keep this PR focused solely on the precision promotion? And the potential checked version can be done in another PR.
That makes sense. From a sequencing stand point, should the checked versions come first or should we get this through first?
I think they are well independent so feel free to start anytime (reviewers may take them one by one though). And thank you for willing to take this!
OK, I think I've correctly rebased (letting all the tests run now to figure out).
I can remove the Decimal32/64 support once I check if I've rebased correctly.
Will also look at the inline comments once the rebase is completed.
OK @zanmato1984 it took a bit given the way that everything gets set. Updates:
- Removed all the Decimal32/64 logic (except in
WidenDecimalToMaxPrecision - Introduced a
PromoteDecimaltemplate variable to determine whether aSumLikeshould actually promote the decimal - Reverted the Mean and Product promotions as we're not agreed on how to handle those.
As for the mean and product, there's a few actions we could take:
- Status quo: Input type is the same as output type
- Promote both to maximum precision
- Promote only product to max precision:
abs(mean) <= max(abs(max), abs(min))so not really necessary for mean - Promote only product to a double: this is what, e.g., duckdb does
I think the best course of action is actually the fourth choice: Promote only product to double and leave mean the same as the input type. Thoughts?
As for the mean and product, there's a few actions we could take: * Status quo: Input type is the same as output type * Promote both to maximum precision * Promote only product to max precision:
abs(mean) <= max(abs(max), abs(min))so not really necessary for mean * Promote only product to a double: this is what, e.g., duckdb does
"Promoting" to double is a demotion: it loses a lot of potential precision. The mean should certainly stay the same as the input type (and care should be taken to avoid overflows during intermediate computations, if at all possible).
As for the product, it would also be better to not lose too much precision, and therefore stay in the decimal domain. But it sounds less common to compute the product of a bunch of decimals, anyway.
If I understand the code correctly, what happens for the mean for decimals is that it first computes the sum by using the underlying + of the Decimal type, then computes the count of values as a long, and does an integer division of the two. As seen in this commit, that + operator essentially ignores precision. I think this is a reasonable implementation except:
- There should be a checked variant;
- Potentially, you'd want to add some extra scale, in line with the binary division operation.
For comparison, I checked how a few other DBMS's handle the mean:
- postgresql: Drops all precision and scale information in mean
- duckdb: Turns into a double
- mariadb: Appears to just add 4 to scale and precision
For product, it's not very common for databases to provide a product aggregator of any kind. I quickly looked at ANSI SQL, postgres, mysql, and clickhouse, none of which provide a product. The "standard" product implementation for such aggregates is (2 * sum(col % 2) - 1) * exp(sum(ln(abs(col))), but that requires transforming into a double. Moreover, basically any column with at least 100 values will overflow a decimal's precision.
So I'd propose for product either demoting to double (will give a semantically correct value but lose precision) or maxing out the width of the decimal type. Of course, there should also be a checked variant (but I'm deferring that until after this commit is complete).
If I understand the code correctly, what happens for the mean for decimals is that it first computes the sum by using the underlying
+of the Decimal type, then computes the count of values as a long, and does an integer division of the two. As seen in this commit, that+operator essentially ignores precision. I think this is a reasonable implementation except: * There should be a checked variant; * Potentially, you'd want to add some extra scale, in line with the binary division operation.
Ideally, the scale should be "floating" just as in floating-point arithmetic, depending on the current running sum (the running sum can be very large if all data is positive, or very small if the data is centered around zero). It is then normalized to the original scale at the end. But of course that makes the algorithm more involved.
(that would also eliminate the need for a checked variant?)
So I'd propose for product either demoting to double (will give a semantically correct value but lose precision) or maxing out the width of the decimal type. Of course, there should also be a checked variant (but I'm deferring that until after this commit is complete).
Either is fine to me.