arrow icon indicating copy to clipboard operation
arrow copied to clipboard

GH-35166: [C++] Increase precision of decimals in aggregate functions

Open khwilson opened this issue 1 year ago • 17 comments

Rationale for this change

As documented in #35166, when Arrow performs a sum, product, or mean of an array of type decimalX(P, S), it returns a scalar of type decimalX(P, S). This is true even if the aggregate does not fit in the specified precision. For instance, a sum of two decimal128(1, 0)'s such as 1 + 9 is a decimal128(2, 0). But (in Python):

import pyarrow as pa
from decimal import Decimal

arr = pa.array([Decimal("1"), Decimal("9")], type=pa.decimal128(1, 0))
assert arr.sum().type == pa.decimal128(1, 0)

This is recognized in the rules for binary addition and multiplication of decimals (see footnote 1 in this section), but this does not apply to array aggregates.

In #35166 I did a bit of research following a question from @westonpace , and it seems that there's no standard approach to this across DBMS's, but a common solution is to set the precision of the result of a sum to the maximum possible precision of the underlying type. That is, a sum of decimal128(1, 0)'s becomes a decimal128(38, 0).

However, products and means differ further. For instance, in both instances, duckdb converts a decimal to a double, which makes sense as the precision of the product of an array of decimals would likely be huge, e.g., an array of size N with precision 2 decimals would have precision at least 2^N.

This PR implements the minimum possible change: replace all return types of a product, sum, or mean aggregate of decimal128(P, S) to decimal128(38, S) and decimal256(P, S) to decimal256(76, S).

Please note, this PR is not done (see the checklist below), as it would be good to get feedback on the following before going through the whole checklist:

  • Is this the correct change to make (especially for products and means); and
  • The implementation relies on overriding "out types," which is how the current mean implementation works, but perhaps there's a better way to approach this.

What changes are included in this PR?

  • [x] Update C++ kernels to support the change
  • [x] Update docs to reflect the change
  • [x] Fix tests in the languages that depend on the C++ engine
  • [x] Determine if there are other languages which do not depend on the C++ engine which should also be updated

Are these changes tested?

They are tested in the following implementations:

  • [x] Python
  • [x] C++
  • [ ] ~~Java~~ Java does not currently implement any dependent tests

Are there any user-facing changes?

Yes. This changes the return type of a scalar aggregate of decimals.

This PR includes breaking changes to public APIs.

Specifically, the return type of a scalar aggregate of decimals changes. This is unlikely to break downstream applications as the underlying data has not changed, but if an application relies on the (incorrect!) type information for some reason, it would break.

  • GitHub Issue: #35166

khwilson avatar Sep 23 '24 02:09 khwilson

:warning: GitHub issue #35166 has been automatically assigned in GitHub to PR creator.

github-actions[bot] avatar Sep 23 '24 02:09 github-actions[bot]

@zeroshade I noticed in #43957 that you were adding in Decimal32/64 types, which I think will have the same problem that this PR addresses. I was curious if you might have interest in reviewing this PR?

khwilson avatar Sep 30 '24 14:09 khwilson

@khwilson Sure thing, i'll try to take a look at this in the next day or so

zeroshade avatar Sep 30 '24 15:09 zeroshade

Hi @zeroshade just checking in! Thanks again for taking a look

khwilson avatar Oct 08 '24 23:10 khwilson

Thanks for the review!

By a user-side cast, do you mean that users should essentially do:

select avg(cast(blah as decimal(big-precision)))

instead of

select avg(blah)

or do you mean that this code should "inject" a cast on the "user" side?

If you mean putting the cast onto the user, then I would think you'd want to add an error if the answer can't fit into the default precision, but that seems like it would be more disruptive (and out of step with how other systems handle decimal aggregates).

If you mean "injecting" the cast on the user side, would that end up creating a copy of the array?

khwilson avatar Oct 09 '24 13:10 khwilson

@khwilson Hey sorry for the delay from me here, I've been traveling a lot lately for work and have been at ASF Community Over Code this week. I promise i'll get to this soon. In the meantime, you're in the very capable hands of @mapleFU

zeroshade avatar Oct 09 '24 16:10 zeroshade

No problem! Hope your travels were fun!

khwilson avatar Oct 09 '24 18:10 khwilson

Generally this method is ok for me, but I'm not so familiar with the "common solutions" here. I'll dive into Presto/ClickHouse to see the common pattern here

mapleFU avatar Oct 11 '24 04:10 mapleFU

I enumerated several here: https://github.com/apache/arrow/issues/35166#issuecomment-2336776704

Clickhouse for instance just ignores precision.

khwilson avatar Oct 11 '24 12:10 khwilson

Would you mind making this Ready for review?

mapleFU avatar Oct 11 '24 12:10 mapleFU

Sure!

khwilson avatar Oct 11 '24 13:10 khwilson

@mapleFU I believe this is done now. Some notes on the diff:

  • The hash aggregates had to be updated as well (missed them in the first pass)
  • I've also added in Decimal32/64 support the basic aggregates (sum, product, mean, min/max, index). However, there's quite a lot of missing support for these types still in compute (most notably in casts)
  • Docs are updated to reflect the change

