torchrec
torchrec copied to clipboard
Transformer block that supports jagged tensor as input
The standard transformer block supports input tensors that will have same sequence length. This requires padding the example sequences to the same length in a batch if the lengths are not equal.
The introduction of JaggedTensor
does support the variable length examples. Can we support a transformer block that can efficiently work with JaggedTensor
or KeyedJaggedTensor
as input, instead of converting them to the dense and padded form?
Hi, this seems like a question for core library https://github.com/pytorch/pytorch