Add nnx.shard_map
Thanks @cgarciae :rocket:
TODO Then we can update the Transforms guide
any updates?
@cgarciae This is P0 to me and I imagine to many others too. Let me know if I can do anything to help this along.
@marcelroed thanks for the ping. Will try to get it in soon, all the pieces are ready.
In the meantime, are there any issues with doing
from jax.experimental.shard_map import shard_map
@partial(shard_map, mesh=mesh, in_specs=..., out_specs=...)
def forward(graph_def, state, x):
model = nnx.merge(graph_def, state)
return model(x)
other than mutations not reaching outside of JIT?
@marcelroed using split / merge is perfectly valid. If you want to propagate state updates you can also return the new state and use nnx.update outside.
After a longer than necessary wait, nnx.shard_map is now live!