jax icon indicating copy to clipboard operation
jax copied to clipboard

Tracking issue: support Array API

Open jakevdp opened this issue 8 months ago • 0 comments

Goal: make JAX support https://data-apis.org/array-api/latest/

Related to #19246

TODO

  1. Initial Implementation

    • [x] Add initial implementation in jax.experimental.array_api #16099
    • [x] Add CI test based on https://github.com/data-apis/array-api-tests #16099
    • [x] Add smoketest for normal CI runs #18685
    • [ ] Enable fft_tests (requires waiting on upstream test fixes)
  2. JAX API fixes

    • [x] Add JAX support for scalar boolean indexing #19722 #21305
    • [x] Fix NaN identity issue within unique #19090
    • [x] Add descending argument to sort and argsort #19201
  3. Make jax.Array conform to the API spec

    • [x] Deprecate device() method #18730
    • [ ] Add device property (after device() method is removed; ~March 2024)
    • [ ] Add to_device() method
    • [ ] Add device keyword to zeros, ones, arange, etc. (#19445, #19466, #19470, #19504)
    • [ ] Add __array_namespace__ property
  4. Add Array API functions to the standard jax.numpy namespace

    • [x] jnp.bool #19403
    • [x] jnp.isdtype #19400
    • [x] jnp.astype #18757
    • [x] unique_all, unique_counts, unique_inverse, unique_values #19088
    • [x] concat #19323
    • [x] permute_dims #19244
    • [x] acos, acosh, asin, asinh, atan, atanh, atan2 #19054
    • [x] bitwise_left_shift, bitwise_right_shift, bitwise_invert #19278
    • [x] copy keyword argument for jnp.asarray #19186
    • [x] jnp.linalg:
      • [x] diagonal #19321
      • [x] cross #18928
      • [x] matmul #19042
      • [x] matrix_norm #19005
      • [x] matrix_transpose #19005
      • [x] outer #18928
      • [x] svdvals #19042
      • [x] tensordot #19042
      • [x] vecdot #19005
      • [x] vector_norm #19005
      • [x] eigh returns NamedTuple #19347
      • [x] qr returns NamedTuple #19347
      • [x] slogdet returns NamedTuple #19347
      • [x] svd returns NamedTuple #19347
      • [x] cholesky upper argument #19606
      • [x] solve vectorization update #19674
  5. Update to v2023.12 APIs and behavior (see changelog)

  6. Consider removing jax.experimental.array_api and make jax.numpy itself fully-compliant with the array API.

jakevdp avatar Nov 02 '23 15:11 jakevdp