Distributions.jl icon indicating copy to clipboard operation
Distributions.jl copied to clipboard

Inconsistent performance between logpdf and logpdf! for MvNormal.

Open Sahel13 opened this issue 1 year ago • 17 comments


Consider the following minimal example.

using LinearAlgebra, Random, Distributions, BenchmarkTools, Test

samples = randn(5, 1000)
dist = MvNormal(zeros(5), I)

function mod_logpdf(dist, samples)
    out = Array{Float64}(undef, size(samples, 2))
    logpdf!(out, dist, samples)
julia> @test logpdf(dist, samples) ≈ mod_logpdf(dist, samples)
Test Passed
julia> @btime logpdf($dist, $samples);
  89.290 μs (1001 allocations: 101.69 KiB)
julia> @btime mod_logpdf($dist, $samples);
  12.672 μs (3 allocations: 47.05 KiB)

logpdf is around 7x slower than mod_logpdf, even though they both do exactly the same thing.

Possible solution

Add a method

logpdf(d::AbstractMvNormal, x::AbstractMatrix{<:Real})

that does something like mod_logpdf.

Sahel13 avatar Sep 11 '23 17:09 Sahel13

That's extremely weird. Do you know what's causing the performance difference?

ParadaCarleton avatar Sep 14 '23 01:09 ParadaCarleton

It's interesting to note that this difference does not exist if samples is a Vector{Vector} instead of a matrix.

using LinearAlgebra, Random, Distributions, BenchmarkTools, Test

samples = [randn(5) for _ in 1:1000]
dist = MvNormal(zeros(5), I)

function mod_logpdf(dist, samples)
    out = Array{Float64}(undef, len(samples))
    logpdf!(out, dist, samples)
julia> @btime logpdf($dist, $samples);
  89.132 μs (1001 allocations: 101.69 KiB)
julia> @btime mod_logpdf($dist, $samples);
  88.147 μs (1001 allocations: 101.69 KiB)

So solely for the case where x is a matrix, logpdf! is unusually fast.

logpdf!(..., x::AbstractMatrix{<:Real}) calls a method in mvnormal.jl

function _logpdf!(r::AbstractArray{<:Real}, d::AbstractMvNormal, x::AbstractMatrix{<:Real})
    sqmahal!(r, d, x)
    c0 = mvnormal_c0(d)
    for i = 1:size(x, 2)
        @inbounds r[i] = c0 - r[i]/2

I'm guessing this function is the cause of this discrepancy, although why this is faster than the other methods, I do not know.

Sahel13 avatar Sep 14 '23 08:09 Sahel13

Could you run a profiler on this? You w should then see where it is spending the additional time.

simsurace avatar Sep 15 '23 10:09 simsurace

For logpdf

julia> Profile.clear()

julia> @profile (for _ in 1:1000; logpdf(dist, samples); end)

