opus icon indicating copy to clipboard operation
opus copied to clipboard

Make RNN faster by changing loop order

Open dofuuz opened this issue 5 years ago • 7 comments

Here is alternate of PR #101.

I changed loop order, so auto-vectorization can be applied without transposing weights. Encoding speed is about same as #101.

FYI: I tried to make z, r, h calculated in less loop, but there is no notable speedup. (Not included in this PR) Maybe, this would help manual vectorization and loop unrolling later. https://github.com/dofuuz/opus/blob/da2e14f03b1a43614b29a0e15762e6785133a720/src/mlp.c

dofuuz avatar Nov 11 '18 15:11 dofuuz

What kind of speedup are you observing with that patch? Did you check the disassembly to see if the compiler managed to vectorize anything?

jmvalin avatar Nov 11 '18 21:11 jmvalin

I just reordered loop to access memory sequentially, and utilize auto-vectorization more. It's about 2% speedup for opus encoding.

On MSVC, Using /fp:fast(option like -ffast-math on GCC) results assembly like this:

   /* Compute update gate. */
   for (i=0; i<N; i++)
00007FF796C2D578  cmp         r10,r9  
00007FF796C2D57B  jge         compute_gru+0EBh (07FF796C2D59Bh)  
      z[i] = gru->bias[i];
00007FF796C2D57D  mov         rcx,qword ptr [rsi]  
      z[i] = gru->bias[i];
00007FF796C2D580  movsx       eax,byte ptr [rcx+r10]  
00007FF796C2D585  movd        xmm0,eax  
00007FF796C2D589  cvtdq2ps    xmm0,xmm0  
00007FF796C2D58C  movss       dword ptr [rbp+r10*4-50h],xmm0  

   for (i=0; i<N; i++)
00007FF796C2D593  inc         r10  
00007FF796C2D596  cmp         r10,r9  
00007FF796C2D599  jl          compute_gru+0D0h (07FF796C2D580h)  

   for (j=0; j<M; j++)
00007FF796C2D59B  movsxd      r13,dword ptr [rsi+18h]  
00007FF796C2D59F  mov         qword ptr [rsp+220h],rbx  
00007FF796C2D5A7  mov         qword ptr [rsp+238h],rdi  
00007FF796C2D5AF  test        r13,r13  
00007FF796C2D5B2  jle         compute_gru+1F5h (07FF796C2D6A5h)  
         z[i] += gru->input_weights[j*stride + i] * input[j];
00007FF796C2D5B8  mov         r10,r11  
00007FF796C2D5BB  mov         ebx,r11d  
00007FF796C2D5BE  xchg        ax,ax  
      for (i=0; i<N; i++)
00007FF796C2D5C0  mov         rdx,r11  
00007FF796C2D5C3  cmp         r9,4  
00007FF796C2D5C7  jl          compute_gru+1ADh (07FF796C2D65Dh)  

   for (i=0; i<N; i++)
00007FF796C2D5CD  mov         r8,qword ptr [rsi+8]  
00007FF796C2D5D1  lea         rdi,[r9-3]  
00007FF796C2D5D5  movss       xmm1,dword ptr [r12+r10*4]  
00007FF796C2D5DB  movsxd      rcx,ebx  
00007FF796C2D5DE  add         r8,rcx  
         z[i] += gru->input_weights[j*stride + i] * input[j];
00007FF796C2D5E1  movsx       eax,byte ptr [r8+rdx]  
00007FF796C2D5E6  movd        xmm0,eax  
00007FF796C2D5EA  movsx       eax,byte ptr [r8+rdx+1]  
00007FF796C2D5F0  cvtdq2ps    xmm0,xmm0  
00007FF796C2D5F3  mulss       xmm0,xmm1  
00007FF796C2D5F7  addss       xmm0,dword ptr [rbp+rdx*4-50h]  
         z[i] += gru->input_weights[j*stride + i] * input[j];
00007FF796C2D5FD  movss       dword ptr [rbp+rdx*4-50h],xmm0  
00007FF796C2D603  movd        xmm0,eax  
00007FF796C2D607  movsx       eax,byte ptr [r8+rdx+2]  
00007FF796C2D60D  cvtdq2ps    xmm0,xmm0  
00007FF796C2D610  mulss       xmm0,xmm1  
00007FF796C2D614  addss       xmm0,dword ptr [rbp+rdx*4-4Ch]  
00007FF796C2D61A  movss       dword ptr [rbp+rdx*4-4Ch],xmm0  
00007FF796C2D620  movd        xmm0,eax  
00007FF796C2D624  movsx       eax,byte ptr [r8+rdx+3]  
00007FF796C2D62A  cvtdq2ps    xmm0,xmm0  
00007FF796C2D62D  mulss       xmm0,xmm1  
00007FF796C2D631  addss       xmm0,dword ptr [rbp+rdx*4-48h]  
00007FF796C2D637  movss       dword ptr [rbp+rdx*4-48h],xmm0  
00007FF796C2D63D  movd        xmm0,eax  
00007FF796C2D641  cvtdq2ps    xmm0,xmm0  
00007FF796C2D644  mulss       xmm0,xmm1  
00007FF796C2D648  addss       xmm0,dword ptr [rbp+rdx*4-44h]  
00007FF796C2D64E  movss       dword ptr [rbp+rdx*4-44h],xmm0  
00007FF796C2D654  add         rdx,4  

It's using SIMD ops like xmm, mulss, addss. Some auto-vectorization done by compiler.

dofuuz avatar Nov 12 '18 03:11 dofuuz

So your PR made me realize my code was way too repetitive. I just checked in a refactoring commit to master to isolate the product in a single function. That being said, I tried swapping the two loops and didn't really see a measurable improvement. Are you sure you're able to see a difference?

jmvalin avatar Nov 22 '18 19:11 jmvalin

After merging 9791b22b2c83980f6b4386c870cad58557c78007, I tested it again.

Profiling

on MSVC 2015.

↓ Before swapping 2 loops in gemm_accum() img003 Computing NN takes 2~3% of whole encoding time.

↓ After swapping 2 loops in gemm_accum() img004

compute_gru(): 2.5 → 1.7% (↓0.8%) compute_dense(): 0.9 → 0.6% (↓0.3%)

And speedup is

img005 3402ms → 3391ms (0.3% speedup) Speedup is slight so i measured it several times and averaged it.

Profiling says there are about 1% of speedup, but I coudln't find obvious speedup on measurement. There was 1~2% speedup when i made this PR (before merging 9791b22b2c83980f6b4386c870cad58557c78007) (see #101). I am not sure why. Maybe different calculating order made this difference?

I think the loop order should be swapped for speedup. But its effect is too small. I leave it to your judgment.

Anyway, you can close this PR after read this.

dofuuz avatar Jan 06 '19 04:01 dofuuz

And... gemm_accum() should be renamed to gemv_accum() because it's matrix-vector product.

dofuuz avatar Jan 06 '19 04:01 dofuuz

If you really want to optimize this, see sgemv_accum16() in: https://github.com/mozilla/LPCNet/blob/master/src/vec_avx.h It assumes that the output size is a multiple of 16, but it's easy to change to be a multiply of 8. Also, the matrix is assumed to be float, so it'd require some conversion for the 8-bit matrix.

jmvalin avatar Jan 07 '19 19:01 jmvalin

Isn't intended to be fixed point after all? I'd rather try to implement fixed-point RNN instead....

Is there anything already done with fixed-point implementation?

dofuuz avatar Jan 30 '19 08:01 dofuuz

Closing since we now have a much faster RNN implementation in dnn/

jmvalin avatar Feb 27 '24 20:02 jmvalin