Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Performance issues with repeated gemv! calls in for loop

Open gaurav-arya opened this issue 1 year ago • 2 comments

Minified example:

using Enzyme, LinearAlgebra

LinearAlgebra.BLAS.set_num_threads(1)
Enzyme.Compiler.bitcode_replacement!(false)

@inline function coupled_springs(K, m, x0, v0, T)
    Ktmp = zero(K)
    xtmp = zero(x0)
    vtmp = zero(v0)
    N = length(m)
    pX = pointer(xtmp)
    pY = pointer(vtmp)
    sX = 1
    sY = 1
    pA = pointer(Ktmp)
    for j in 1:5000
        # @inline mul!(vtmp, Ktmp, xtmp, dt, 1) 
        GC.@preserve Ktmp xtmp vtmp ccall((:dgemv_64_, LinearAlgebra.BLAS.libblastrampoline), Cvoid,
        (Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{Float64},
        Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt},
        Ref{Float64}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Clong),
        'N', N, N, 0.001,
        pA, N, pX, sX,
        1.0, pY, sY, 1)
    end
    return @inbounds xtmp[1]
end

function make_args(N)
    K = rand(N, N)
    K[diagind(K)] .= 0
    m = 0.5 .+ 0.5 * rand(N)
    x0 = randn(N)
    v0 = zeros(N)
    T = 1.0
    return K, m, x0, v0, T
end

@inline function enzyme_inputs(K, m, x0, v0, T)
    dK = zero(K)
    dm = zero(m)
    dx0 = zero(x0)
    dv0 = zero(v0)
    return Duplicated(K, dK), Duplicated(m, dm), Duplicated(x0, dx0), Duplicated(v0, dv0), Const(T)
end

function enzyme_gradient(args...)
    inputs = enzyme_inputs(args...)
    dK = inputs[1].dval
    Enzyme.autodiff(Reverse, Const(coupled_springs), inputs...)
    return dK
end

N = 200
args = make_args(N)
@time coupled_springs(args...) # compilation
@time enzyme_gradient(args...) # compilation

println("\nPrimal:")
@time coupled_springs(args...) # around 0.02 seconds
println("Enzyme:")
@time enzyme_gradient(args...) # around 1.1 seconds

A vtunes profiler of the enzyme gradient call reveals that most time is spent in memcpy: image On Julia 1.10-beta1 with Enzyme main.

gaurav-arya avatar Aug 15 '23 23:08 gaurav-arya

It seems like the slowdown only occurs with the tbl-gen blas rules. With the fallback Enzyme.Compiler.bitcode_replacement!(true) it runs in around twice the time as the primal, while with tblgen Enzyme.Compiler.bitcode_replacement!(false) as above it runs about 50x slower.

gaurav-arya avatar Aug 16 '23 01:08 gaurav-arya

FYI, the MWE above was trivial in that the arrays for blas were always 0 and the primal was always 0. Below is a slightly more involved example with meaningful gradients (arrays not initialized to 0, and one more function in loop body). The big change is that the fallback BLAS rules also becomes slow in this case.

using Enzyme, LinearAlgebra

LinearAlgebra.BLAS.set_num_threads(1)
Enzyme.Compiler.bitcode_replacement!(true)

@inline function coupled_springs(K, m, x0, v0, T)
    Ktmp = copy(K)
    xtmp = copy(x0)
    vtmp = copy(v0)
    N = length(m)
    pX = pointer(xtmp)
    pY = pointer(vtmp)
    sX = 1
    sY = 1
    pA = pointer(Ktmp)
    dt = 1/5000
    for j in 1:5000
        # @inline mul!(vtmp, Ktmp, xtmp, dt, 1.0) 
        GC.@preserve Ktmp xtmp vtmp ccall((:dgemv_64_, LinearAlgebra.BLAS.libblastrampoline), Cvoid,
        (Ref{UInt8}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{LinearAlgebra.BLAS.BlasInt}, Ref{Float64},
        Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt},
        Ref{Float64}, Ptr{Float64}, Ref{LinearAlgebra.BLAS.BlasInt}, Clong),
        'N', N, N, dt,
        pA, N, pX, sX,
        1.0, pY, sY, 1)
        xtmp .+= vtmp .* dt 
    end
    return @inbounds xtmp[1]
end

function make_args(N)
    K = ones(N, N)
    K[diagind(K)] .= 0
    m = 0.5 .+ 0.5 * rand(N)
    x0 = float.(collect(1:N)) ./ N
    v0 = zeros(N)
    T = 1.0
    return K, m, x0, v0, T
end

@inline function enzyme_inputs(K, m, x0, v0, T)
    dK = zero(K)
    dm = zero(m)
    dx0 = zero(x0)
    dv0 = zero(v0)
    return Duplicated(K, dK), Duplicated(m, dm), Duplicated(x0, dx0), Duplicated(v0, dv0), Const(T)
end

function enzyme_gradient(args...)
    inputs = enzyme_inputs(args...)
    dK = inputs[1].dval
    Enzyme.autodiff(Reverse, Const(coupled_springs), inputs...)
    return dK
end

N = 200
args = make_args(N)
@time coupled_springs(args...) # compilation
@time enzyme_gradient(args...) # compilation

println("\nPrimal:")
@time coupled_springs(args...)
println("Enzyme:")
@time enzyme_gradient(args...)

Primal around 0.05 seconds, tablegen BLAS rules around 1.1 seconds, fallback BLAS rules around 3.5 seconds. cc @vchuravy @ZuseZ4

Edit: fixed a zero into a copy in the code example

gaurav-arya avatar Aug 21 '23 19:08 gaurav-arya