jax
jax copied to clipboard
Update `jax.experimental.array_api` to v2023.12 API
This issue tracks the changes necessary to adopt the v2023.12 Array API. This was originally mentioned in https://github.com/google/jax/issues/18353. Note that there may be some specifications that we already satisfy, however the vast majority of these will need alterations.
API Updates
- [x] https://github.com/google/jax/pull/20198 | spec
- [x] https://github.com/google/jax/pull/20175 | spec
- [x] https://github.com/google/jax/pull/20194 | spec
- [x] https://github.com/google/jax/pull/20195 | spec
New API
- [x] https://github.com/google/jax/pull/20294 | spec
- [x] https://github.com/google/jax/pull/20753
- [x] https://github.com/google/jax/pull/20550 | spec
- [x] https://github.com/google/jax/pull/20755 | spec
- [x] https://github.com/google/jax/pull/20754 | spec
- [x] https://github.com/google/jax/pull/20756 | spec
- [x] https://github.com/google/jax/pull/20954
Breaking Changes
For specific details on what has changed, look on their specification pages for "Changed in version 2023.12 ..."
Common Utilities Refactor
- [ ] Device placement + copy semantics
- Initial draft can be found here
- Get
Device
fromDevice | Sharding
Thanks for adding this! One overall note: our eventual goal is to remove jax.experimental.array_api
and just use the jax.numpy
namespace directly. So as much as possible, we should aim for API functions in jax.experimental.array_api
to not have any additional logic beyond just calling the jax.numpy
counterpart.
A note on default_device
(cf. a discussion with @yashk2810). JAX doesn't really have a concept of "default device" in the way that the Array API envisions it. By default, arrays are created uncommitted, so the only way to write a consistent default_device
function would be for it to return "uncommitted". Currently most (if not all) functions that accept a device parameter have a default value None
, so our default_device
function should probably look like this:
def default_device():
return None
and be documented appropriately. That's the only way for, e.g. jax.device_put(x, device=jnp.default_device())
to have the same behavior as jax.device_put(x)
, which seems like a sensible requirement for the concept of a default!