llvm-project icon indicating copy to clipboard operation
llvm-project copied to clipboard

Improve vectorized code of loop with bool condition

Open davidbolvansky opened this issue 3 years ago • 4 comments

#define N 256
typedef char T;
extern T a[N];
extern T b[N];
extern T c[N];
extern _Bool pb[N];
extern char pc[N];

void predicate_by_bool()
{
  for (int i = 0; i < N; i++)
    c[i] = pb[i] ? a[i] : b[i];
}

void predicate_by_char()
{
  for (int i = 0; i < N; i++)
    c[i] = pc[i] ? a[i] : b[i];
}

LLVM -O3 -mavx2:

predicate_by_bool:                      # @predicate_by_bool
        xor     eax, eax
        mov     r8, qword ptr [rip + pb@GOTPCREL]
        vpxor   xmm0, xmm0, xmm0
        mov     rdx, qword ptr [rip + b@GOTPCREL]
        mov     rsi, qword ptr [rip + a@GOTPCREL]
        mov     r9, qword ptr [rip + c@GOTPCREL]
.LBB0_1:                                # =>This Inner Loop Header: Depth=1
        vpcmpeqb        ymm1, ymm0, ymmword ptr [r8 + rax]
        vpmovmskb       ecx, ymm1
        test    cl, 1
        mov     rdi, rsi
        cmovne  rdi, rdx
        test    cl, 2
        vmovd   xmm1, dword ptr [rdi + rax]     # xmm1 = mem[0],zero,zero,zero
        mov     rdi, rsi
        cmovne  rdi, rdx
        vpinsrb xmm1, xmm1, byte ptr [rdi + rax + 1], 1
        test    cl, 4
        mov     rdi, rsi
        cmovne  rdi, rdx
        test    cl, 8
        vpinsrb xmm1, xmm1, byte ptr [rdi + rax + 2], 2
        mov     rdi, rsi
        cmovne  rdi, rdx
        vpinsrb xmm1, xmm1, byte ptr [rdi + rax + 3], 3
        test    cl, 16
        mov     rdi, rsi
        cmovne  rdi, rdx
        test    cl, 32
        vpinsrb xmm1, xmm1, byte ptr [rdi + rax + 4], 4
        mov     rdi, rsi
        cmovne  rdi, rdx
        ....

ICC -O3 -mavx2 (GCC is similar..):

predicate_by_bool:
        vpxor     ymm0, ymm0, ymm0                              #12.27
        vmovdqu   ymm1, YMMWORD PTR a[rip]                      #12.27
        vmovdqu   ymm4, YMMWORD PTR 32+a[rip]                   #12.27
        vmovdqu   ymm7, YMMWORD PTR 64+a[rip]                   #12.27
        vmovdqu   ymm10, YMMWORD PTR 96+a[rip]                  #12.27
        vmovdqu   ymm13, YMMWORD PTR 128+a[rip]                 #12.27
        vpcmpeqb  ymm2, ymm0, YMMWORD PTR pb[rip]               #12.12
        vpcmpeqb  ymm5, ymm0, YMMWORD PTR 32+pb[rip]            #12.12
        vpcmpeqb  ymm8, ymm0, YMMWORD PTR 64+pb[rip]            #12.12
        vpblendvb ymm3, ymm1, YMMWORD PTR b[rip], ymm2          #12.27
        vpblendvb ymm6, ymm4, YMMWORD PTR 32+b[rip], ymm5       #12.27
        vpblendvb ymm9, ymm7, YMMWORD PTR 64+b[rip], ymm8       #12.27
        vmovdqu   ymm1, YMMWORD PTR 160+a[rip]                  #12.27
        vmovdqu   ymm4, YMMWORD PTR 192+a[rip]                  #12.27
        vmovdqu   YMMWORD PTR c[rip], ymm3                      #12.5
        vmovdqu   YMMWORD PTR 32+c[rip], ymm6                   #12.5
        vmovdqu   YMMWORD PTR 64+c[rip], ymm9                   #12.5
        vpcmpeqb  ymm11, ymm0, YMMWORD PTR 96+pb[rip]           #12.12
        vpcmpeqb  ymm14, ymm0, YMMWORD PTR 128+pb[rip]          #12.12
        vpcmpeqb  ymm2, ymm0, YMMWORD PTR 160+pb[rip]           #12.12
        vpcmpeqb  ymm5, ymm0, YMMWORD PTR 192+pb[rip]           #12.12
        vpcmpeqb  ymm7, ymm0, YMMWORD PTR 224+pb[rip]           #12.12
        vmovdqu   ymm0, YMMWORD PTR 224+a[rip]                  #12.27
        vpblendvb ymm12, ymm10, YMMWORD PTR 96+b[rip], ymm11    #12.27
        vpblendvb ymm15, ymm13, YMMWORD PTR 128+b[rip], ymm14   #12.27
        vpblendvb ymm3, ymm1, YMMWORD PTR 160+b[rip], ymm2      #12.27
        vpblendvb ymm6, ymm4, YMMWORD PTR 192+b[rip], ymm5      #12.27
        vpblendvb ymm8, ymm0, YMMWORD PTR 224+b[rip], ymm7      #12.27
        vmovdqu   YMMWORD PTR 96+c[rip], ymm12                  #12.5
        vmovdqu   YMMWORD PTR 128+c[rip], ymm15                 #12.5
        vmovdqu   YMMWORD PTR 160+c[rip], ymm3                  #12.5
        vmovdqu   YMMWORD PTR 192+c[rip], ymm6                  #12.5
        vmovdqu   YMMWORD PTR 224+c[rip], ymm8                  #12.5
        vzeroupper                                              #13.1
        ret  

