Distributions.jl
Distributions.jl copied to clipboard
Evaluate loglikelihood using sufficient statistics
This is a feature request to implement the following methods:
loglikelihood(d::Normal, ss::NormalStats)
loglikelihood(d::AbstractMvNormal, ss::MvNormalStats)
This would be useful because many algorithms that use some iterative procedure also need to repeatedly evaluate the loglikelihood of e.g., a multivariate normal. If such an algorithm runs for I iterations and the data consists of N observations, the complexity is I*N. By computing the sufficient statistics once the complexity becomes I+N.
To be honest I'd mostly use this myself in Turing. For example, a Turing model like this
@model function demo(x)
μ ~ ...
Σ ~ ...
for i in axes(x, 2)
x[:, i] ~ MvNormal(μ, Σ)
end
end
would become
@model function demo(ss)
μ ~ ...
Σ ~ ...
Turing.@addlogprob! loglikelihood(MvNormal(μ, Σ), ss)
end
and would run a lot faster.
However, I can imagine that these methods could also be useful outside of Turing, which is why I opened the issue here.
If you think this is useful I'd be happy to open a PR or propose an implementation in this issue.
Since the sufficient statistics are internal unexported implementation details of the fit! pipeline, I think the correct approach is rather to add fast paths for loglikelihood(dist, x) whenever possible such as proposed in https://github.com/JuliaStats/Distributions.jl/pull/1490.
In the Turing example you would then just write something like
x ~ product_distribution(Fill(MvNormal(m, S), n))
and it would use the fast path automatically.
Since the sufficient statistics are internal unexported implementation details of the fit! pipeline
Perhaps I'm missing something, but suffstats is exported?
https://github.com/JuliaStats/Distributions.jl/blob/0c9367ca7a7549d46c12d05b0ee5ec8e5000bc13/src/Distributions.jl#L258
I think the correct approach is rather to add fast paths for loglikelihood(dist, x) whenever possible
But that would still scale in N, wouldn't it?
For example
using Distributions, BenchmarkTools
import PDMats, LinearAlgebra
function loglikelihood_suffstats(d::Distributions.AbstractMvNormal, ss::Distributions.MvNormalStats)
μ, Σ = Distributions.params(d)
x̄, S, n = ss.m, ss.s2, ss.tw
p = length(x̄)
return (
-n / 2 * (
p * log(2pi) +
LinearAlgebra.logdet(Σ) +
PDMats.invquad(Σ, x̄ .- μ) +
LinearAlgebra.tr(Σ \ S) / ss.tw
)
)
end
function loglikelihood_fastpath(d::Distributions.AbstractMvNormal, x)
loglikelihood_suffstats(d, suffstats(MvNormal, x))
end
n, p = 10_000, 30
pop_mu = randn(p)
pop_sds = abs.(randn(p))
d = MvNormal(pop_mu, pop_sds)
x = rand(d, n)
ss = suffstats(MvNormal, x)
@assert loglikelihood(d, x) ≈ loglikelihood_fastpath(d, x) ≈ loglikelihood_suffstats(d, ss)
yields these benchmarks
julia> @benchmark loglikelihood($d, $x)
BenchmarkTools.Trial: 1974 samples with 1 evaluation.
Range (min … max): 2.263 ms … 5.544 ms ┊ GC (min … max): 0.00% … 47.58%
Time (median): 2.350 ms ┊ GC (median): 0.00%
Time (mean ± σ): 2.530 ms ± 485.259 μs ┊ GC (mean ± σ): 3.54% ± 8.65%
█▆▄▁▁ ▂▃▂▁
██████▇▆▇▇█▇▅▆▇████▆▃▅▁▁▁▁▁▁▃▁▃▁▃▁▃▁▁▃▁▁▃▁▁▁▁▁▁▆▇▇▇▆▄▆▇▇▇▆▇ █
2.26 ms Histogram: log(frequency) by time 4.53 ms <
Memory estimate: 2.90 MiB, allocs estimate: 10005.
julia> @benchmark loglikelihood_fastpath($d, $x)
BenchmarkTools.Trial: 6029 samples with 1 evaluation.
Range (min … max): 716.741 μs … 4.813 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 746.936 μs ┊ GC (median): 0.00%
Time (mean ± σ): 825.997 μs ± 269.492 μs ┊ GC (mean ± σ): 5.12% ± 10.28%
█▇▆▅▄▂▂ ▁▁▁ ▁
██████████▇▇▇▆▆████▇▇▅▆▅▅▃▅▅▃▁▅▁▃▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▄▇██▇▇▆▃▅▃▇██ █
717 μs Histogram: log(frequency) by time 2.1 ms <
Memory estimate: 2.30 MiB, allocs estimate: 13.
julia> @benchmark loglikelihood_suffstats($d, $ss)
BenchmarkTools.Trial: 10000 samples with 48 evaluations.
Range (min … max): 843.854 ns … 48.019 μs ┊ GC (min … max): 0.00% … 94.57%
Time (median): 938.979 ns ┊ GC (median): 0.00%
Time (mean ± σ): 1.536 μs ± 3.269 μs ┊ GC (mean ± σ): 16.57% ± 7.62%
▁█
▃██▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▃▄▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▂
844 ns Histogram: frequency by time 3.01 μs <
Memory estimate: 7.48 KiB, allocs estimate: 2.
precomputing the sufficient statistics seems to perform best by quite a large margin. I also don't see any way that a fast path computation can avoid scaling with N, unless the compiler realizes that within x ~ product_distribution(Fill(MvNormal(m, S), n)) computing the sufficient statistics is constant and then it pulls it out of any loops (but that seems unlikely to me).
I've actually discussed this with Tor in the past, and I'd love to see Turing take advantage of this--ParetoSmooth.jl is often dramatically slowed down by not having a way to exploit sufficient statistics. Some models are impossible to do cross-validation for with Turing.jl, but can be done in a fraction of a second in Stan by exploiting sufficient statistics.