jax icon indicating copy to clipboard operation
jax copied to clipboard

WIP `jax.numpy` array API compliance finalization

Open Micky774 opened this issue 1 year ago • 1 comments

Towards https://github.com/google/jax/issues/21088

Changes

  • Adds the following attributes to ArrayImpl:
    • __array_namespace__ property
    • to_device method
    • device property
  • Adds the following to jax.numpy:
    • __array_namespace_info__
    • __array_api_version__

Notes

This PR is a draft right now since we should include these changes last so as to publicly support jax.numpy as an array API compliant namespace. @vfdev-5 can take over this PR later once the remainder of the work is completed.

This does not need to wait on all the ongoing array API related deprecations to be completed, since some of them are only required for the 2023 standard, hence we can likely adopt the 2022 standard first.

It may make sense break off the to_device and device changes for ArrayImpl into a small separate PR, since they don't imply explicit compliance by themselves, but I wanted to keep them together in this PR in case there were any caveats wrt to ArrayImpl vs Tracer behaviors that we should discuss first (based on old TODO note).

Micky774 avatar May 21 '24 01:05 Micky774

We only finally removed the arr.device() method in JAX v0.4.27 – to avoid confusion for users I think we should wait for one more release (0.4.29) before we add the arr.device property, so that it will be part of 0.4.30. What do you think?

jakevdp avatar May 21 '24 03:05 jakevdp

@Micky774 let me close this PR as Array API compliance is already finalized on main.

vfdev-5 avatar Aug 13 '24 12:08 vfdev-5