LoopVectorization.jl icon indicating copy to clipboard operation
LoopVectorization.jl copied to clipboard

Multiply 2 16-bit integers and accumulate to 32-bit integer

Open zsoerenm opened this issue 5 years ago • 10 comments
trafficstars

How would I multiply 2 16-bit integers and accumulate to 32-bit integer? Something like _mm_madd_epi16 So far the result of the multiplication seems to be done in 16-bit:

a = Int16.([1 << 8, 1 << 9, 1 << 10, 1 << 10])
b = Int16.([1 << 7, 1 << 8, 1 << 11, 1 << 10])
function fma_avx(a, b)
    res = Int32(0)
    @avx for i = 1:4
        res = res + a[i] * b[i]
    end
    res
end
julia> fma_avx(a, b)
-32768

julia> fma_avx(Int32.(a), Int32.(b))
3309568

zsoerenm avatar Jan 26 '20 11:01 zsoerenm

Hmm, interesting. Unless I'm missing it, LLVM doesn't seem to support it with an intrinsic (as of the yet-to-be-released LLVM-10).

But It can be implemented via shufflevectors and zext.

Do some instruction sets provide hardware support for this? If so, we'd have to confirm that LLVM does the correct thing from that definition (it normally does). Or have any other suggestions?

I'd also have to change how I do reductions. Currently, it reads the type from the arrays (a and b), creates a corresponding vector of zeros with respect to the operation of the appropriate type (ie, 0 for addition, 1 for multiplication), and then in the end it will reduce the result and combine with your initial value (if the initial value was 0 itself, the compiler should drop the addition). The needed change (aside from support for that operation) is that it'll need to keep track of separate types for the arrays and accumulators.

chriselrod avatar Jan 26 '20 17:01 chriselrod

Or have any other suggestions?

I am unfortunate, that I am not of much help here. SIMD is new for me. While googling I found that _mm_madd_epi16 is what I needed. Unfortunately, I haven't found a LLVM intrinsic either.

During testing I made another observation: Even if each individual multiplication does not overflow, the sum does with @avx:

a = repeat(Int16.([1 << 6]), 10)
b = copy(a)
fma_avx(a, b) # Result: -24576

However without @avx it does the job:

function fma(a, b)
    res = Int32(0)
    @inbounds for i = 1:length(a)
        res = res + a[i] * b[i]
    end
    res
end
fma(a, b) # 40960

zsoerenm avatar Jan 26 '20 20:01 zsoerenm

Even if each individual multiplication does not overflow, the sum does with @avx

Yes, like I said, currently @avx creates an accumulator of the same type as promite_type(eltype(a), eltype(b)). Only when it is done accumulating does it sum this accumulator (which will normally be a SIMD vector), and add that to your res.

chriselrod avatar Jan 26 '20 20:01 chriselrod

Here is one possible implementation:

using SIMDPirates
@generated function muladd(a::Vec{W1,T1}, b::Vec{W1,T1}, c::Vec{W2,T2}) where {W1,W2,T1<:Integer,T2<:Integer}
    size_T1 = sizeof(T1); size_T2 = sizeof(T2)
    @assert size_T1 * W1 == size_T2 * W2
    @assert W1 == 2W2 # lazy, for now we'll only implement this case.
    typ1 = "i$(8size_T1)"
    vtyp1 = "<$W1 x $typ1>"
    typ2 = "i$(8size_T2)"
    vtyp2 = "<$W2 x $typ2>"
    instrs = String[]
    vhtyp1 = "<$W2 x $typ1>"
    su = join(1:2:W1-1, ", i32 ")
    sl = join(0:2:W1-1, ", i32 ")
    push!(instrs, "%au = shufflevector $vtyp1 %0, $vtyp1 undef, <$W2 x i32> <i32 $su>")
    push!(instrs, "%bu = shufflevector $vtyp1 %1, $vtyp1 undef, <$W2 x i32> <i32 $su>")
    push!(instrs, "%auz = zext $vhtyp1 %au to $vtyp2")
    push!(instrs, "%buz = zext $vhtyp1 %bu to $vtyp2")
    push!(instrs, "%abu = mul $vtyp2 %auz, %buz")
    push!(instrs, "%cab = add $vtyp2 %2, %abu")
    push!(instrs, "%al = shufflevector $vtyp1 %0, $vtyp1 undef, <$W2 x i32> <i32 $sl>")
    push!(instrs, "%bl = shufflevector $vtyp1 %1, $vtyp1 undef, <$W2 x i32> <i32 $sl>")
    push!(instrs, "%alz = zext $vhtyp1 %al to $vtyp2")
    push!(instrs, "%blz = zext $vhtyp1 %bl to $vtyp2")
    push!(instrs, "%abl = mul $vtyp2 %alz, %blz")
    push!(instrs, "%ret = add $vtyp2 %cab, %abl")
    push!(instrs, "ret $vtyp2 %ret")
    quote
        $(Expr(:meta,:inline))
        Base.llvmcall($(join(instrs,"\n")), Vec{$W2,$T2}, Tuple{Vec{$W1,$T1},Vec{$W1,$T1},Vec{$W2,$T2}}, a, b, c)
    end