And a note that quite a few tests are failing for what appears to be the same reason as #41390. Happy to address them if you'd like.

khwilson avatar Oct 13 '24 20:10 khwilson

I'm lukewarm about the approach here. Silently casting to the max precision discards metadata about the input; it also risks producing errors further down the line (if e.g. the max precision is deemed too large for other operations). It also doesn't automatically eliminate any potential overflow, for example:

>>> a = pa.array([789.3] * 20).cast(pa.decimal128(38, 35))
>>> a
<pyarrow.lib.Decimal128Array object at 0x7f0f103ca7a0>
[
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440,
  789.29999999999995452526491135358810440
]
>>> pc.sum(a)
<pyarrow.Decimal128Scalar: Decimal('-1228.11834604692408266343214451664848480')>

We should instead check that the result of an aggregate fits into the resulting Decimal type, while overflows currently pass silently:

>>> a = pa.array([123., 456., 789.]).cast(pa.decimal128(4, 1))
>>> a
<pyarrow.lib.Decimal128Array object at 0x7f0ed06261a0>
[
  123.0,
  456.0,
  789.0
]
>>> pc.sum(a)
<pyarrow.Decimal128Scalar: Decimal('1368.0')>
>>> pc.sum(a).validate(full=True)
Traceback (most recent call last):
  ...
ArrowInvalid: Decimal value 13680 does not fit in precision of decimal128(4, 1)

pitrou avatar Oct 15 '24 16:10 pitrou

Two problems with just validating afterward: First, I'd expect in reasonable cases for the validation to fail. A sum of 1m decimals of approximately the same size you'd expect to have 6 more digits of precision. I assume this is why all the DBMSs I looked at increase the precision by default.

Second, just checking for overflow doesn't solve the underlying problem. Consider:

a = pa.array([789.3] * 18).cast(pa.decimal128(38, 35))
print(pc.sum(a))
pc.sum(a).validate(full=True)  # passes

In duckdb, they implement an intermediate check to make sure that there's not an internal overflow:

tab = pa.Table.from_pydict({"a": a})
duckdb.query("select sum(a) from tab")
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
# duckdb.duckdb.OutOfRangeException: Out of Range Error: Overflow in HUGEINT addition: 
# 157859999999999990905052982270717620880 + 78929999999999995452526491135358810440

Notably, this lack of overflow checking also applies to integer sums in arrow:

>>> pa.array([9223372036854775800] * 2, type=pa.int64())
<pyarrow.lib.Int64Array object at 0x10c1d8b80>
[
  9223372036854775800,
  9223372036854775800
]
>>> pc.sum(pa.array([9223372036854775800] * 2, type=pa.int64()))
<pyarrow.Int64Scalar: -16>
>>> pc.sum(pa.array([9223372036854775800] * 2, type=pa.int64())).validate(full=True)

khwilson avatar Oct 15 '24 19:10 khwilson

Two problems with just validating afterward: First, I'd expect in reasonable cases for the validation to fail. A sum of 1m decimals of approximately the same size you'd expect to have 6 more digits of precision.

It depends obviously if all decimals are of the same sign, and what their actual magnitude is.

Second, just checking for overflow doesn't solve the underlying problem.

In the example above, I used a validate call simply to show that the result was indeed erroneous. I didn't mean we should actually call validation afterwards. We should instead check for overflow at each individual aggregation step (for each add or multiply, for example). This is required even if we were to bump the result's precision to the max.

pitrou avatar Oct 15 '24 20:10 pitrou

Notably, this lack of overflow checking also applies to integer sums in arrow:

Yes, and there's already a bug open for it: https://github.com/apache/arrow/issues/37090

pitrou avatar Oct 15 '24 20:10 pitrou

Nice! I'm excited for the checked variants of sum and product!

With the integer overflow example, I only meant to point out that the compute module currently allows overflows, so I think it would be unexpected for sum to complain about an overflow only if the underlying type was a decimal. But if the goal with #37090 is to replace sum with a checked version, then the solution of erroring makes a lot of sense, and I'd be happy to implement it when #37536 gets merged. :-)

Still, I do think that users would find it unexpected to get an error if the sum fit in the underlying storage since this is how all the databases I've used (and the four I surveyed in #35166) have operated.

khwilson avatar Oct 16 '24 02:10 khwilson

Hi @pitrou sorry for dropping the ball on this a bit earlier. Are you still interested in merging this at some point? I'm happy to explore some alternative options for how to do the decimal expansion, or alternatively am happy to push #37536 forward

khwilson avatar Nov 23 '24 17:11 khwilson

Hey @pitrou i was thinking about this PR again recently. Would you still be interested in merging if I cleaned it up a bit?

khwilson avatar Apr 24 '25 22:04 khwilson

@khwilson I'm still not sure this is actually desirable (@zanmato1984 what do you think?).

What I think is definitely desirable is to check for overflow when doing sum/mean/product of decimals, regardless of which data type is selected for output.

pitrou avatar Apr 30 '25 13:04 pitrou

