torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

Transformer block that supports jagged tensor as input

Open yuhuishi-convect opened this issue 1 year ago • 1 comments

The standard transformer block supports input tensors that will have same sequence length. This requires padding the example sequences to the same length in a batch if the lengths are not equal.

The introduction of JaggedTensor does support the variable length examples. Can we support a transformer block that can efficiently work with JaggedTensor or KeyedJaggedTensor as input, instead of converting them to the dense and padded form?

yuhuishi-convect avatar Nov 13 '23 06:11 yuhuishi-convect

Hi, this seems like a question for core library https://github.com/pytorch/pytorch

henrylhtsang avatar Nov 13 '23 18:11 henrylhtsang