end

This yields:

julia> ai16 = ntuple(Val(32)) do i Core.VecElement(Int16(i)) end;

julia> bi16 = ntuple(Val(32)) do i Core.VecElement(Int16(i)) end;

julia> ci32 = ntuple(Val(16)) do i Core.VecElement(Int32(i)) end;

julia> @code_native debuginfo=:none muladd(ai16, bi16, ci32)
	.text
	movabsq	$.rodata.cst32, %rax
	vbroadcasti64x4	(%rax), %zmm3   # zmm3 = mem[0,1,2,3,0,1,2,3]
	vpermw	%zmm1, %zmm3, %zmm4
	vpermw	%zmm0, %zmm3, %zmm3
	movabsq	$139727739599872, %rax  # imm = 0x7F14E648BC00
	vbroadcasti64x4	(%rax), %zmm5   # zmm5 = mem[0,1,2,3,0,1,2,3]
	vpermw	%zmm1, %zmm5, %zmm1
	vpermw	%zmm0, %zmm5, %zmm0
	vpmovzxwd	%ymm0, %zmm0    # zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero
	vpmovzxwd	%ymm1, %zmm1    # zmm1 = ymm1[0],zero,ymm1[1],zero,ymm1[2],zero,ymm1[3],zero,ymm1[4],zero,ymm1[5],zero,ymm1[6],zero,ymm1[7],zero,ymm1[8],zero,ymm1[9],zero,ymm1[10],zero,ymm1[11],zero,ymm1[12],zero,ymm1[13],zero,ymm1[14],zero,ymm1[15],zero
	vpmulld	%zmm0, %zmm1, %zmm0
	vpaddd	%zmm2, %zmm0, %zmm0
	vpmovzxwd	%ymm3, %zmm1    # zmm1 = ymm3[0],zero,ymm3[1],zero,ymm3[2],zero,ymm3[3],zero,ymm3[4],zero,ymm3[5],zero,ymm3[6],zero,ymm3[7],zero,ymm3[8],zero,ymm3[9],zero,ymm3[10],zero,ymm3[11],zero,ymm3[12],zero,ymm3[13],zero,ymm3[14],zero,ymm3[15],zero
	vpmovzxwd	%ymm4, %zmm2    # zmm2 = ymm4[0],zero,ymm4[1],zero,ymm4[2],zero,ymm4[3],zero,ymm4[4],zero,ymm4[5],zero,ymm4[6],zero,ymm4[7],zero,ymm4[8],zero,ymm4[9],zero,ymm4[10],zero,ymm4[11],zero,ymm4[12],zero,ymm4[13],zero,ymm4[14],zero,ymm4[15],zero
	vpmulld	%zmm1, %zmm2, %zmm1
	vpaddd	%zmm1, %zmm0, %zmm0
	retq
	nopl	(%rax)

Did I do this right? It may be that other shuffles yield better assembly.

This definitely is not as described, but using the following instead:

           su = join(0:W2-1, ", i32 ")
           sl = join(W2:W1-1, ", i32 ")

yields

