jax icon indicating copy to clipboard operation
jax copied to clipboard

Refactored common upcast for integral-type accumulators

Open Micky774 opened this issue 1 year ago • 1 comments

Towards https://github.com/google/jax/issues/20200

This PR refactors a few instances of numpy-style integral accumulator selective upcasting into a common jax.numpy.util function _promote_to_default_integral_dtype, and removes unnecessary casts in the jax.experimental.array_api namespace, since all jax.numpy reductions handle it by default.

Temporarily disables tests for prod, sum, and trace since the 2023 API included breaking changes which are not yet accounted for in the tests repository.

Micky774 avatar Apr 19 '24 19:04 Micky774

@jakevdp should be ready for review

Micky774 avatar Apr 25 '24 19:04 Micky774

Could you sync against the updated main branch and resolve conflicts? Thanks!

jakevdp avatar May 02 '24 19:05 jakevdp

Synced!

Micky774 avatar May 02 '24 20:05 Micky774

It looks like this breaks some tests in lax_numpy_reducers_test.py. You should be able to repro by running

JAX_NUM_GENERATED_CASES=90 pytest -n auto tests/lax_numpy_reducers_test.py -k testCumulativeSum

It looks like this somehow changed the behavior for integer inputs narrower than int32.

jakevdp avatar May 02 '24 22:05 jakevdp

It looks like the core issue is that we don't allow bool dtypes for reductions due to limiting add and mul primitives to non-boolean numerical types, and we adjust for this by upcasting bool to int_ before accumulation, but this differs from NumPy behavior where the data is converted to bool but addition ins implemented only as an or operation, rather than a count of True values.

~For now I've explicitly disallowed the use of dtype=bool, since I think adjusting that behavior is actually a bit of a deeper change.~

Update: I've set it to keep bool-->int_ upcast for accumulation and then downcast back to bool. Should match NumPy and satisfy array API dtype behavior as well.

Micky774 avatar May 03 '24 20:05 Micky774

@jakevdp should be good for another look now

Micky774 avatar May 03 '24 21:05 Micky774

Good point. I've changed it so that we make a helper reduction _cumsum_with_promotion (as opposed to the usual cumsum), preserving the API for the other cumulative reductions.

Micky774 avatar May 03 '24 21:05 Micky774