torchrec
torchrec copied to clipboard
Add missing fields to KJT's PyTree flatten/unflatten logic for VBE KJT
Summary:
Context
- Currently torchrec IR serializer does not support exporting variable batch KJT, because the
stride_per_rank_per_rankandinverse_indicesfields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. - The diff updates KJT's PyTree flatten/unflatten function to include
stride_per_rank_per_rankandinverse_indices.
Ref
Differential Revision: D74295924