julia> @code_native debuginfo=:none muladd(ai16, bi16, ci32)
	.text
	vpmovzxwd	%ymm0, %zmm3    # zmm3 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero
	vpmovzxwd	%ymm1, %zmm4    # zmm4 = ymm1[0],zero,ymm1[1],zero,ymm1[2],zero,ymm1[3],zero,ymm1[4],zero,ymm1[5],zero,ymm1[6],zero,ymm1[7],zero,ymm1[8],zero,ymm1[9],zero,ymm1[10],zero,ymm1[11],zero,ymm1[12],zero,ymm1[13],zero,ymm1[14],zero,ymm1[15],zero
	vpmulld	%zmm3, %zmm4, %zmm3
	vpaddd	%zmm2, %zmm3, %zmm2
	vextracti64x4	$1, %zmm0, %ymm0
	vextracti64x4	$1, %zmm1, %ymm1
	vpmovzxwd	%ymm0, %zmm0    # zmm0 = ymm0[0],zero,ymm0[1],zero,ymm0[2],zero,ymm0[3],zero,ymm0[4],zero,ymm0[5],zero,ymm0[6],zero,ymm0[7],zero,ymm0[8],zero,ymm0[9],zero,ymm0[10],zero,ymm0[11],zero,ymm0[12],zero,ymm0[13],zero,ymm0[14],zero,ymm0[15],zero
	vpmovzxwd	%ymm1, %zmm1    # zmm1 = ymm1[0],zero,ymm1[1],zero,ymm1[2],zero,ymm1[3],zero,ymm1[4],zero,ymm1[5],zero,ymm1[6],zero,ymm1[7],zero,ymm1[8],zero,ymm1[9],zero,ymm1[10],zero,ymm1[11],zero,ymm1[12],zero,ymm1[13],zero,ymm1[14],zero,ymm1[15],zero
	vpmulld	%zmm0, %zmm1, %zmm0
	vpaddd	%zmm0, %zmm2, %zmm0
	retq
	nop

chriselrod avatar Jan 26 '20 21:01 chriselrod

Wow, that's a lot of work you put into it. I really appreciate it! However, I fail to use it.

  1. ~~Why is ci32 a vector and not just a single Int32 like res?~~ Alright I understand that the first element of the result is ai16[1] * bi16[1] + ai16[2] * bi16[2] + ci32[1]. That looks good to me.
  2. I tried to use it with
function testing(a, b)
    res = Int32(0)
    @avx for i = 1:length(a)
        res += muladd1(a[i], b[i], res) # I renamed your muladd to muladd1 just to be sure that I use the correct one
    end
    res
end
testing(a, b) # Results into an error ERROR: MethodError: no method matching muladd1(::SVec{16,Int16}, ::SVec{16,Int16}, ::SVec{16,Int16})

Edit: Alright, I was able to get the correct result by doing:

function test1(ai16, bi16, ci32)
    SIMDPirates.reduce_to_add(muladd(ai16, bi16, ci32), Int32(0))
end
function test2(a, b)
    res = Int32(0)
    @inbounds for i = 1:length(a)
        res += a[i] * b[i]
    end
    res
end
a = repeat([Int16(1<<6)], 32)
b = copy(a)
ai16 = ntuple(Val(32)) do i Core.VecElement(Int16(1<<6)) end
bi16 = ntuple(Val(32)) do i Core.VecElement(Int16(1<<6)) end
ci32 = ntuple(Val(16)) do i Core.VecElement(Int32(0)) end
julia> @btime test1($ai16, $bi16, $ci32)
  8.716 ns (0 allocations: 0 bytes)
131072

julia> @btime test2($a, $b)
  4.879 ns (0 allocations: 0 bytes)
131072

Is that the intended way to do it?

zsoerenm avatar Jan 27 '20 07:01 zsoerenm

That was just supposed to be an implementation of _mmX_madd_epi16:

FOR j := 0 to 15
	i := j*32
	dst[i+31:i] := a[i+31:i+16]*b[i+31:i+16] + a[i+15:i]*b[i+15:i]
ENDFOR
dst[MAX:512] := 0

I showed the @code_native for X=512.

I haven't changed the way LoopVectorization handles reductions, so there is no way to use it.

It's also worth looking at what LLVM does for your test2, on an AVX2 computer:

#julia> @code_native debuginfo=:none test2(a,b)
        .text
        movq    8(%rdi), %rax
        testq   %rax, %rax
        jle     L45
        movq    %rax, %rcx
        sarq    $63, %rcx
        andnq   %rax, %rcx, %r8
        movq    (%rdi), %rdx
        movq    (%rsi), %rsi
        cmpq    $32, %r8
        jae     L48
        xorl    %eax, %eax
        movl    $1, %edi
        jmp     L236
L45:
        xorl    %eax, %eax
        retq
