torchrec
torchrec copied to clipboard
Update KJT stride calculation logic to be based off of inverse_indices for VBE KJTs.
Summary:
Update the _maybe_compute_stride_kjt logic to calculate stride based off of inverse_indices for VBE KJTs.
Currently, stride of VBE KJT with stride_per_key_per_rank is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc.
Differential Revision: D74273083