jax
jax copied to clipboard
Tracking issue: support Array API
Goal: make JAX support https://data-apis.org/array-api/latest/
Related to #19246
TODO
-
Initial Implementation
- [x] Add initial implementation in
jax.experimental.array_api
#16099 - [x] Add CI test based on https://github.com/data-apis/array-api-tests #16099
- [x] Add smoketest for normal CI runs #18685
- [ ] Enable
fft_tests
(requires waiting on upstream test fixes)
- [x] Add initial implementation in
-
JAX API fixes
- [x] Add JAX support for scalar boolean indexing #19722 #21305
- [x] Fix NaN identity issue within
unique
#19090 - [x] Add
descending
argument tosort
andargsort
#19201
-
Make
jax.Array
conform to the API spec- [x] Deprecate
device()
method #18730 - [ ] Add
device
property (afterdevice()
method is removed; ~March 2024) - [ ] Add
to_device()
method - [ ] Add
device
keyword tozeros
,ones
,arange
, etc. (#19445, #19466, #19470, #19504) - [ ] Add
__array_namespace__
property
- [x] Deprecate
-
Add Array API functions to the standard
jax.numpy
namespace- [x]
jnp.bool
#19403 - [x]
jnp.isdtype
#19400 - [x]
jnp.astype
#18757 - [x]
unique_all
,unique_counts
,unique_inverse
,unique_values
#19088 - [x]
concat
#19323 - [x]
permute_dims
#19244 - [x]
acos
,acosh
,asin
,asinh
,atan
,atanh
,atan2
#19054 - [x]
bitwise_left_shift
,bitwise_right_shift
,bitwise_invert
#19278 - [x]
copy
keyword argument forjnp.asarray
#19186 - [x]
jnp.linalg
:- [x]
diagonal
#19321 - [x]
cross
#18928 - [x]
matmul
#19042 - [x]
matrix_norm
#19005 - [x]
matrix_transpose
#19005 - [x]
outer
#18928 - [x]
svdvals
#19042 - [x]
tensordot
#19042 - [x]
vecdot
#19005 - [x]
vector_norm
#19005 - [x]
eigh
returnsNamedTuple
#19347 - [x]
qr
returnsNamedTuple
#19347 - [x]
slogdet
returnsNamedTuple
#19347 - [x]
svd
returnsNamedTuple
#19347 - [x]
cholesky
upper
argument #19606 - [x]
solve
vectorization update #19674
- [x]
- [x]
-
Update to v2023.12 APIs and behavior (see changelog)
-
Consider removing
jax.experimental.array_api
and makejax.numpy
itself fully-compliant with the array API.