LoopVectorization.jl icon indicating copy to clipboard operation
LoopVectorization.jl copied to clipboard

Using @turbo with FowardDiff.Dual for logsumexp

Open magerton opened this issue 3 years ago • 6 comments

Using @turbo loops gives incredible performance gains (10x) over the LogExpFunctions library for arrays of Float64s. However, the @turbo doesn't seem to play well with FowardDiff.Dual arrays and prints the warning below. Is there a way to leverage LoopVectorization to accelerate operations on Dual numbers?

`LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead.
Use `warn_check_args=false`, e.g. `@turbo warn_check_args=false ...`, to disable this warning.

I'm uploading a Pluto notebook with some benchmarks, which I reproduce below

Not sure if this is related to #93. @chriselrod , I think that this is related to your posts at https://discourse.julialang.org/t/speeding-up-my-logsumexp-function/42380/9?page=2 and https://discourse.julialang.org/t/fast-logsumexp-over-4th-dimension/64182/26

Thanks!

2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "Float64" => 6-element BenchmarkTools.BenchmarkGroup:
	  tags: ["Float64"]
	  "Vanilla Loop" => Trial(29.500 μs)
	  "Tullio" => Trial(5.200 μs)
	  "LogExpFunctions" => Trial(35.700 μs)
	  "Turbo" => Trial(3.000 μs)
	  "SIMD Loop" => Trial(25.500 μs)
	  "Vmap" => Trial(3.800 μs)
  "Dual" => 6-element BenchmarkTools.BenchmarkGroup:
	  tags: ["Dual"]
	  "Vanilla Loop" => Trial(45.300 μs)
	  "Tullio" => Trial(53.100 μs)
	  "LogExpFunctions" => Trial(62.800 μs)
	  "Turbo" => Trial(311.900 μs)
	  "SIMD Loop" => Trial(37.600 μs)
	  "Vmap" => Trial(44.300 μs)

LoopVectorization functions are

"""
using `LoopVectorization.@turbo` loops

**NOTE** - not compatible with `ForwardDiff.Dual` numbers!
"""
function logsumexp_turbo!(Vbar, tmp_max, X)
	n,k = size(X)
	maximum!(tmp_max, X)
	fill!(Vbar, 0)
	@turbo for i in 1:n, j in 1:k
		Vbar[i] += exp(X[i,j] - tmp_max[i])
	end
	@turbo for i in 1:n
		Vbar[i] = log(Vbar[i]) + tmp_max[i]
	end
	return Vbar
end

"""
using `LoopVectorization` `vmap` convenience fcts

**NOTE** - this DOES work with `ForwardDiff.Dual` numbers!
"""
function logsumexp_vmap!(Vbar, tmp_max, X, Xtmp)
	maximum!(tmp_max, X)
	n = size(X,2)
	for j in 1:n
		Xtmpj = view(Xtmp, :, j)
		Xj    = view(X, :, j)
		vmap!((xij, mi) -> exp(xij-mi), Xtmpj, Xj, tmp_max)
	end
	Vbartmp = vreduce(+, Xtmp; dims=2)
	vmap!((vi,mi) -> log(vi) + mi, Vbar, Vbartmp, tmp_max)
	return Vbar
end

magerton avatar Sep 30 '22 19:09 magerton

See notebook logsumexp-speedtests.pdf

magerton avatar Sep 30 '22 19:09 magerton

HTML rendering of notebook (strip off .txt) logsumexp-speedtests.jl.html.txt

Pluto notebook (strip off .txt) logsumexp-speedtests.jl.txt

magerton avatar Sep 30 '22 19:09 magerton

I was able to get a bit faster for Dual numbers by pirating vexp and log_fast, though the relative speedup (2x) is still less than what @turbo does for Float64 arrays.

using ForwardDiff
const FD = ForwardDiff
import VectorizationBase: vexp
import SLEEFPirates: log_fast

@inline function vexp(d::FD.Dual{T}) where {T}
    val = vexp(FD.value(d))
    partials =  FD.partials(d)
    return FD.Dual{T}(val, val * partials)
end

@inline function log_fast(d::FD.Dual{T}) where {T}
    val = FD.value(d)
    partials =  FD.partials(d)
    return FD.Dual{T}(log_fast(val), inv(val) * partials)
end

"using base SIMD loops with LoopVectorization tricks"
function logsumexp_tricks!(Vbar, tmp_max, X)
	m,n = size(X)
	maximum!(tmp_max, X)
	fill!(Vbar, 0)
	@inbounds for j in 1:n
		@simd for i in 1:m
			Vbar[i] += vexp(X[i,j] - tmp_max[i])
		end
	end
    	
	@inbounds @simd for i in 1:m
		Vbar[i] = log_fast(Vbar[i]) + tmp_max[i]
	end
	return Vbar
end

magerton avatar Sep 30 '22 22:09 magerton

This also worked, though wasn't quite as fast

"using base SIMD loops with LoopVectorization tricks"
function logsumexp_turbo2!(Vbar, tmp_max, X)
	m,n = size(X)
	maximum!(tmp_max, X)
	fill!(Vbar, 0)
    @turbo safe=false warn_check_args=false for i in 1:m, j in 1:n
		Vbar[i] += vexp(X[i,j] - tmp_max[i])
	end
    	
	@turbo safe=false warn_check_args=false for i in 1:m
		Vbar[i] = log_fast(Vbar[i]) + tmp_max[i]
	end
	return Vbar
end

magerton avatar Sep 30 '22 22:09 magerton

I may respond with more later, but you can get some ideas for more tricks here: https://github.com/PumasAI/SimpleChains.jl/blob/main/src/forwarddiff_matmul.jl

Long term, the rewrite should "just work" for duals/generic Julia code. However, it is currently a long ways away; I'm still working on the rewrite's dependence analysis (which LoopVectorization.jl of course doesn't do at all).

chriselrod avatar Sep 30 '22 23:09 chriselrod

Thank you for the quick response, @chriselrod -- really appreciate it. I had a bit of a hard time understanding the code you referenced, but it looked to me that the strategy was to reinterpret(reshape, T, A) the arrays as Float64 arrays and do the derivative computations separately. Is that the strategy you're suggesting? I tried that strategy and managed to get a big speedup vs LogExpFunctions for the case of logsumexp(X::AbstractVector{<:ForwardDiff.Dual}) over LogExpFunctions

See FastLogSumExp.jl. Benchmarks for these and a few more are at https://github.com/magerton/FastLogSumExp.jl. Benchmarking is done in runtests.jl. Using @turbo and reinterpreting arrays gives ~5-6x speedup on my current machine, and on my other one was giving 10x speedups.

Vector case

"fastest logsumexp over Dual vector requires tmp vector"
function vec_logsumexp_dual_reinterp!(tmp::AbstractVector{V}, X::AbstractVector{<:FD.Dual{T,V,K}}) where {T,V,K}
    Xre   = reinterpret(reshape, V, X)

    uv = typemin(V)
    @turbo for i in eachindex(X)
        uv = max(uv, Xre[1,i])
    end

    s = zero(V)

    @turbo for j in eachindex(X,tmp)
        ex = exp(Xre[1,j] - uv)
        tmp[j] = ex
        s += ex
    end

    v = log(s) + uv # logsumexp value

    invs = inv(s) # for doing softmax for derivatives

    # would be nice to use a more elegant consruction for
    # pvec instead of multiple conversions below
    # that said, it seems like we still have zero allocations
    pvec = zeros(MVector{K,V})
    @turbo for j in eachindex(X,tmp)
        tmp[j] *= invs
        for k in 1:K
            pvec[k] += tmp[j]*Xre[k+1,j]
        end
    end

    ptup = NTuple{K,V}(pvec)
    ptl = FD.Partials{K,V}(ptup)

    return FD.Dual{T,V,K}(v, ptl)

end

Matrix case for logsumexp(X; dims=2)

function mat_logsumexp_dual_reinterp!(
    Vbar::AbstractVector{D}, tmp_max::AbstractVector{V}, 
    tmpX::Matrix{V}, X::AbstractMatrix{D}
    ) where {T,V,K,D<:FD.Dual{T,V,K}}
    
    m,n = size(X)

    (m,n) == size(tmpX) || throw(DimensionMismatch())
    (m,) == size(Vbar) == size(tmp_max) || throw(DimensionMismatch())

    Vre   = reinterpret(reshape, V, Vbar)
    Xre   = reinterpret(reshape, V, X)

    tmp_inv = tmp_max # resuse

    fill!(Vbar, 0)
    fill!(tmp_max, typemin(V))

    @turbo for i in 1:m, j in 1:n
        tmp_max[i] = max(tmp_max[i], Xre[1,i,j])
    end

    @turbo for i in 1:m, j in 1:n
        ex = exp(Xre[1,i,j] - tmp_max[i])
        tmpX[i,j] = ex
        Vre[1,i] += ex
    end

    @turbo for i in 1:m
        v = Vre[1,i]
        m = tmp_max[i]
        tmp_inv[i] = inv(v)
        Vre[1,i] = log(v) + m
    end

    @turbo for i in 1:m, j in 1:n, k in 1:K
        Vre[k+1,i] += tmpX[i,j]*Xre[k+1,i,j]*tmp_inv[i]
    end

    return Vbar

end

magerton avatar Oct 01 '22 16:10 magerton