jax
jax copied to clipboard
WIP `jax.numpy` array API compliance finalization
Towards https://github.com/google/jax/issues/21088
Changes
- Adds the following attributes to
ArrayImpl:__array_namespace__propertyto_devicemethoddeviceproperty
- Adds the following to
jax.numpy:__array_namespace_info____array_api_version__
Notes
This PR is a draft right now since we should include these changes last so as to publicly support jax.numpy as an array API compliant namespace. @vfdev-5 can take over this PR later once the remainder of the work is completed.
This does not need to wait on all the ongoing array API related deprecations to be completed, since some of them are only required for the 2023 standard, hence we can likely adopt the 2022 standard first.
It may make sense break off the to_device and device changes for ArrayImpl into a small separate PR, since they don't imply explicit compliance by themselves, but I wanted to keep them together in this PR in case there were any caveats wrt to ArrayImpl vs Tracer behaviors that we should discuss first (based on old TODO note).
We only finally removed the arr.device() method in JAX v0.4.27 – to avoid confusion for users I think we should wait for one more release (0.4.29) before we add the arr.device property, so that it will be part of 0.4.30. What do you think?
@Micky774 let me close this PR as Array API compliance is already finalized on main.