jax
jax copied to clipboard
Refactored common upcast for integral-type accumulators
Towards https://github.com/google/jax/issues/20200
This PR refactors a few instances of numpy-style integral accumulator selective upcasting into a common jax.numpy.util function _promote_to_default_integral_dtype, and removes unnecessary casts in the jax.experimental.array_api namespace, since all jax.numpy reductions handle it by default.
Temporarily disables tests for prod, sum, and trace since the 2023 API included breaking changes which are not yet accounted for in the tests repository.
@jakevdp should be ready for review
Could you sync against the updated main branch and resolve conflicts? Thanks!
Synced!
It looks like this breaks some tests in lax_numpy_reducers_test.py. You should be able to repro by running
JAX_NUM_GENERATED_CASES=90 pytest -n auto tests/lax_numpy_reducers_test.py -k testCumulativeSum
It looks like this somehow changed the behavior for integer inputs narrower than int32.
It looks like the core issue is that we don't allow bool dtypes for reductions due to limiting add and mul primitives to non-boolean numerical types, and we adjust for this by upcasting bool to int_ before accumulation, but this differs from NumPy behavior where the data is converted to bool but addition ins implemented only as an or operation, rather than a count of True values.
~For now I've explicitly disallowed the use of dtype=bool, since I think adjusting that behavior is actually a bit of a deeper change.~
Update: I've set it to keep bool-->int_ upcast for accumulation and then downcast back to bool. Should match NumPy and satisfy array API dtype behavior as well.
@jakevdp should be good for another look now
Good point. I've changed it so that we make a helper reduction _cumsum_with_promotion (as opposed to the usual cumsum), preserving the API for the other cumulative reductions.