Enzyme.jl
Enzyme.jl copied to clipboard
Performance issues with repeated gemv! calls in for loop
Minified example:
using Enzyme, LinearAlgebra
LinearAlgebra.BLAS.set_num_threads(1)
Enzyme.Compiler.bitcode_replacement!(false)
@inline function coupled_springs(K, m, x0, v0, T)
Ktmp = zero(K)
xtmp = zero(x0)
vtmp = zero(v0)
N = length(m)
pX = pointer(xtmp)
pY = pointer(vtmp)
sX = 1
sY = 1
pA = pointer(Ktmp)
for j in 1:5000
# @inline mul!(vtmp, Ktmp, xtmp, dt, 1)
GC.@preserve Ktmp xtmp vtmp ccall((:dgemv_64_, LinearAlgebra.BLAS.libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{Float64},
Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt},
Ref{Float64}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Clong),
'N', N, N, 0.001,
pA, N, pX, sX,
1.0, pY, sY, 1)
end
return @inbounds xtmp[1]
end
function make_args(N)
K = rand(N, N)
K[diagind(K)] .= 0
m = 0.5 .+ 0.5 * rand(N)
x0 = randn(N)
v0 = zeros(N)
T = 1.0
return K, m, x0, v0, T
end
@inline function enzyme_inputs(K, m, x0, v0, T)
dK = zero(K)
dm = zero(m)
dx0 = zero(x0)
dv0 = zero(v0)
return Duplicated(K, dK), Duplicated(m, dm), Duplicated(x0, dx0), Duplicated(v0, dv0), Const(T)
end
function enzyme_gradient(args...)
inputs = enzyme_inputs(args...)
dK = inputs[1].dval
Enzyme.autodiff(Reverse, Const(coupled_springs), inputs...)
return dK
end
N = 200
args = make_args(N)
@time coupled_springs(args...) # compilation
@time enzyme_gradient(args...) # compilation
println("\nPrimal:")
@time coupled_springs(args...) # around 0.02 seconds
println("Enzyme:")
@time enzyme_gradient(args...) # around 1.1 seconds
A vtunes profiler of the enzyme gradient call reveals that most time is spent in memcpy:
On Julia 1.10-beta1 with Enzyme main.
It seems like the slowdown only occurs with the tbl-gen blas rules. With the fallback Enzyme.Compiler.bitcode_replacement!(true)
it runs in around twice the time as the primal, while with tblgen Enzyme.Compiler.bitcode_replacement!(false)
as above it runs about 50x slower.
FYI, the MWE above was trivial in that the arrays for blas were always 0 and the primal was always 0. Below is a slightly more involved example with meaningful gradients (arrays not initialized to 0, and one more function in loop body). The big change is that the fallback BLAS rules also becomes slow in this case.
using Enzyme, LinearAlgebra
LinearAlgebra.BLAS.set_num_threads(1)
Enzyme.Compiler.bitcode_replacement!(true)
@inline function coupled_springs(K, m, x0, v0, T)
Ktmp = copy(K)
xtmp = copy(x0)
vtmp = copy(v0)
N = length(m)
pX = pointer(xtmp)
pY = pointer(vtmp)
sX = 1
sY = 1
pA = pointer(Ktmp)
dt = 1/5000
for j in 1:5000
# @inline mul!(vtmp, Ktmp, xtmp, dt, 1.0)
GC.@preserve Ktmp xtmp vtmp ccall((:dgemv_64_, LinearAlgebra.BLAS.libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{Float64},
Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt},
Ref{Float64}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Clong),
'N', N, N, dt,
pA, N, pX, sX,
1.0, pY, sY, 1)
xtmp .+= vtmp .* dt
end
return @inbounds xtmp[1]
end
function make_args(N)
K = ones(N, N)
K[diagind(K)] .= 0
m = 0.5 .+ 0.5 * rand(N)
x0 = float.(collect(1:N)) ./ N
v0 = zeros(N)
T = 1.0
return K, m, x0, v0, T
end
@inline function enzyme_inputs(K, m, x0, v0, T)
dK = zero(K)
dm = zero(m)
dx0 = zero(x0)
dv0 = zero(v0)
return Duplicated(K, dK), Duplicated(m, dm), Duplicated(x0, dx0), Duplicated(v0, dv0), Const(T)
end
function enzyme_gradient(args...)
inputs = enzyme_inputs(args...)
dK = inputs[1].dval
Enzyme.autodiff(Reverse, Const(coupled_springs), inputs...)
return dK
end
N = 200
args = make_args(N)
@time coupled_springs(args...) # compilation
@time enzyme_gradient(args...) # compilation
println("\nPrimal:")
@time coupled_springs(args...)
println("Enzyme:")
@time enzyme_gradient(args...)
Primal around 0.05 seconds, tablegen BLAS rules around 1.1 seconds, fallback BLAS rules around 3.5 seconds. cc @vchuravy @ZuseZ4
Edit: fixed a zero into a copy in the code example