L48:
        movabsq $9223372036854775776, %rcx # imm = 0x7FFFFFFFFFFFFFE0
        andq    %r8, %rcx
        leaq    1(%rcx), %rdi
        vpxor   %xmm0, %xmm0, %xmm0
        xorl    %eax, %eax
        vpxor   %xmm1, %xmm1, %xmm1
        vpxor   %xmm2, %xmm2, %xmm2
        vpxor   %xmm3, %xmm3, %xmm3
        nopw    %cs:(%rax,%rax)
        nopl    (%rax)
L96:
        vmovdqu (%rsi,%rax,2), %xmm4
        vmovdqu 16(%rsi,%rax,2), %xmm5
        vmovdqu 32(%rsi,%rax,2), %xmm6
        vmovdqu 48(%rsi,%rax,2), %xmm7
        vpmullw (%rdx,%rax,2), %xmm4, %xmm4
        vpmullw 16(%rdx,%rax,2), %xmm5, %xmm5
        vpmullw 32(%rdx,%rax,2), %xmm6, %xmm6
        vpmullw 48(%rdx,%rax,2), %xmm7, %xmm7
        vpmovsxwd       %xmm4, %ymm4
        vpaddd  %ymm4, %ymm0, %ymm0
        vpmovsxwd       %xmm5, %ymm4
        vpaddd  %ymm4, %ymm1, %ymm1
        vpmovsxwd       %xmm6, %ymm4
        vpaddd  %ymm4, %ymm2, %ymm2
        vpmovsxwd       %xmm7, %ymm4
        vpaddd  %ymm4, %ymm3, %ymm3
        addq    $32, %rax
        cmpq    %rax, %rcx
        jne     L96
        vpaddd  %ymm0, %ymm1, %ymm0
        vpaddd  %ymm0, %ymm2, %ymm0
        vpaddd  %ymm0, %ymm3, %ymm0
        vextracti128    $1, %ymm0, %xmm1
        vpaddd  %xmm1, %xmm0, %xmm0
        vpshufd $78, %xmm0, %xmm1       # xmm1 = xmm0[2,3,0,1]
        vpaddd  %xmm1, %xmm0, %xmm0
        vpshufd $229, %xmm0, %xmm1      # xmm1 = xmm0[1,1,2,3]
        vpaddd  %xmm1, %xmm0, %xmm0
        vmovd   %xmm0, %eax
        cmpq    %rcx, %r8
        je      L262
L236:
        decq    %rdi
        nop
L240:
        movzwl  (%rsi,%rdi,2), %ecx
        imulw   (%rdx,%rdi,2), %cx
        movswl  %cx, %ecx
        addl    %ecx, %eax
        incq    %rdi
        cmpq    %rdi, %r8
        jne     L240
L262:
        vzeroupper
        retq
        nopw    (%rax,%rax)

The loop body is

L96:
        vmovdqu (%rsi,%rax,2), %xmm4
        vmovdqu 16(%rsi,%rax,2), %xmm5
        vmovdqu 32(%rsi,%rax,2), %xmm6
        vmovdqu 48(%rsi,%rax,2), %xmm7
        vpmullw (%rdx,%rax,2), %xmm4, %xmm4
        vpmullw 16(%rdx,%rax,2), %xmm5, %xmm5
        vpmullw 32(%rdx,%rax,2), %xmm6, %xmm6
        vpmullw 48(%rdx,%rax,2), %xmm7, %xmm7
        vpmovsxwd       %xmm4, %ymm4
        vpaddd  %ymm4, %ymm0, %ymm0
        vpmovsxwd       %xmm5, %ymm4
        vpaddd  %ymm4, %ymm1, %ymm1
        vpmovsxwd       %xmm6, %ymm4
        vpaddd  %ymm4, %ymm2, %ymm2
        vpmovsxwd       %xmm7, %ymm4
        vpaddd  %ymm4, %ymm3, %ymm3
        addq    $32, %rax
        cmpq    %rax, %rcx
        jne     L96

This CPU has 256 bit (ymm) registers. It can also use only half of the registers, for 128 bit. When it does so, they're called xmm. First, LLVM loads 4x 128 bits from one vector. That is, it loads 4 vectors that each contain 8 16-bit Ints. Those are the 4 lines saying:

vmovdqu (%rsi,%rax,2), %xmm4

Then, it multiplies these 4 vectors with vectors from elsewhere in memory.

vpmullw (%rdx,%rax,2), %xmm4, %xmm4

