warp-rnnt
warp-rnnt copied to clipboard
warp-rnnt with compact memory layout implementation
The compact layout in memory can be explained with this figure. The input to rnnt_loss()
is of size (N, T, U+1, V)=(3, 6, 6, V)
in a normal layout. Colored boxes denote data and white boxes denote the padding. (Note that I eliminate the V
dimension)
- with normal layout, the memory size of input tensor is
(3 * 6 * 6 * V)floats * 4B/float = 432V Byte
- with compact layout, that is
(20+18+18) * V floats * 4B/float = 224V Byte
. Significantly less than normal layout.
I also implement the gather mode with compact layout, which is more recommended. Since it's difficult to select the indices (current gather mode uses the torch.gather
) in a compact layout, I integrate the gather operation and its backward computation in C++/CUDA.
Only PyTorch binding is implemented.
benchmark results: (tested on RTX 3090)
- forward only: https://github.com/maxwellzh/warp-rnnt/blob/master/pytorch_binding/BenchMarkForwardResults.txt
- forward+backward: https://github.com/maxwellzh/warp-rnnt/blob/master/pytorch_binding/BenchMarkResults.txt
Wow, looks cool! Thank you! I will check it in next few days. As far as I remember, there was a research paper about the compact layout. Maybe we should mention this in the README as well.
Maybe this one? https://arxiv.org/abs/1909.12415
Exactly, Section 3.1. Efficient encoder and prediction output combination
is it similar or am I misunderstood?
Yes, that's it. We can add it in the README later.
I did review and I can't understand how it will be used in practice. @maxwellzh could you shed some light on how to prepare log_probs
and labels
arrays in practice? Is it possible to efficiently prepare a compact log_probs
array? It looks like this array will always have the shape (N, T, U, V) before this cost function and we have to convert it manually with compactTensor
function. Is it correct?
The function compactTensor
is just for testing. If we convert the tensor by that function every time invoking rnnt_loss()
, the overall performance might be poor.
Normally, in Joint Network, we do such things.
# (N, T, H_enc) -> (N, T, H_joint)
trans_enc = fc_enc(encoder_output) # linear layer over encoder, or the transcription network
# (N, U+1, H_dec) -> (N, U+1, H_joint)
trans_dec = fc_dec(decoder_output) # linear layer over decoder, or the prediction network
# (N, T, 1, H_joint) + (N, 1, U+1, H_joint) -> (N, T, U+1, H_joint)
expanded_input = trans_enc.unsqueeze(2) + trans_dec.unsqueeze(1) # broadcast to add up
# (N, T, U+1, H_joint) -> (N, T, U+1, V)
rnnt_input = classifier(sigmoid(expanded_input)).log_softmax(dim=-1)
loss = rnnt_loss(rnnt_input, ...)
With compact layout, in practice, I would recommend to compact the tensors before the first linear layer. Like this,
# (N, T, H_enc) -> (ST, H_enc), ST denotes \sum{Ti}
encoder_output_compact = compactTensor(encoder_output) # make it compact
# (ST, H_enc) -> (ST, H_joint)
trans_enc = fc_enc(encoder_output_compact) # linear layer over encoder, or the transcription network
# (N, U+1, H_dec) -> (SU, H_dec), SU denotes \sum{Ui+1}
decoder_output_compact = compactTensor(decoder_output) # make it compact
# (SU, H_dec) -> (SU, H_joint)
trans_dec = fc_dec(decoder_output_compact) # linear layer over decoder, or the prediction network
# broadcast addup is not as easy as the normal layout, so we have to add it by sequence or implement a faster CUDA-binding function
# (ST, H_joint) + (SU, H_joint) -> (STU, H_joint), STU denotes \sum{Ti(Ui+1)}
expanded_input = TrickyAdd(trans_enc, trans_dec)
# (STU, H_joint) -> (STU, V)
rnnt_input = classifier(sigmoid(expanded_input)).log_softmax(dim=-1)
loss = rnnt_loss(rnnt_input, ..., compact=True)
compactTensor()
with 3-dim tensor is easy to do with torch.cat
. As for the TrickyAdd()
, I didn't think of any efficient way with python, so I just implement it with CUDA along with its backward. In fact, I also implement the compactTensor()
for 3-dim tensors in CUDA. Do you have suggestions about this?
Thanks for clarification! Now it make sense, you have the additional functions. Could you add these into MR? Without these functions, it's not clear how to use the compact version more efficiently than the original one.
It's something outside the rnnt-loss, so I'd like to create a new repo including these functions implementation. It will take some time for me to prepare the codes. I'll let you know if it's prepared.
BTW. the compact version is more efficient for squeezing the padding, so the performance improvement would also has relation to the N=batch_size
. Theoretically, when N
is large, we have lots of padding in the normal layout, so the compact version is expected to be better. However, if N
is relatively small, considering the overhead of compactTensor
, the compact version might be poor.
I have released the implementation of these functions. https://github.com/maxwellzh/torch-gather
Thank you! I will check it on the weekend.