Distributions.jl icon indicating copy to clipboard operation
Distributions.jl copied to clipboard

MvNormal unnecessarily slow when mean and covariance matrix have different types

Open hersle opened this issue 8 months ago • 3 comments

My model has a MvNormal with a constant covariance matrix (just Float64s) and a differentiable mean vector (e.g. with ForwardDiff.Duals):

using Distributions, ForwardDiff, PDMats, LinearAlgebra, BenchmarkTools

Σ = PDMat(Matrix(1.0I, 1000, 1000)) # e.g. constant model-independent covariance matrix
μ = ForwardDiff.Dual.(1.0:1000.0) # e.g. differentiable model-dependent data vector

MvNormalSlow(μ, Σ) = MvNormal(μ, Σ) # standard constructor
MvNormalFast(μ, Σ) = MvNormal{Real, typeof(Σ), typeof(μ)}(μ, Σ) # construct manually

@btime MvNormalSlow(μ, Σ); # 1.789 ms (7 allocations: 15.26 MiB)
@btime MvNormalFast(μ, Σ); # 17.818 ns (1 allocation: 48 bytes)

The standard MvNormal constructor is unnecessarily slow. It seems to copy/refactorize covariance matrix, even though it is already factorized in my call to PDMat (which I use precisely because I know my covariance matrix is constant and I don't want to refactorize it for every model evaluation):

@profview [MvNormalSlow(μ, Σ) for i in 1:100]

Image

I think the reason is here. When μ and Σ are of different types (like here), it hits the second constructor, triggering a conversion and refactorization of the covariance matrix.

Is this intended? Is there a smooth fix to it?

hersle avatar Apr 10 '25 20:04 hersle

Is this intended?

It predates my times as a contributor, and it seems problematic in this case, but I'm quite certain that it is intentional. Having parameters of the same element types - as for most univariate distributions - makes it easier to reason about the "parameter type" of a distribution and to avoid accidental type instabilities (if in a special case, e.g., zero mean, a term or a whole branch of the computation is skipped).

I'm not sure what exactly your use case is, but it sounds a bit like it might be more efficient to shift your data/samples manually and to use a constant MvNormal, with zero mean (of Float64) and constant covariance matrix (of Float64).

devmotion avatar Apr 10 '25 21:04 devmotion

triggering a conversion and refactorization of the covariance matrix.

Note though that (as intended) it doesn't actually refactorize the matrix, it only converts the underlying factor of the Cholesky decomposition: Via https://github.com/JuliaStats/PDMats.jl/blob/dbeb7c7c28aeebfadfe2f875ca17af373b0fe40f/src/generics.jl#L9 -> https://github.com/JuliaStats/PDMats.jl/blob/dbeb7c7c28aeebfadfe2f875ca17af373b0fe40f/src/generics.jl#L8 -> https://github.com/JuliaStats/PDMats.jl/blob/dbeb7c7c28aeebfadfe2f875ca17af373b0fe40f/src/pdmat.jl#L69 -> https://github.com/JuliaStats/PDMats.jl/blob/dbeb7c7c28aeebfadfe2f875ca17af373b0fe40f/src/pdmat.jl#L66 -> https://github.com/JuliaStats/PDMats.jl/blob/dbeb7c7c28aeebfadfe2f875ca17af373b0fe40f/src/pdmat.jl#L28 -> https://github.com/JuliaStats/PDMats.jl/blob/dbeb7c7c28aeebfadfe2f875ca17af373b0fe40f/src/pdmat.jl#L18 -> https://github.com/JuliaLang/LinearAlgebra.jl/blob/0671a7be73d2a0aa82fb858c4342fe40dab56b8c/src/factorization.jl#L102 -> https://github.com/JuliaLang/LinearAlgebra.jl/blob/0671a7be73d2a0aa82fb858c4342fe40dab56b8c/src/cholesky.jl#L630 -> https://github.com/JuliaLang/LinearAlgebra.jl/blob/0671a7be73d2a0aa82fb858c4342fe40dab56b8c/src/cholesky.jl#L631

devmotion avatar Apr 10 '25 21:04 devmotion

makes it easier to reason about the "parameter type" of a distribution and to avoid accidental type instabilities

Thanks, I see what you mean. In this particular example I was thinking Distributions could go with μ and Σ as-is (without copying) and set the type parameter by promotion:

function MvNormalFast(μ::AbstractVector{<:Real}, Σ::AbstractPDMat{<:Real})
    R = Base.promote_eltype(μ, Σ)
    MvNormal{R, typeof(Σ), typeof(μ)}(μ, Σ)
end
@btime N = MvNormalFast(μ, Σ) # 17.618 ns (1 allocation: 48 bytes)
eltype(N) # ForwardDiff.Dual{Nothing, Float64, 0}

But maybe that would screw up something else?

I'm not sure what exactly your use case is, but it sounds a bit like it might be more efficient to shift your data/samples manually and to use a constant MvNormal, with zero mean (of Float64) and constant covariance matrix (of Float64).

This might indeed be a good workaround. I will try it out, thanks! Still, that means very subtle changes can lead to performance differences, which I think is unfortunate.

Note though that (as intended) it doesn't actually refactorize the matrix, it only converts the underlying factor of the Cholesky decomposition:

You are right. I guess it's just the conversion/copying that's slowing down and becomes noticeable since I have >1000x1000 matrices.

hersle avatar Apr 11 '25 09:04 hersle