warp-rnnt
warp-rnnt copied to clipboard
improve efficiency of warps
In current implementation, the warps along T axis are computed in fully serialized manner https://github.com/1ytic/warp-rnnt/blob/edd5857cd9abf29f12ab3fbc153f78f21191d80b/core.cu#L112-L134
The for loop of each warp is executed one-by-one, which means the ith warp at specific row u
, has to wait for all its leading warps to finish the loops, and that is i (num of warps) * W (for loop overhead, warpsize, 32 here)
time complexity.
However, we don't necessarily have to wait for previous warps to finish before we go into the loop in current warp.
Let's take forward computation of alphas as the example with warpsize=4
:
Here d
denotes the index inside a warp, so 0 <= d < W
. B
is the result from u-1
row and supposed to be ready.
The forward computation of alpha follows (indeed we do the computation in logarithm, here is just for discussion):
Note that alpha_0
relies on result from the last warp.
Here comes the trick, I rewrote alpha_3
formula to following
The underlined part is warp-independent. The first part (the product of emitting probability e_2 e_1 e_0
) can be computed by prefix sum (scan) algorithm in logarithm, and only introduce log2(W) complexity.
Finally, the new procedure is like:
- Compute local paths combination prob (the underlined part). O(W) complexity;
- Compute product of emitting probs (e2e1e0, ...) with prefix sum algorithm. O(log2(W)) complexity;
- Wait for previous warps to finish and compute final results. Constant complexity.
For all warps at row u, 1 & 2 can be done in parallel, ith warp has only to wait all previous warps to finish step 3. The new procedure should be considerably faster than current serialized execution, especially when T is large.
Hello Huahuan Zheng, interesting theory! But I don't think it will be useful in practice. Optimising a forward pass doesn't make sense. Your can check the cuda profiler logs. The big issue is memory IO, and I really like your previous MR with compact memory version. I wish to finish reviewing it and reopen your MR in near feature.
Will do further investigation later :)
As for the IO issue, I remember I have seen in somewhere that a thread block would instinctively load nearby memory whatever it is used or not. Have you ever tried using (N, U, T, V) layout instead of (N, T, U, V)? With the former's (and especially when gather=True), a warp (also a thread block) is able to load a chunk of consecutive memory and reuse it.
Indeed, I've been using the compact version loss function in our speech recognition tasks for a while. It should be technically correct (it's in my dev branch now, the main branch hasn't been updated for some time). I'll finish some merge from my dev to the main branch, and once it's finished, I would reopen the MR.
I’m not familiar with memory manager for cuda threads. But you right, having TxU matrix is the main bottleneck. Fortunately, there is solution for this, fast_rnnt. It looks really promising.
I've been following the fast_rnnt work for a while, but haven't make a successful pruned rnn-t training yet.
They also have a paper about the implementation. https://arxiv.org/pdf/2206.13236.pdf