orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Bugfixes to local checkpoint manager

Open findmyway opened this issue 5 months ago • 0 comments

Current implementation of slice_count is not what the name indicates. Actually it returns the replica count along the replica_axis_index. This PR returns the actual slice count based on the global mesh. _get_single_slice_sharding is also updated correspondingly, assuming that the slice dimension is replicated along the replica_axis_index dimension.


Without the fix here, _find_slice_with_complete_local_checkpoint will always return -1

findmyway avatar Jun 16 '25 13:06 findmyway