torchrec
torchrec copied to clipboard
change regroup backend operator for KeyedTensor
Summary:
specs
- inputs: a list of KTs a. the N-th KT has a shape of (batch_size, dimN), batch_size should be identical b. the N-th KT contains a list of features, the j-th feature has a dimention of dimN_j, naturally sum(dimN_j for j) = dimN
- permute list: a list
permutewhere apermuteis a list of feature_id (str or int) a. each feature_id can has a corresponding feature from the inputs (KTs) b. apermuterepresents an output KT - outputs: a list of KTs a. number of output KTs = size of permute list b. each output KT contains the features from the inputs that corresponds to the permute
Differential Revision: D58649553