jax
jax copied to clipboard
Wrong result for unsigned dtype input into `jax.numpy.partition` and `jax.numpy.argpartition`
Description
partition
and argpartition
return the wrong results for unsigned dtype arrays, and is probably caused by the following line: bottom_ind = lax.top_k(-arr, kth + 1)[1]
.
POC:
>>> from jax import numpy as np
>>> np.partition(np.asarray([0,1], dtype=np.uint8), 1)
Array([1, 0], dtype=uint8)
>>> np.partition(np.asarray([0,1], dtype=np.int8), 1)
Array([0, 1], dtype=int8)
>>> np.argpartition(np.asarray([0,1], dtype=np.uint8), 1)
Array([1, 0], dtype=int32)
>>> np.argpartition(np.asarray([0,1], dtype=np.int8), 1)
Array([0, 1], dtype=int32)
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.30
jaxlib: 0.4.30
numpy: 2.0.0
python: 3.9.19 (main, May 6 2024, 20:12:36) [MSC v.1916 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Yeetus-Screetus', release='10', version='10.0.19045', machine='AMD64')