Tracking bug: Integration of Array in JAX.
-
[ ] Add tests for _committed attribute in Array.
-
[ ] Figure out how to resolve the local imports like (from jax.experimental import pjit) that are in dispatch.py.
-
[x] Make testArrayCopy in lax_numpy_test.py work with Array. There is a TODO in that test with this bug number.
-
[x] Make xmap work properly with SingleDeviceSharding Arrays or just make the behavior of xmap similar to what it was with DA and SDA.
-
[x] Implement no copy indexing similar to SDA i.e. getitem of SDA
-
[x] Make sure that non-committed Arrays on SingleDevice (for eg. jnp.arange(10) with a mesh decorator are sharded properly in pjit and xmap.
-
[ ] Add a similar check to non-XLACompatibleShardings that exists in the constructor of Array (there is a TODO there).
I think this would be much better using the GitHub projects / boards feature :)
I have moved other things to this board: https://github.com/orgs/google/projects/25/views/1 but I am keeping this bug alive for the things I am tracking here.
@yashk2810 Unfortunately, https://github.com/orgs/google/projects/25/views/1 seems not accessible outside Google (i.e. to the public) -- I suppose this is a JAX-specific board; why don't we have Github projects/boards within google/jax?
Because the boards can only be created at the project IIUC. If there is a different solution, I can use that :)