jax icon indicating copy to clipboard operation
jax copied to clipboard

Tracking bug: Integration of Array in JAX.

Open yashk2810 opened this issue 3 years ago • 4 comments

  • [ ] Add tests for _committed attribute in Array.

  • [ ] Figure out how to resolve the local imports like (from jax.experimental import pjit) that are in dispatch.py.

  • [x] Make testArrayCopy in lax_numpy_test.py work with Array. There is a TODO in that test with this bug number.

  • [x] Make xmap work properly with SingleDeviceSharding Arrays or just make the behavior of xmap similar to what it was with DA and SDA.

  • [x] Implement no copy indexing similar to SDA i.e. getitem of SDA

  • [x] Make sure that non-committed Arrays on SingleDevice (for eg. jnp.arange(10) with a mesh decorator are sharded properly in pjit and xmap.

  • [ ] Add a similar check to non-XLACompatibleShardings that exists in the constructor of Array (there is a TODO there).

yashk2810 avatar Aug 19 '22 17:08 yashk2810

I think this would be much better using the GitHub projects / boards feature :)

mjsML avatar Aug 21 '22 19:08 mjsML

I have moved other things to this board: https://github.com/orgs/google/projects/25/views/1 but I am keeping this bug alive for the things I am tracking here.

yashk2810 avatar Sep 08 '22 19:09 yashk2810

@yashk2810 Unfortunately, https://github.com/orgs/google/projects/25/views/1 seems not accessible outside Google (i.e. to the public) -- I suppose this is a JAX-specific board; why don't we have Github projects/boards within google/jax?

wookayin avatar Sep 19 '22 18:09 wookayin

Because the boards can only be created at the project IIUC. If there is a different solution, I can use that :)

yashk2810 avatar Sep 19 '22 18:09 yashk2810