LoopVectorization.jl
LoopVectorization.jl copied to clipboard
Using @turbo with FowardDiff.Dual for logsumexp
Using @turbo loops gives incredible performance gains (10x) over the LogExpFunctions library for arrays of Float64s. However, the @turbo doesn't seem to play well with FowardDiff.Dual arrays and prints the warning below. Is there a way to leverage LoopVectorization to accelerate operations on Dual numbers?
`LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead.
Use `warn_check_args=false`, e.g. `@turbo warn_check_args=false ...`, to disable this warning.
I'm uploading a Pluto notebook with some benchmarks, which I reproduce below
Not sure if this is related to #93. @chriselrod , I think that this is related to your posts at https://discourse.julialang.org/t/speeding-up-my-logsumexp-function/42380/9?page=2 and https://discourse.julialang.org/t/fast-logsumexp-over-4th-dimension/64182/26
Thanks!
2-element BenchmarkTools.BenchmarkGroup:
tags: []
"Float64" => 6-element BenchmarkTools.BenchmarkGroup:
tags: ["Float64"]
"Vanilla Loop" => Trial(29.500 μs)
"Tullio" => Trial(5.200 μs)
"LogExpFunctions" => Trial(35.700 μs)
"Turbo" => Trial(3.000 μs)
"SIMD Loop" => Trial(25.500 μs)
"Vmap" => Trial(3.800 μs)
"Dual" => 6-element BenchmarkTools.BenchmarkGroup:
tags: ["Dual"]
"Vanilla Loop" => Trial(45.300 μs)
"Tullio" => Trial(53.100 μs)
"LogExpFunctions" => Trial(62.800 μs)
"Turbo" => Trial(311.900 μs)
"SIMD Loop" => Trial(37.600 μs)
"Vmap" => Trial(44.300 μs)
LoopVectorization functions are
"""
using `LoopVectorization.@turbo` loops
**NOTE** - not compatible with `ForwardDiff.Dual` numbers!
"""
function logsumexp_turbo!(Vbar, tmp_max, X)
n,k = size(X)
maximum!(tmp_max, X)
fill!(Vbar, 0)
@turbo for i in 1:n, j in 1:k
Vbar[i] += exp(X[i,j] - tmp_max[i])
end
@turbo for i in 1:n
Vbar[i] = log(Vbar[i]) + tmp_max[i]
end
return Vbar
end
"""
using `LoopVectorization` `vmap` convenience fcts
**NOTE** - this DOES work with `ForwardDiff.Dual` numbers!
"""
function logsumexp_vmap!(Vbar, tmp_max, X, Xtmp)
maximum!(tmp_max, X)
n = size(X,2)
for j in 1:n
Xtmpj = view(Xtmp, :, j)
Xj = view(X, :, j)
vmap!((xij, mi) -> exp(xij-mi), Xtmpj, Xj, tmp_max)
end
Vbartmp = vreduce(+, Xtmp; dims=2)
vmap!((vi,mi) -> log(vi) + mi, Vbar, Vbartmp, tmp_max)
return Vbar
end
See notebook logsumexp-speedtests.pdf
HTML rendering of notebook (strip off .txt) logsumexp-speedtests.jl.html.txt
Pluto notebook (strip off .txt) logsumexp-speedtests.jl.txt
I was able to get a bit faster for Dual numbers by pirating vexp and log_fast, though the relative speedup (2x) is still less than what @turbo does for Float64 arrays.
using ForwardDiff
const FD = ForwardDiff
import VectorizationBase: vexp
import SLEEFPirates: log_fast
@inline function vexp(d::FD.Dual{T}) where {T}
val = vexp(FD.value(d))
partials = FD.partials(d)
return FD.Dual{T}(val, val * partials)
end
@inline function log_fast(d::FD.Dual{T}) where {T}
val = FD.value(d)
partials = FD.partials(d)
return FD.Dual{T}(log_fast(val), inv(val) * partials)
end
"using base SIMD loops with LoopVectorization tricks"
function logsumexp_tricks!(Vbar, tmp_max, X)
m,n = size(X)
maximum!(tmp_max, X)
fill!(Vbar, 0)
@inbounds for j in 1:n
@simd for i in 1:m
Vbar[i] += vexp(X[i,j] - tmp_max[i])
end
end
@inbounds @simd for i in 1:m
Vbar[i] = log_fast(Vbar[i]) + tmp_max[i]
end
return Vbar
end
This also worked, though wasn't quite as fast
"using base SIMD loops with LoopVectorization tricks"
function logsumexp_turbo2!(Vbar, tmp_max, X)
m,n = size(X)
maximum!(tmp_max, X)
fill!(Vbar, 0)
@turbo safe=false warn_check_args=false for i in 1:m, j in 1:n
Vbar[i] += vexp(X[i,j] - tmp_max[i])
end
@turbo safe=false warn_check_args=false for i in 1:m
Vbar[i] = log_fast(Vbar[i]) + tmp_max[i]
end
return Vbar
end
I may respond with more later, but you can get some ideas for more tricks here: https://github.com/PumasAI/SimpleChains.jl/blob/main/src/forwarddiff_matmul.jl
Long term, the rewrite should "just work" for duals/generic Julia code. However, it is currently a long ways away; I'm still working on the rewrite's dependence analysis (which LoopVectorization.jl of course doesn't do at all).
Thank you for the quick response, @chriselrod -- really appreciate it. I had a bit of a hard time understanding the code you referenced, but it looked to me that the strategy was to reinterpret(reshape, T, A) the arrays as Float64 arrays and do the derivative computations separately. Is that the strategy you're suggesting? I tried that strategy and managed to get a big speedup vs LogExpFunctions for the case of logsumexp(X::AbstractVector{<:ForwardDiff.Dual}) over LogExpFunctions
See FastLogSumExp.jl. Benchmarks for these and a few more are at https://github.com/magerton/FastLogSumExp.jl. Benchmarking is done in runtests.jl. Using @turbo and reinterpreting arrays gives ~5-6x speedup on my current machine, and on my other one was giving 10x speedups.
Vector case
"fastest logsumexp over Dual vector requires tmp vector"
function vec_logsumexp_dual_reinterp!(tmp::AbstractVector{V}, X::AbstractVector{<:FD.Dual{T,V,K}}) where {T,V,K}
Xre = reinterpret(reshape, V, X)
uv = typemin(V)
@turbo for i in eachindex(X)
uv = max(uv, Xre[1,i])
end
s = zero(V)
@turbo for j in eachindex(X,tmp)
ex = exp(Xre[1,j] - uv)
tmp[j] = ex
s += ex
end
v = log(s) + uv # logsumexp value
invs = inv(s) # for doing softmax for derivatives
# would be nice to use a more elegant consruction for
# pvec instead of multiple conversions below
# that said, it seems like we still have zero allocations
pvec = zeros(MVector{K,V})
@turbo for j in eachindex(X,tmp)
tmp[j] *= invs
for k in 1:K
pvec[k] += tmp[j]*Xre[k+1,j]
end
end
ptup = NTuple{K,V}(pvec)
ptl = FD.Partials{K,V}(ptup)
return FD.Dual{T,V,K}(v, ptl)
end
Matrix case for logsumexp(X; dims=2)
function mat_logsumexp_dual_reinterp!(
Vbar::AbstractVector{D}, tmp_max::AbstractVector{V},
tmpX::Matrix{V}, X::AbstractMatrix{D}
) where {T,V,K,D<:FD.Dual{T,V,K}}
m,n = size(X)
(m,n) == size(tmpX) || throw(DimensionMismatch())
(m,) == size(Vbar) == size(tmp_max) || throw(DimensionMismatch())
Vre = reinterpret(reshape, V, Vbar)
Xre = reinterpret(reshape, V, X)
tmp_inv = tmp_max # resuse
fill!(Vbar, 0)
fill!(tmp_max, typemin(V))
@turbo for i in 1:m, j in 1:n
tmp_max[i] = max(tmp_max[i], Xre[1,i,j])
end
@turbo for i in 1:m, j in 1:n
ex = exp(Xre[1,i,j] - tmp_max[i])
tmpX[i,j] = ex
Vre[1,i] += ex
end
@turbo for i in 1:m
v = Vre[1,i]
m = tmp_max[i]
tmp_inv[i] = inv(v)
Vre[1,i] = log(v) + m
end
@turbo for i in 1:m, j in 1:n, k in 1:K
Vre[k+1,i] += tmpX[i,j]*Xre[k+1,i,j]*tmp_inv[i]
end
return Vbar
end