orbax
orbax copied to clipboard
Bugfixes to local checkpoint manager
Current implementation of slice_count is not what the name indicates. Actually it returns the replica count along the replica_axis_index. This PR returns the actual slice count based on the global mesh. _get_single_slice_sharding is also updated correspondingly, assuming that the slice dimension is replicated along the replica_axis_index dimension.
Without the fix here, _find_slice_with_complete_local_checkpoint will always return -1