flax icon indicating copy to clipboard operation
flax copied to clipboard

Add nnx.shard_map

Open cgarciae opened this issue 1 year ago • 1 comments

cgarciae avatar Oct 07 '24 17:10 cgarciae

Thanks @cgarciae :rocket:

TODO Then we can update the Transforms guide

8bitmp3 avatar Oct 08 '24 12:10 8bitmp3

any updates?

carlesoctav avatar Dec 24 '24 03:12 carlesoctav

@cgarciae This is P0 to me and I imagine to many others too. Let me know if I can do anything to help this along.

marcelroed avatar Dec 29 '24 01:12 marcelroed

@marcelroed thanks for the ping. Will try to get it in soon, all the pieces are ready.

cgarciae avatar Dec 29 '24 03:12 cgarciae

In the meantime, are there any issues with doing

from jax.experimental.shard_map import shard_map

@partial(shard_map, mesh=mesh, in_specs=..., out_specs=...)
def forward(graph_def, state, x):
    model = nnx.merge(graph_def, state)
    return model(x)

other than mutations not reaching outside of JIT?

marcelroed avatar Dec 29 '24 14:12 marcelroed

@marcelroed using split / merge is perfectly valid. If you want to propagate state updates you can also return the new state and use nnx.update outside.

cgarciae avatar Dec 31 '24 03:12 cgarciae

After a longer than necessary wait, nnx.shard_map is now live!

cgarciae avatar Feb 27 '25 15:02 cgarciae