@khwilson I'm still not sure this is actually desirable (@zanmato1984 what do you think?).

Most DBMSes do precision promotion for sum aggregation, and most promotions are arbitrary. So I think the idea of this issue/PR makes very much sense. I will look further into the implementation later.

What I think is definitely desirable is to check for overflow when doing sum/mean/product of decimals, regardless of which data type is selected for output.

Agreed, checked versions are necessary, though imo, independent of precision promotion.

zanmato1984 avatar Apr 30 '25 22:04 zanmato1984

Cool. I’m happy to do some more clean up or do a little work on creating the checked versions of the aggregates to push this along

khwilson avatar May 01 '25 18:05 khwilson

Cool. I’m happy to do some more clean up or do a little work on creating the checked versions of the aggregates to push this along

Thank you for the contribution and passion. How about we keep this PR focused solely on the precision promotion? And the potential checked version can be done in another PR.

zanmato1984 avatar May 01 '25 18:05 zanmato1984

That makes sense. From a sequencing stand point, should the checked versions come first or should we get this through first?

khwilson avatar May 01 '25 18:05 khwilson

I think they are well independent so feel free to start anytime (reviewers may take them one by one though). And thank you for willing to take this!

zanmato1984 avatar May 01 '25 18:05 zanmato1984

OK, I think I've correctly rebased (letting all the tests run now to figure out).

I can remove the Decimal32/64 support once I check if I've rebased correctly.

Will also look at the inline comments once the rebase is completed.

khwilson avatar May 06 '25 03:05 khwilson

OK @zanmato1984 it took a bit given the way that everything gets set. Updates:

  • Removed all the Decimal32/64 logic (except in WidenDecimalToMaxPrecision
  • Introduced a PromoteDecimal template variable to determine whether a SumLike should actually promote the decimal
  • Reverted the Mean and Product promotions as we're not agreed on how to handle those.

As for the mean and product, there's a few actions we could take:

  • Status quo: Input type is the same as output type
  • Promote both to maximum precision
  • Promote only product to max precision: abs(mean) <= max(abs(max), abs(min)) so not really necessary for mean
  • Promote only product to a double: this is what, e.g., duckdb does

I think the best course of action is actually the fourth choice: Promote only product to double and leave mean the same as the input type. Thoughts?

khwilson avatar May 11 '25 16:05 khwilson

As for the mean and product, there's a few actions we could take: * Status quo: Input type is the same as output type * Promote both to maximum precision * Promote only product to max precision: abs(mean) <= max(abs(max), abs(min)) so not really necessary for mean * Promote only product to a double: this is what, e.g., duckdb does

"Promoting" to double is a demotion: it loses a lot of potential precision. The mean should certainly stay the same as the input type (and care should be taken to avoid overflows during intermediate computations, if at all possible).

As for the product, it would also be better to not lose too much precision, and therefore stay in the decimal domain. But it sounds less common to compute the product of a bunch of decimals, anyway.

pitrou avatar May 11 '25 16:05 pitrou

If I understand the code correctly, what happens for the mean for decimals is that it first computes the sum by using the underlying + of the Decimal type, then computes the count of values as a long, and does an integer division of the two. As seen in this commit, that + operator essentially ignores precision. I think this is a reasonable implementation except:

  • There should be a checked variant;
  • Potentially, you'd want to add some extra scale, in line with the binary division operation.

For comparison, I checked how a few other DBMS's handle the mean:

  • postgresql: Drops all precision and scale information in mean
  • duckdb: Turns into a double
  • mariadb: Appears to just add 4 to scale and precision

For product, it's not very common for databases to provide a product aggregator of any kind. I quickly looked at ANSI SQL, postgres, mysql, and clickhouse, none of which provide a product. The "standard" product implementation for such aggregates is (2 * sum(col % 2) - 1) * exp(sum(ln(abs(col))), but that requires transforming into a double. Moreover, basically any column with at least 100 values will overflow a decimal's precision.

So I'd propose for product either demoting to double (will give a semantically correct value but lose precision) or maxing out the width of the decimal type. Of course, there should also be a checked variant (but I'm deferring that until after this commit is complete).

khwilson avatar May 13 '25 01:05 khwilson

If I understand the code correctly, what happens for the mean for decimals is that it first computes the sum by using the underlying + of the Decimal type, then computes the count of values as a long, and does an integer division of the two. As seen in this commit, that + operator essentially ignores precision. I think this is a reasonable implementation except: * There should be a checked variant; * Potentially, you'd want to add some extra scale, in line with the binary division operation.

Ideally, the scale should be "floating" just as in floating-point arithmetic, depending on the current running sum (the running sum can be very large if all data is positive, or very small if the data is centered around zero). It is then normalized to the original scale at the end. But of course that makes the algorithm more involved.

(that would also eliminate the need for a checked variant?)

So I'd propose for product either demoting to double (will give a semantically correct value but lose precision) or maxing out the width of the decimal type. Of course, there should also be a checked variant (but I'm deferring that until after this commit is complete).

Either is fine to me.

pitrou avatar May 13 '25 09:05 pitrou