producing 4 vectors of 8 16-bit Ints that contain the products. Because these are still 16-bit, overflow could happen here (it was a mistake in my code above to disallow that). Next, it starts converting these vectors of 8 16-bit Ints (for 128 bits total) into vectors of 8 32-bit Ints (for 256 bits total), and then adding the results to four different accumulation vectors (ymm0, ymm1, ymm2, ymm3) that contain 8x 32 bit Ints:

        vpmovsxwd       %xmm4, %ymm4
        vpaddd  %ymm4, %ymm0, %ymm0

Eventually, it adds these 4 vectors together into just 1 vector, and sums the elements of that vector.

What @avx does currently is similar, except that it would use vectors of 16 16-bit Ints, and add them to accumulation vectors of 16 16-bit Ints. Once it is done and has summed all those accumulation vectors into a single value, does it add that result to the Int32. Before I change LoopVectorization to do things differently, I want to figure out what it should actually do.

What LLVM does:

  1. Uses full-width accumulation vectors.
  2. Loads half-width vectors from a and b.
  3. Multiplies them, producing half-width products. Could overflow in situations Int32 would not.
  4. Promotes them to full-width vectors by converting from Int16 to Int32.
  5. Adds them to the accumulation vectors.

A more _mm_madd_epi-like approach would be:

  1. Use full-width accumulation vectors.
  2. Load full-width vectors from a and b.
  3. Split these vectors into 4 half-width vectors.
  4. Multiply them.
  5. Promote those half-width vectors into full-width by converting from Int16 to Int32.
  6. Add them to the accumulation vectors. There's also the choice here about whether you want to double the number of accumulation vectors, or add to the same vector twice.

I think something closer to this is what I want:

using SIMDPirates
@generated function pmaddwd(a::Vec{W,T}, b::Vec{W,T}) where {W,T<:Integer}
    size_T = sizeof(T)
    @assert ispow2(W)
    Wh = W >>> 1
    typ1 = "i$(8size_T)"
    vtyp1 = "<$W x $typ1>"
    typ2 = "i$(16size_T)"
    T2 = if T == Int8
        Int16
    elseif T == Int16
        Int32
    elseif T == Int32
        Int64
    elseif T == Int64
        Int128
    elseif T == UInt8
        UInt16
    elseif T == UInt16
        UInt32
    elseif T == UInt32
        UInt64
    elseif T == UInt64
        UInt128
    else
        throw("Integer of type $T not supported.")
    end
    vtyp2 = "<$Wh x $typ2>"
    vtyp3 = "<$Wh x $typ1>"
    instrs = String[]
    su = join(1:2:W-1, ", i32 ")
    sl = join(0:2:W-1, ", i32 ")
    push!(instrs, "%au = shufflevector $vtyp1 %0, $vtyp1 undef, <$Wh x i32> <i32 $su>")
    push!(instrs, "%al = shufflevector $vtyp1 %0, $vtyp1 undef, <$Wh x i32> <i32 $sl>")
    push!(instrs, "%bu = shufflevector $vtyp1 %1, $vtyp1 undef, <$Wh x i32> <i32 $su>")
    push!(instrs, "%bl = shufflevector $vtyp1 %1, $vtyp1 undef, <$Wh x i32> <i32 $sl>")
    push!(instrs, "%auz = zext $vtyp3 %au to $vtyp2")
    push!(instrs, "%alz = zext $vtyp3 %al to $vtyp2")
    push!(instrs, "%buz = zext $vtyp3 %bu to $vtyp2")
    push!(instrs, "%blz = zext $vtyp3 %bl to $vtyp2")
    push!(instrs, "%abu = mul $vtyp2 %auz, %buz")
    push!(instrs, "%abl = mul $vtyp2 %alz, %blz")
    push!(instrs, "%ret = add $vtyp2 %abu, %abl")
    push!(instrs, "ret $vtyp2 %ret")
    quote
        $(Expr(:meta,:inline))
        Base.llvmcall($(join(instrs,"\n")), Vec{$Wh,$T2}, Tuple{Vec{$W,$T},Vec{$W,$T}}, a, b)
    end
end

Whether these approaches will be faster or not than just using Int32 a and b to begin with will be dependent on whether or not your operations are memory bottle-necked.

But it doesn't compile to the correct code. It produces a monstrous pile of asm, when it should be just the pmaddwd instruction.

The _mmX_madd_epi16 series produces only a single instruction, so I'd have to get LLVM to do that as well. I think the easiest way might be inline assembly.

