torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

Add missing fields to KJT's PyTree flatten/unflatten logic for VBE KJT

Open jd7-tr opened this issue 6 months ago • 1 comments

Summary:

Context

  • Currently torchrec IR serializer does not support exporting variable batch KJT, because the stride_per_rank_per_rank and inverse_indices fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function.
  • The diff updates KJT's PyTree flatten/unflatten function to include stride_per_rank_per_rank and inverse_indices.

Ref

Differential Revision: D74295924

jd7-tr avatar May 07 '25 02:05 jd7-tr