Current codegen: https://godbolt.org/z/hPf9fEs8v

davidbolvansky avatar May 19 '22 11:05 davidbolvansky

Similar case with T = short

#define N 256
typedef short T;
extern T a[N];
extern T b[N];
extern T c[N];
extern _Bool pb[N];

void predicate_by_bool()
{
  for (int i = 0; i < N; i++)
    c[i] = pb[i] ? a[i] : b[i];
}

ICC -O3 -mavx2:

predicate_by_bool:
        xor       eax, eax                                      #10.3
        vpxor     ymm1, ymm1, ymm1                              #11.12
        vmovdqu   ymm0, YMMWORD PTR .L_2il0floatpacket.0[rip]   #11.12
..B1.2:                         # Preds ..B1.2 ..B1.1
        vmovdqu   ymm2, YMMWORD PTR [b+rax*2]                   #11.27
        vmovdqu   ymm3, YMMWORD PTR [a+rax*2]                   #11.20
        vpmovzxbd ymm4, QWORD PTR [pb+rax]                      #11.12
        vpmovzxbd ymm11, QWORD PTR [8+pb+rax]                   #11.12
        vpcmpeqd  ymm7, ymm1, ymm4                              #11.12
        vpcmpeqd  ymm14, ymm1, ymm11                            #11.12
        vextracti128 xmm9, ymm2, 1                              #11.27
        vextracti128 xmm10, ymm3, 1                             #11.20
        vpmovsxwd ymm6, xmm2                                    #11.27
        vpmovsxwd ymm5, xmm3                                    #11.20
        vpmovsxwd ymm13, xmm9                                   #11.27
        vpmovsxwd ymm12, xmm10                                  #11.20
        vpblendvb ymm8, ymm5, ymm6, ymm7                        #11.27
        vpblendvb ymm15, ymm12, ymm13, ymm14                    #11.27
        vpand     ymm4, ymm8, ymm0                              #11.12
        vpand     ymm2, ymm15, ymm0                             #11.12
        vpackusdw ymm3, ymm4, ymm2                              #11.12
        vpermq    ymm5, ymm3, 216                               #11.12
        vmovdqu   YMMWORD PTR [c+rax*2], ymm5                   #11.5
        add       rax, 16                                       #10.3
        cmp       rax, 256                                      #10.3
        jb        ..B1.2        # Prob 99%                      #10.3
        vzeroupper                                              #12.1
        ret              

https://godbolt.org/z/hTExeqoq7

davidbolvansky avatar May 19 '22 11:05 davidbolvansky

With -O3 there is no vectorization (icc and gcc vectorizes it)

maybe cost model issue too? cc @RKSimon

davidbolvansky avatar May 20 '22 16:05 davidbolvansky

cc @rotateright @nikic It looks like we've managed to end up selecting between pointers instead of selecting between 2 loads:

select i1 %64, ptr @b, ptr @a
...
%113 = getelementptr inbounds [256 x i8], ptr %65, i64 0, i64 %14
...
%145 = load i8, ptr %113, align 1, !tbaa !10

RKSimon avatar May 26 '22 14:05 RKSimon

So missing instcombine canonicalization ?

davidbolvansky avatar Aug 15 '22 16:08 davidbolvansky