chriselrod avatar Jan 27 '20 14:01 chriselrod

Yeah, I would check out AsmMacro.jl for this

MasonProtter avatar Jan 29 '20 00:01 MasonProtter

julia> using VectorizationBase: REGISTER_SIZE

julia> @generated function vpmaddwd(a::NTuple{W,Core.VecElement{Int16}}, b::NTuple{W,Core.VecElement{Int16}}) where {W}
          Wh = W >>> 1
          @assert 2Wh == W
          @assert (REGISTER_SIZE >> 1) ≥ W
          S = W * 16
          # decl = "@llvm.x86.avx512.pmaddw.d.512"
          instr = "@llvm.x86.avx512.pmaddw.d.$S"
          decl = "declare <$Wh x i32> $instr(<32 x i16>, <32 x i16>)"
          instrs = String[
              "%res = call <$Wh x i32> $instr(<$W x i16> %0, <$W x i16> %1)",
              "ret <$Wh x i32> %res"
          ]
          quote
              $(Expr(:meta,:inline))
              Base.llvmcall(
                  $((decl,join(instrs,"\n"))),
                  NTuple{$Wh,Core.VecElement{Int32}},
                  Tuple{NTuple{$W,Core.VecElement{Int16}},NTuple{$W,Core.VecElement{Int16}}},
                  a, b
              )
          end
      end
vpmaddwd (generic function with 1 method)

julia> a = ntuple(Val(REGISTER_SIZE >>> 1)) do i Core.VecElement(Int16(1<<10 + i)) end;

julia> b = ntuple(Val(REGISTER_SIZE >>> 1)) do i Core.VecElement(Int16(1<<11 + i)) end;

julia> vpmaddwd(a, b)
(VecElement{Int32}(4203525), VecElement{Int32}(4215833), VecElement{Int32}(4228157), VecElement{Int32}(4240497), VecElement{Int32}(4252853), VecElement{Int32}(4265225), VecElement{Int32}(4277613), VecElement{Int32}(4290017), VecElement{Int32}(4302437), VecElement{Int32}(4314873), VecElement{Int32}(4327325), VecElement{Int32}(4339793), VecElement{Int32}(4352277), VecElement{Int32}(4364777), VecElement{Int32}(4377293), VecElement{Int32}(4389825))

julia> @code_native vpmaddwd(a, b)
	.text
; ┌ @ REPL[8]:2 within `vpmaddwd'
; │┌ @ REPL[8]:15 within `macro expansion'
	vpmaddwd	%zmm1, %zmm0, %zmm0
	retq
	nopw	(%rax,%rax)
; └└

Probably have to swap out "avx512" with something else for the function definition.

I'm not sure how I could use AsmMacro for vector inputs and outputs. All the examples seem to have integer or pointer inputs, and no returns.

chriselrod avatar Jan 29 '20 02:01 chriselrod

I love your commitment, that you have put into it! A drawback of using your proposed ASM is, however, that it is not platform agnostic. I have found, that at least clang is able to lower to vpmaddwd given the correct arch: https://bugs.llvm.org/show_bug.cgi?id=32710 Might this be of help?

An intermediate solution to this problem might to be to not care about the multiplication overflow and just accumulate to a 32-bit integer? In this case the user has to care himself, that the multiplication does not overflow.

zsoerenm avatar Jan 29 '20 14:01 zsoerenm

A difficulty is that it isn't obvious what we should be doing with the accumulator.

Normally, if the loop if a loop is U-fold unrolled and vectorized with width W, it will create U accumulation vectors of width W. Each vector fits in a register, and will remain there while summing the loop.

The beauty of vpmaddwd is that it lets us follow the hardware side of this pattern: You have U vectors that each occupy the entirety of a register.

And it works because while each element requires twice the bytes, there are half as many of them.

Otherwise, we have to either

  1. do what LLVM does (occupy only half the register with the elements you're accumulating)
  2. double the number of accumulation registers
  3. accumulate into the accumulation registers twice.

I think "2." is preferable to "3."

What would be ideal is to have a cleverer way of creating the accumulation vectors based on types, so that we can use pmaddw when possible, and otherwise default to one of the above strategies (probably 2, as it would be easiest to implement; the lazy way is to just specify a Vec that is too large, and LLVM will split it in two for you).

chriselrod avatar Jan 29 '20 15:01 chriselrod