orbax
orbax copied to clipboard
Restore from single slice
- Note that I'm still using an old version (
v0.11.15) ofslice_devicesmethod. (This means that it will return the devices from singlesliceinstead of singlereplica) - The basic idea is to keep the original implementation almost unchanged. And I created another dimension of
replicasto broadcast data. - I also tried the original idea of simply restore from single replica. But I get the
InvalidShardingError. The reason is obvious, the devices of a process are distributed across different replicas.
https://github.com/google/orbax/blob/9fc371659b068718600b198cb47b901184f285e4/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py#L1468-L1474
My questions:
- Any obvious errors or potential improvements with my current implementation?
- One datapoint from my latest test: ~30s on deserialization plus ~60s on broadcasting (only one broadcast in total).
- Any idea on how to address the above
InvalidShardingError? (My initial thought is that, the resharding should still work after the sum op even though here's a mismatch.)
Thanks!
Looks like you are using two processes for testing here.
Could you also add the (2, 4) mesh shape in the test below? Make sure the mesh is created from jax.experimental.mesh_utils.create_device_mesh to validate my assumption.
https://gist.github.com/cpgaffney1/35161a6e6f6e1bc7bf2ffd3df543efe5#file-type_handlers_test-L291-L295