jax icon indicating copy to clipboard operation
jax copied to clipboard

Wrong result for unsigned dtype input into `jax.numpy.partition` and `jax.numpy.argpartition`

Open JuliaPoo opened this issue 1 week ago • 0 comments

Description

partition and argpartition return the wrong results for unsigned dtype arrays, and is probably caused by the following line: bottom_ind = lax.top_k(-arr, kth + 1)[1].

POC:

>>> from jax import numpy as np
>>> np.partition(np.asarray([0,1], dtype=np.uint8), 1) 
Array([1, 0], dtype=uint8)
>>> np.partition(np.asarray([0,1], dtype=np.int8), 1)  
Array([0, 1], dtype=int8)
>>> np.argpartition(np.asarray([0,1], dtype=np.uint8), 1) 
Array([1, 0], dtype=int32)
>>> np.argpartition(np.asarray([0,1], dtype=np.int8), 1)  
Array([0, 1], dtype=int32)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0
python: 3.9.19 (main, May  6 2024, 20:12:36) [MSC v.1916 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Yeetus-Screetus', release='10', version='10.0.19045', machine='AMD64')

JuliaPoo avatar Jun 27 '24 10:06 JuliaPoo