llvm-project
llvm-project copied to clipboard
Improve vectorized code of loop with bool condition
#define N 256
typedef char T;
extern T a[N];
extern T b[N];
extern T c[N];
extern _Bool pb[N];
extern char pc[N];
void predicate_by_bool()
{
for (int i = 0; i < N; i++)
c[i] = pb[i] ? a[i] : b[i];
}
void predicate_by_char()
{
for (int i = 0; i < N; i++)
c[i] = pc[i] ? a[i] : b[i];
}
LLVM -O3 -mavx2:
predicate_by_bool: # @predicate_by_bool
xor eax, eax
mov r8, qword ptr [rip + pb@GOTPCREL]
vpxor xmm0, xmm0, xmm0
mov rdx, qword ptr [rip + b@GOTPCREL]
mov rsi, qword ptr [rip + a@GOTPCREL]
mov r9, qword ptr [rip + c@GOTPCREL]
.LBB0_1: # =>This Inner Loop Header: Depth=1
vpcmpeqb ymm1, ymm0, ymmword ptr [r8 + rax]
vpmovmskb ecx, ymm1
test cl, 1
mov rdi, rsi
cmovne rdi, rdx
test cl, 2
vmovd xmm1, dword ptr [rdi + rax] # xmm1 = mem[0],zero,zero,zero
mov rdi, rsi
cmovne rdi, rdx
vpinsrb xmm1, xmm1, byte ptr [rdi + rax + 1], 1
test cl, 4
mov rdi, rsi
cmovne rdi, rdx
test cl, 8
vpinsrb xmm1, xmm1, byte ptr [rdi + rax + 2], 2
mov rdi, rsi
cmovne rdi, rdx
vpinsrb xmm1, xmm1, byte ptr [rdi + rax + 3], 3
test cl, 16
mov rdi, rsi
cmovne rdi, rdx
test cl, 32
vpinsrb xmm1, xmm1, byte ptr [rdi + rax + 4], 4
mov rdi, rsi
cmovne rdi, rdx
....
ICC -O3 -mavx2 (GCC is similar..):
predicate_by_bool:
vpxor ymm0, ymm0, ymm0 #12.27
vmovdqu ymm1, YMMWORD PTR a[rip] #12.27
vmovdqu ymm4, YMMWORD PTR 32+a[rip] #12.27
vmovdqu ymm7, YMMWORD PTR 64+a[rip] #12.27
vmovdqu ymm10, YMMWORD PTR 96+a[rip] #12.27
vmovdqu ymm13, YMMWORD PTR 128+a[rip] #12.27
vpcmpeqb ymm2, ymm0, YMMWORD PTR pb[rip] #12.12
vpcmpeqb ymm5, ymm0, YMMWORD PTR 32+pb[rip] #12.12
vpcmpeqb ymm8, ymm0, YMMWORD PTR 64+pb[rip] #12.12
vpblendvb ymm3, ymm1, YMMWORD PTR b[rip], ymm2 #12.27
vpblendvb ymm6, ymm4, YMMWORD PTR 32+b[rip], ymm5 #12.27
vpblendvb ymm9, ymm7, YMMWORD PTR 64+b[rip], ymm8 #12.27
vmovdqu ymm1, YMMWORD PTR 160+a[rip] #12.27
vmovdqu ymm4, YMMWORD PTR 192+a[rip] #12.27
vmovdqu YMMWORD PTR c[rip], ymm3 #12.5
vmovdqu YMMWORD PTR 32+c[rip], ymm6 #12.5
vmovdqu YMMWORD PTR 64+c[rip], ymm9 #12.5
vpcmpeqb ymm11, ymm0, YMMWORD PTR 96+pb[rip] #12.12
vpcmpeqb ymm14, ymm0, YMMWORD PTR 128+pb[rip] #12.12
vpcmpeqb ymm2, ymm0, YMMWORD PTR 160+pb[rip] #12.12
vpcmpeqb ymm5, ymm0, YMMWORD PTR 192+pb[rip] #12.12
vpcmpeqb ymm7, ymm0, YMMWORD PTR 224+pb[rip] #12.12
vmovdqu ymm0, YMMWORD PTR 224+a[rip] #12.27
vpblendvb ymm12, ymm10, YMMWORD PTR 96+b[rip], ymm11 #12.27
vpblendvb ymm15, ymm13, YMMWORD PTR 128+b[rip], ymm14 #12.27
vpblendvb ymm3, ymm1, YMMWORD PTR 160+b[rip], ymm2 #12.27
vpblendvb ymm6, ymm4, YMMWORD PTR 192+b[rip], ymm5 #12.27
vpblendvb ymm8, ymm0, YMMWORD PTR 224+b[rip], ymm7 #12.27
vmovdqu YMMWORD PTR 96+c[rip], ymm12 #12.5
vmovdqu YMMWORD PTR 128+c[rip], ymm15 #12.5
vmovdqu YMMWORD PTR 160+c[rip], ymm3 #12.5
vmovdqu YMMWORD PTR 192+c[rip], ymm6 #12.5
vmovdqu YMMWORD PTR 224+c[rip], ymm8 #12.5
vzeroupper #13.1
ret
Current codegen: https://godbolt.org/z/hPf9fEs8v
Similar case with T = short
#define N 256
typedef short T;
extern T a[N];
extern T b[N];
extern T c[N];
extern _Bool pb[N];
void predicate_by_bool()
{
for (int i = 0; i < N; i++)
c[i] = pb[i] ? a[i] : b[i];
}
ICC -O3 -mavx2:
predicate_by_bool:
xor eax, eax #10.3
vpxor ymm1, ymm1, ymm1 #11.12
vmovdqu ymm0, YMMWORD PTR .L_2il0floatpacket.0[rip] #11.12
..B1.2: # Preds ..B1.2 ..B1.1
vmovdqu ymm2, YMMWORD PTR [b+rax*2] #11.27
vmovdqu ymm3, YMMWORD PTR [a+rax*2] #11.20
vpmovzxbd ymm4, QWORD PTR [pb+rax] #11.12
vpmovzxbd ymm11, QWORD PTR [8+pb+rax] #11.12
vpcmpeqd ymm7, ymm1, ymm4 #11.12
vpcmpeqd ymm14, ymm1, ymm11 #11.12
vextracti128 xmm9, ymm2, 1 #11.27
vextracti128 xmm10, ymm3, 1 #11.20
vpmovsxwd ymm6, xmm2 #11.27
vpmovsxwd ymm5, xmm3 #11.20
vpmovsxwd ymm13, xmm9 #11.27
vpmovsxwd ymm12, xmm10 #11.20
vpblendvb ymm8, ymm5, ymm6, ymm7 #11.27
vpblendvb ymm15, ymm12, ymm13, ymm14 #11.27
vpand ymm4, ymm8, ymm0 #11.12
vpand ymm2, ymm15, ymm0 #11.12
vpackusdw ymm3, ymm4, ymm2 #11.12
vpermq ymm5, ymm3, 216 #11.12
vmovdqu YMMWORD PTR [c+rax*2], ymm5 #11.5
add rax, 16 #10.3
cmp rax, 256 #10.3
jb ..B1.2 # Prob 99% #10.3
vzeroupper #12.1
ret
https://godbolt.org/z/hTExeqoq7
With -O3 there is no vectorization (icc and gcc vectorizes it)
maybe cost model issue too? cc @RKSimon
cc @rotateright @nikic It looks like we've managed to end up selecting between pointers instead of selecting between 2 loads:
select i1 %64, ptr @b, ptr @a
...
%113 = getelementptr inbounds [256 x i8], ptr %65, i64 0, i64 %14
...
%145 = load i8, ptr %113, align 1, !tbaa !10
So missing instcombine canonicalization ?