torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

Update KJT stride calculation logic to be based off of inverse_indices for VBE KJTs.

Open jd7-tr opened this issue 6 months ago • 3 comments

Summary: Update the _maybe_compute_stride_kjt logic to calculate stride based off of inverse_indices for VBE KJTs.

Currently, stride of VBE KJT with stride_per_key_per_rank is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc.

Differential Revision: D74273083

jd7-tr avatar May 06 '25 21:05 jd7-tr