orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Restore from single slice

Open findmyway opened this issue 5 months ago • 1 comments

  1. Note that I'm still using an old version (v0.11.15) of slice_devices method. (This means that it will return the devices from single slice instead of single replica)
  2. The basic idea is to keep the original implementation almost unchanged. And I created another dimension of replicas to broadcast data.
  3. I also tried the original idea of simply restore from single replica. But I get the InvalidShardingError. The reason is obvious, the devices of a process are distributed across different replicas.

https://github.com/google/orbax/blob/9fc371659b068718600b198cb47b901184f285e4/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py#L1468-L1474

My questions:

  1. Any obvious errors or potential improvements with my current implementation?
    • One datapoint from my latest test: ~30s on deserialization plus ~60s on broadcasting (only one broadcast in total).
  2. Any idea on how to address the above InvalidShardingError? (My initial thought is that, the resharding should still work after the sum op even though here's a mismatch.)

findmyway avatar Jul 09 '25 07:07 findmyway

Thanks!

Looks like you are using two processes for testing here.

Could you also add the (2, 4) mesh shape in the test below? Make sure the mesh is created from jax.experimental.mesh_utils.create_device_mesh to validate my assumption.

https://gist.github.com/cpgaffney1/35161a6e6f6e1bc7bf2ffd3df543efe5#file-type_handlers_test-L291-L295

findmyway avatar Jul 11 '25 06:07 findmyway