julia> Profile.print()
Overhead ╎ [+additional indent] Count File:Line; Function
  ╎86 @Base/task.jl:514; (::VSCodeServer.var"#62#63")()
  ╎ 86 @VSCodeServer/src/eval.jl:34; macro expansion
  ╎  86 @Base/essentials.jl:816; invokelatest(::Any)
  ╎   86 @Base/essentials.jl:819; #invokelatest#2
  ╎    86 @VSCodeServer/src/repl.jl:193; (::VSCodeServer.var"#109#111"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎     86 @Base/logging.jl:626; with_logger
  ╎    ╎ 86 @Base/logging.jl:514; with_logstate(f::Function, logstate::Any)
  ╎    ╎  86 @VSCodeServer/src/repl.jl:192; (::VSCodeServer.var"#110#112"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎    ╎   86 @VSCodeServer/src/repl.jl:229; repleval(m::Module, code::Expr, #unused#::String)
  ╎    ╎    86 @Base/Base.jl:68; eval
  ╎    ╎     86 @Base/boot.jl:370; eval
  ╎    ╎    ╎ 86 ...ot-9/usr/share/julia/stdlib/v1.9/Profile/src/Profile.jl:27; top-level scope
  ╎    ╎    ╎  86 REPL[20]:1; macro expansion
  ╎    ╎    ╎   86 @Distributions/src/common.jl:319; logpdf(d::IsoNormal, x::Matrix{Float64})
  ╎    ╎    ╎    86 @Base/abstractarray.jl:3263; map
  ╎    ╎    ╎     86 @Base/array.jl:711; collect_similar
  ╎    ╎    ╎    ╎ 86 @Base/array.jl:812; _collect(c::Distributions.EachVariate{1, Matrix{Float64}, Tuple{Base.OneTo{Int64}...
  ╎    ╎    ╎    ╎  86 @Base/array.jl:818; collect_to_with_first!
  ╎    ╎    ╎    ╎   86 @Base/array.jl:840; collect_to!(dest::Vector{Float64}, itr::Base.Generator{Distributions.EachVariate...
  ╎    ╎    ╎    ╎    1  @Base/generator.jl:44; iterate
  ╎    ╎    ╎    ╎     1  @Base/abstractarray.jl:1220; iterate
  ╎    ╎    ╎    ╎    ╎ 1  @Base/range.jl:891; iterate
 1╎    ╎    ╎    ╎    ╎  1  @Base/promotion.jl:499; ==
  ╎    ╎    ╎    ╎    85 @Base/generator.jl:47; iterate
 1╎    ╎    ╎    ╎     85 @Base/operators.jl:1108; (::Base.Fix1{typeof(logpdf), IsoNormal})(y::SubArray{Float64, 1, Matrix{Float6...
  ╎    ╎    ╎    ╎    ╎ 84 @Distributions/src/common.jl:250; logpdf
  ╎    ╎    ╎    ╎    ╎  84 @Distributions/src/multivariate/mvnormal.jl:143; _logpdf
 1╎    ╎    ╎    ╎    ╎   1  @Base/float.jl:409; -
  ╎    ╎    ╎    ╎    ╎   9  @Distributions/src/multivariate/mvnormal.jl:101; mvnormal_c0
  ╎    ╎    ╎    ╎    ╎    9  @Distributions/src/multivariate/mvnormal.jl:263; logdetcov
  ╎    ╎    ╎    ╎    ╎     9  @PDMats/src/scalmat.jl:65; logdet
  ╎    ╎    ╎    ╎    ╎    ╎ 9  @Base/special/log.jl:267; log
 1╎    ╎    ╎    ╎    ╎    ╎  1  @Base/special/log.jl:0; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
 1╎    ╎    ╎    ╎    ╎    ╎  1  @Base/special/log.jl:270; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
 1╎    ╎    ╎    ╎    ╎    ╎  1  @Base/special/log.jl:275; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
  ╎    ╎    ╎    ╎    ╎    ╎  6  @Base/special/log.jl:277; _log(x::Float64, base::Val{:ℯ}, func::Symbol)
  ╎    ╎    ╎    ╎    ╎    ╎   2  @Base/special/log.jl:196; log_proc2
  ╎    ╎    ╎    ╎    ╎    ╎    2  @Base/operators.jl:578; *
 2╎    ╎    ╎    ╎    ╎    ╎     2  @Base/float.jl:410; *
  ╎    ╎    ╎    ╎    ╎    ╎   4  @Base/special/log.jl:215; log_proc2
  ╎    ╎    ╎    ╎    ╎    ╎    4  @Base/floatfuncs.jl:426; fma
 4╎    ╎    ╎    ╎    ╎    ╎     4  @Base/floatfuncs.jl:421; fma_llvm
  ╎    ╎    ╎    ╎    ╎   3  @Distributions/src/multivariate/mvnormal.jl:102; mvnormal_c0
 2╎    ╎    ╎    ╎    ╎    2  @Base/float.jl:408; +
  ╎    ╎    ╎    ╎    ╎    1  @Base/promotion.jl:413; /
 1╎    ╎    ╎    ╎    ╎     1  @Base/float.jl:411; /
  ╎    ╎    ╎    ╎    ╎   71 @Distributions/src/multivariate/mvnormal.jl:267; sqmahal(d::IsoNormal, x::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Sli...
 1╎    ╎    ╎    ╎    ╎    1  @Base/Base.jl:37; getproperty
  ╎    ╎    ╎    ╎    ╎    60 @Base/broadcast.jl:873; materialize
  ╎    ╎    ╎    ╎    ╎     58 @Base/broadcast.jl:898; copy
  ╎    ╎    ╎    ╎    ╎    ╎ 6  @Base/broadcast.jl:926; copyto!
  ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:970; copyto!
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/broadcast.jl:953; preprocess
  ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/broadcast.jl:956; preprocess_args
  ╎    ╎    ╎    ╎    ╎    ╎     1  @Base/broadcast.jl:957; preprocess_args
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1  @Base/broadcast.jl:954; preprocess
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:947; broadcast_unalias
  ╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/abstractarray.jl:1482; unalias
  ╎    ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/abstractarray.jl:1517; mightalias
  ╎    ╎    ╎    ╎    ╎    ╎    ╎     1  @Base/abstractarray.jl:1541; dataids
  ╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1  @Base/abstractarray.jl:1242; pointer
 1╎    ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/pointer.jl:65; unsafe_convert
  ╎    ╎    ╎    ╎    ╎    ╎  5  @Base/broadcast.jl:973; copyto!
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/simdloop.jl:72; macro expansion
 1╎    ╎    ╎    ╎    ╎    ╎    1  @Base/int.jl:83; <
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/simdloop.jl:76; macro expansion
  ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/simdloop.jl:54; simd_index
 1╎    ╎    ╎    ╎    ╎    ╎     1  @Base/int.jl:87; +
  ╎    ╎    ╎    ╎    ╎    ╎   3  @Base/simdloop.jl:77; macro expansion
  ╎    ╎    ╎    ╎    ╎    ╎    3  @Base/broadcast.jl:974; macro expansion
 2╎    ╎    ╎    ╎    ╎    ╎     2  @Base/array.jl:969; setindex!
  ╎    ╎    ╎    ╎    ╎    ╎     1  @Base/broadcast.jl:610; getindex
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 1  @Base/broadcast.jl:656; _broadcast_getindex
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:683; _broadcast_getindex_evalf
 1╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/float.jl:409; -
  ╎    ╎    ╎    ╎    ╎    ╎ 52 @Base/broadcast.jl:211; similar
  ╎    ╎    ╎    ╎    ╎    ╎  52 @Base/broadcast.jl:212; similar
  ╎    ╎    ╎    ╎    ╎    ╎   52 @Base/abstractarray.jl:883; similar
  ╎    ╎    ╎    ╎    ╎    ╎    52 @Base/abstractarray.jl:884; similar
  ╎    ╎    ╎    ╎    ╎    ╎     52 @Base/boot.jl:494; Array
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 52 @Base/boot.jl:486; Array
51╎    ╎    ╎    ╎    ╎    ╎    ╎  52 @Base/boot.jl:477; Array
  ╎    ╎    ╎    ╎    ╎     2  @Base/broadcast.jl:294; instantiate
  ╎    ╎    ╎    ╎    ╎    ╎ 2  @Base/broadcast.jl:512; combine_axes
  ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/abstractarray.jl:98; axes
 1╎    ╎    ╎    ╎    ╎    ╎   1  @Base/array.jl:149; size
  ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/broadcast.jl:517; broadcast_shape
  ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/broadcast.jl:523; _bcs
 1╎    ╎    ╎    ╎    ╎    ╎    1  @Base/broadcast.jl:529; _bcs1
  ╎    ╎    ╎    ╎    ╎    10 @PDMats/src/scalmat.jl:87; invquad
 3╎    ╎    ╎    ╎    ╎     3  @Base/float.jl:411; /
  ╎    ╎    ╎    ╎    ╎     7  @Base/reducedim.jl:995; sum
  ╎    ╎    ╎    ╎    ╎    ╎ 7  @Base/reducedim.jl:995; #sum#808
  ╎    ╎    ╎    ╎    ╎    ╎  7  @Base/reducedim.jl:999; _sum
  ╎    ╎    ╎    ╎    ╎    ╎   7  @Base/reducedim.jl:999; #_sum#810
  ╎    ╎    ╎    ╎    ╎    ╎    7  @Base/reducedim.jl:357; mapreduce
  ╎    ╎    ╎    ╎    ╎    ╎     7  @Base/reducedim.jl:357; #mapreduce#800
  ╎    ╎    ╎    ╎    ╎    ╎    ╎ 7  @Base/reducedim.jl:365; _mapreduce_dim
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  1  @Base/reduce.jl:433; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
 1╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/essentials.jl:13; getindex
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  3  @Base/reduce.jl:435; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
  ╎    ╎    ╎    ╎    ╎    ╎    ╎   2  @Base/number.jl:189; abs2
 2╎    ╎    ╎    ╎    ╎    ╎    ╎    2  @Base/float.jl:410; *
  ╎    ╎    ╎    ╎    ╎    ╎    ╎   1  @Base/reduce.jl:27; add_sum
 1╎    ╎    ╎    ╎    ╎    ╎    ╎    1  @Base/float.jl:408; +
  ╎    ╎    ╎    ╎    ╎    ╎    ╎  3  @Base/reduce.jl:436; _mapreduce(f::typeof(abs2), op::typeof(Base.add_sum), #unused#::IndexLin...
 3╎    ╎    ╎    ╎    ╎    ╎    ╎   3  @Base/int.jl:83; <
Total snapshots: 89. Utilization: 100% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task.

For mod_logpdf

julia> Profile.clear()

julia> @profile (for _ in 1:1000; mod_logpdf(dist, samples); end)

julia> Profile.print()
Overhead ╎ [+additional indent] Count File:Line; Function
  ╎32 @Base/task.jl:514; (::VSCodeServer.var"#62#63")()
  ╎ 32 @VSCodeServer/src/eval.jl:34; macro expansion
  ╎  32 @Base/essentials.jl:816; invokelatest(::Any)
  ╎   32 @Base/essentials.jl:819; #invokelatest#2
  ╎    32 @VSCodeServer/src/repl.jl:193; (::VSCodeServer.var"#109#111"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎     32 @Base/logging.jl:626; with_logger
  ╎    ╎ 32 @Base/logging.jl:514; with_logstate(f::Function, logstate::Any)
  ╎    ╎  32 @VSCodeServer/src/repl.jl:192; (::VSCodeServer.var"#110#112"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
  ╎    ╎   32 @VSCodeServer/src/repl.jl:229; repleval(m::Module, code::Expr, #unused#::String)
  ╎    ╎    32 @Base/Base.jl:68; eval
  ╎    ╎     32 @Base/boot.jl:370; eval
  ╎    ╎    ╎ 32 ...ot-9/usr/share/julia/stdlib/v1.9/Profile/src/Profile.jl:27; top-level scope
 1╎    ╎    ╎  32 REPL[23]:1; macro expansion
  ╎    ╎    ╎   2  /Users/sahel/Code/InsideOutSMC.jl/experiments/testing.jl:7; mod_logpdf(dist::IsoNormal, samples::Matrix{Float64})
  ╎    ╎    ╎    2  @Base/boot.jl:491; Array
 2╎    ╎    ╎     2  @Base/boot.jl:477; Array
  ╎    ╎    ╎   29 /Users/sahel/Code/InsideOutSMC.jl/experiments/testing.jl:8; mod_logpdf(dist::IsoNormal, samples::Matrix{Float64})
  ╎    ╎    ╎    29 @Distributions/src/common.jl:424; logpdf!
  ╎    ╎    ╎     29 @Distributions/src/multivariate/mvnormal.jl:146; _logpdf!(r::Vector{Float64}, d::IsoNormal, x::Matrix{Float64})
  ╎    ╎    ╎    ╎ 29 @Distributions/src/multivariate/mvnormal.jl:269; sqmahal!(r::Vector{Float64}, d::IsoNormal, x::Matrix{Float64})
  ╎    ╎    ╎    ╎  28 @Base/broadcast.jl:873; materialize
  ╎    ╎    ╎    ╎   28 @Base/broadcast.jl:898; copy
  ╎    ╎    ╎    ╎    20 @Base/broadcast.jl:926; copyto!
  ╎    ╎    ╎    ╎     20 @Base/broadcast.jl:973; copyto!
  ╎    ╎    ╎    ╎    ╎ 18 @Base/simdloop.jl:77; macro expansion
  ╎    ╎    ╎    ╎    ╎  18 @Base/broadcast.jl:974; macro expansion
  ╎    ╎    ╎    ╎    ╎   18 @Base/multidimensional.jl:670; setindex!
17╎    ╎    ╎    ╎    ╎    18 @Base/array.jl:971; setindex!
  ╎    ╎    ╎    ╎    ╎ 2  @Base/simdloop.jl:78; macro expansion
 2╎    ╎    ╎    ╎    ╎  2  @Base/int.jl:87; +
  ╎    ╎    ╎    ╎    8  @Base/broadcast.jl:211; similar
  ╎    ╎    ╎    ╎     8  @Base/broadcast.jl:212; similar
  ╎    ╎    ╎    ╎    ╎ 8  @Base/abstractarray.jl:883; similar
  ╎    ╎    ╎    ╎    ╎  8  @Base/abstractarray.jl:884; similar
  ╎    ╎    ╎    ╎    ╎   8  @Base/boot.jl:494; Array
  ╎    ╎    ╎    ╎    ╎    8  @Base/boot.jl:487; Array
 8╎    ╎    ╎    ╎    ╎     8  @Base/boot.jl:479; Array
  ╎    ╎    ╎    ╎  1  @PDMats/src/scalmat.jl:90; invquad!
  ╎    ╎    ╎    ╎   1  @PDMats/src/utils.jl:103; colwise_sumsqinv!(r::Vector{Float64}, a::Matrix{Float64}, c::Float64)
 1╎    ╎    ╎    ╎    1  @Base/range.jl:891; iterate
Total snapshots: 32. Utilization: 100% across all threads and tasks. Use the `groupby` kwarg to break down by thread and/or task.

There seems to be a lot more backtraces at materialize in logpdf, but I do not know how to infer why.

Sahel13 avatar Sep 16 '23 07:09 Sahel13

Hmm, do you have a flamegraph?

ParadaCarleton avatar Sep 16 '23 16:09 ParadaCarleton

I'm assuming the generated images are what you asked for (I'm currently working on my first project in Julia, so there's a lot to learn).





Sahel13 avatar Sep 16 '23 19:09 Sahel13

Ahh, I was suggesting you might want to look at the flamegraphs to see which lines specifically are the ones slowing down logpdf, sorry for not being clear about that. :sweat_smile:

Where is it spending most of its time?

ParadaCarleton avatar Sep 16 '23 19:09 ParadaCarleton

Sorry, my bad XD. This is the profile view plot for logpdf:


Most of the time seems to be taken up by sqmahal. Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf is doing) is slower than doing it for a matrix in one go (as done by logpdf!)? I can think of the Cholesky factorization of the covariance matrix having to be computed only once in the latter case, for example.

Sahel13 avatar Sep 18 '23 14:09 Sahel13

I can think of the Cholesky factorization of the covariance matrix having to be computed only once in the latter case, for example.

The factorization is only computed once upfront when you construct an MvNormal object.

devmotion avatar Sep 21 '23 21:09 devmotion

Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf is doing) is slower than doing it for a matrix in one go (as done by logpdf!)?

Yep, that would be it. It's creating way more arrays. Could you make a PR to fix this?

ParadaCarleton avatar Sep 21 '23 22:09 ParadaCarleton

Yes I can. But just a question, is it problematic if logpdf calls the mutating version underneath? I don't know your design principles behind this package, but if we want to perform logpdf to play well with autodiff, for example, we wouldn't want it to perform any in-place operations.

Sahel13 avatar Sep 25 '23 17:09 Sahel13

No, ideally we would not mix both paths, also eg for better compatibility with static arrays. Even though probably currently many methods don't work (in an optimized way) with static arrays.

devmotion avatar Sep 25 '23 17:09 devmotion

Another reason is that generally it is quite challenging and brittle when one starts to come up with heuristics for the output type.

devmotion avatar Sep 25 '23 17:09 devmotion

No, ideally we would not mix both paths, also eg for better compatibility with static arrays. Even though probably currently many methods don't work (in an optimized way) with static arrays.

Sorry, I don't understand whether you meant it's not a problem for logpdf to call the mutating version, or whether it's better it doesn't. If it is the case that you would prefer a completely non-mutating version, then I do not know how to write a faster implementation.

Sahel13 avatar Sep 25 '23 18:09 Sahel13

I meant that generally logpdf should be non-mutating, and in particular it should not make any assumptions about the type of the arrays it is called with and eg whether they are mutable or not.

devmotion avatar Sep 25 '23 18:09 devmotion

Ok, thanks for the clarification. Then I'm afraid I don't know a solution to this.

Sahel13 avatar Sep 25 '23 18:09 Sahel13

Most of the time seems to be taken up by sqmahal. Is it possible that computing the squared Mahalanobis distance for one vector at a time (which is what logpdf is doing) is slower than doing it for a matrix in one go (as done by logpdf!)?

Quick question, is sqmahal the main difference in time spent between logpdf and logpdf!? (You can benchmark both to see which lines make up most of the difference.) If it is, I think it should be possible to correct this by just doing all the calculations at once.

ParadaCarleton avatar Oct 03 '23 19:10 ParadaCarleton