jax icon indicating copy to clipboard operation
jax copied to clipboard

Update `jax.experimental.array_api` to v2023.12 API

Open Micky774 opened this issue 4 months ago • 2 comments

This issue tracks the changes necessary to adopt the v2023.12 Array API. This was originally mentioned in https://github.com/google/jax/issues/18353. Note that there may be some specifications that we already satisfy, however the vast majority of these will need alterations.

API Updates

  • [x] https://github.com/google/jax/pull/20198 | spec
  • [x] https://github.com/google/jax/pull/20175 | spec
  • [x] https://github.com/google/jax/pull/20194 | spec
  • [x] https://github.com/google/jax/pull/20195 | spec

New API

  • [x] https://github.com/google/jax/pull/20294 | spec
    • Add capabilities | spec
    • Add default_device | spec
    • Add default_dtypes | spec
    • Add dtypes | spec
    • Add devices | spec
  • [x] https://github.com/google/jax/pull/20753
    • Add copysign | spec
    • Add maximum | spec
    • Add minimum | spec
    • Add moveaxis | spec
    • Add repeat | spec
    • Add searchsorted | spec
    • Add signbit | spec
  • [x] https://github.com/google/jax/pull/20550 | spec
  • [x] https://github.com/google/jax/pull/20755 | spec
  • [x] https://github.com/google/jax/pull/20754 | spec
  • [x] https://github.com/google/jax/pull/20756 | spec
  • [x] https://github.com/google/jax/pull/20954

Breaking Changes

For specific details on what has changed, look on their specification pages for "Changed in version 2023.12 ..."

  • [x] https://github.com/google/jax/pull/20842

Common Utilities Refactor

  • [ ] Device placement + copy semantics
    • Initial draft can be found here
    • Get Device from Device | Sharding

Micky774 avatar Mar 12 '24 17:03 Micky774

Thanks for adding this! One overall note: our eventual goal is to remove jax.experimental.array_api and just use the jax.numpy namespace directly. So as much as possible, we should aim for API functions in jax.experimental.array_api to not have any additional logic beyond just calling the jax.numpy counterpart.

jakevdp avatar Mar 14 '24 14:03 jakevdp

A note on default_device (cf. a discussion with @yashk2810). JAX doesn't really have a concept of "default device" in the way that the Array API envisions it. By default, arrays are created uncommitted, so the only way to write a consistent default_device function would be for it to return "uncommitted". Currently most (if not all) functions that accept a device parameter have a default value None, so our default_device function should probably look like this:

def default_device():
  return None

and be documented appropriately. That's the only way for, e.g. jax.device_put(x, device=jnp.default_device()) to have the same behavior as jax.device_put(x), which seems like a sensible requirement for the concept of a default!

jakevdp avatar Mar 14 '24 16:03 jakevdp