DistributionsAD.jl
DistributionsAD.jl copied to clipboard
Confused about arraydist and filldist
Some things I don't understand:
- Why does
arraydistexist? - Why is
arraydistin this package? It seemingly has nothing to do with autodiff. - When should I use
arraydistrather thanproduct_distribution? - Why is
filldistin this package? It also seemingly has nothing to do with autodiff.
These are mainly due to historical reasons. arraydist and filldist predate product_distribution.
Let's unify arraydist and filldist into a single, convenient API, or deprecate them in favour of product_distributions. We should also transfer all related code into DynamicPPL instead.
Sorry, I should probably have checked here before opening https://github.com/TuringLang/Turing.jl/issues/2613. Anyway, the TL;DR is that I benchmarked it and arraydist and filldist can largely be removed, sans a specialisation for arraydist of Normals which is much faster than Distributions.product_distribution.
Unfortunately I don't think we can completely remove them: filldist uses FillArrays.Fill which is lazy and is faster than instantiating product_distribution([dist for _ in ...]), so we should keep it (unless we want to instruct users to use FillArrays for performance -- but I feel that we may as well keep it just as a convenience method since the definition is only one line and requires no maintenance). As for arraydist there is the specialisation above, which could hopefully be fixed upstream in Distributions, but I wouldn't rely on that.
I would be v happy to move them to either DynamicPPL, or AbstractPPLDistributionsExt, which IMO is a slightly better fit. DynamicPPL is a bit of a heavy dependency for someone who might just want to pull one of those functions in. But my preference isn't strong, and if we only ever expect to use it in DynamicPPL or Turing, then chucking it in DynamicPPL makes perfect sense.
Reproducing my comment on that Turing issue.
My hypothesis is that we can completely replace filldist and arraydist with the following:
# Current definitions
filldist = DistributionsAD.filldist
arraydist = DistributionsAD.arraydist
# Alternative, much simpler, definitions
filldist2(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...))
arraydist2 = Distributions.product_distribution
To test this, I benchmarked rand, logpdf, and the gradient of logpdf (with our standard backends) using the existing implementation and the proposed implementation.
Benchmarking code (click to expand)
using Distributions
using DistributionsAD: DistributionsAD
using Chairmarks
using FillArrays: Fill
using DifferentiationInterface: AutoForwardDiff, AutoReverseDiff, AutoMooncake, prepare_gradient, gradient
import ForwardDiff
import ReverseDiff
import Mooncake
# Current definitions
filldist = DistributionsAD.filldist
arraydist = DistributionsAD.arraydist
# Alternative, much simpler, definitions
filldist2(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...))
arraydist2 = Distributions.product_distribution
# AD backends to test
backends = Dict(
"FD" => AutoForwardDiff(),
"RD" => AutoReverseDiff(),
"MC" => AutoMooncake(),
)
println("\n")
println("filldist")
println("========")
# Benchmark filldist
for dist in [Normal(), Beta(2, 2), MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]), Wishart(7, [1.0 0.5; 0.5 1.0])]
println("\n$(typeof(dist))")
fd = filldist(dist, 2, 3)
fd2 = filldist2(dist, 2, 3)
fd_rand = @be rand($fd)
fd2_rand = @be rand($fd2)
println(" rand filldist/filldist2: $(median(fd_rand).time/median(fd2_rand).time)")
r = rand(fd)
fd_logp = @be logpdf($fd, $r)
fd2_logp = @be logpdf($fd2, $r)
println(" logpdf filldist/filldist2: $(median(fd_logp).time/median(fd2_logp).time)")
if !(dist isa Wishart)
for (name, adtype) in backends
f = Base.Fix1(logpdf, fd)
f2 = Base.Fix1(logpdf, fd2)
prep = prepare_gradient(f, adtype, r)
prep2 = prepare_gradient(f2, adtype, r)
fd_grad_logp = @be gradient($f, $prep, $adtype, $r)
fd2_grad_logp = @be gradient($f2, $prep2, $adtype, $r)
println(" $name ∇logpdf filldist/filldist2: $(median(fd_grad_logp).time/median(fd2_grad_logp).time)")
end
end
end
println("\n")
println("arraydist")
println("=========")
# Benchmark arraydist
for dists in [
[Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)],
[Beta(2, 2), InverseGamma(2, 3), Normal()],
[MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]), MvNormal([2.0, 4.0], [1.0 0.5; 0.5 1.0])],
]
println("\ndistributions starting with: $(typeof(dists[1]))")
ad = arraydist(dists)
ad2 = arraydist2(dists)
ad_rand = @be rand($ad)
ad2_rand = @be rand($ad2)
println(" rand arraydist/arraydist2: $(median(ad_rand).time/median(ad2_rand).time)")
r = rand(ad)
ad_logp = @be logpdf($ad, $r)
ad2_logp = @be logpdf($ad2, $r)
println(" logpdf arraydist/arraydist2: $(median(ad_logp).time/median(ad2_logp).time)")
for (name, adtype) in backends
f = Base.Fix1(logpdf, ad)
f2 = Base.Fix1(logpdf, ad2)
prep = prepare_gradient(f, adtype, r)
prep2 = prepare_gradient(f2, adtype, r)
ad_grad_logp = @be gradient($f, $prep, $adtype, $r)
ad2_grad_logp = @be gradient($f2, $prep2, $adtype, $r)
println(" $name ∇logpdf arraydist/arraydist2: $(median(ad_grad_logp).time/median(ad2_grad_logp).time)")
end
end
Results
As can be verified from the benchmarking code above, the numbers reported are (existing / proposed), i.e., a number < 1 means that DistributionsAD is more efficient.
filldist
========
Normal{Float64}
rand filldist/filldist2: 1.0085484029484029
logpdf filldist/filldist2: 1.289915841908632
FD ∇logpdf filldist/filldist2: 1.1131279951870339
RD ∇logpdf filldist/filldist2: 1.0306648131490719
MC ∇logpdf filldist/filldist2: 2.3058095238095238
Beta{Float64}
rand filldist/filldist2: 1.069383434601855
logpdf filldist/filldist2: 1.011111111111111
FD ∇logpdf filldist/filldist2: 1.0779808895335319
RD ∇logpdf filldist/filldist2: 1.0
MC ∇logpdf filldist/filldist2: 1.2243054833003848
FullNormal
rand filldist/filldist2: 1.0057641122777223
logpdf filldist/filldist2: 1.0158160093929138
FD ∇logpdf filldist/filldist2: 1.014606050887321
RD ∇logpdf filldist/filldist2: 0.9965654224944137
MC ∇logpdf filldist/filldist2: 1.0277829296338885
Wishart{Float64, PDMats.PDMat{Float64, Matrix{Float64}}, Int64}
rand filldist/filldist2: 1.0972668192219681
logpdf filldist/filldist2: 1.0186094420600857
arraydist
=========
distributions starting with: Normal{Float64}
rand arraydist/arraydist2: 0.8166216049657906
logpdf arraydist/arraydist2: 0.3643514465925137
FD ∇logpdf arraydist/arraydist2: 0.7815570133385658
RD ∇logpdf arraydist/arraydist2: 0.8234206667941607
MC ∇logpdf arraydist/arraydist2: 0.46417765631907404
distributions starting with: Beta{Float64}
rand arraydist/arraydist2: 1.0020691994572593
logpdf arraydist/arraydist2: 1.0045024289620155
FD ∇logpdf arraydist/arraydist2: 1.020335676643539
RD ∇logpdf arraydist/arraydist2: 1.0028416872089838
MC ∇logpdf arraydist/arraydist2: 0.9901037037037037
distributions starting with: FullNormal
rand arraydist/arraydist2: 1.7908701400560223
logpdf arraydist/arraydist2: 1.137092822555559
FD ∇logpdf arraydist/arraydist2: 0.8979361936193618
RD ∇logpdf arraydist/arraydist2: 1.0244971751412428
MC ∇logpdf arraydist/arraydist2: 1.041990500863558
The only case where DistributionsAD provides an obvious benefit is the first arraydist() case, i.e. a vector of Normals.
In fact, DistributionsAD.arraydist(Normal...) doesn't actually return a custom type. It returns Distributions.Product (which is actually deprecated in Distributions.jl).
In contrast, arraydist2 i.e. Distributions.product_distribution goes one step further and returns a MvNormal:
julia> arraydist([Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)])
Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}(v=Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=1.0, σ=1.0), Normal{Float64}(μ=2.0, σ=3.0), Normal{Float64}(μ=4.0, σ=0.2)])
julia> arraydist2([Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)])
DiagNormal(
dim: 4
μ: [0.0, 1.0, 2.0, 4.0]
Σ: [1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; 0.0 0.0 9.0 0.0; 0.0 0.0 0.0 0.04000000000000001]
)
One could argue then that in this case, Distributions.product_distribution has an inefficient implementation for Normals and should be fixed upstream in Distributions.jl. I reported this upstream: https://github.com/JuliaStats/Distributions.jl/issues/1989
Even in the current state, without any upstream fixes, I think we can:
- deprecate / remove DistributionsAD
- add my suggested definitions of
filldistandarraydistto DynamicPPL (or maybe AbstractPPLDistributionsExt) - add an
arraydistspecialisation for Normals
If the performance is fixed upstream in Distributions, then we can further remove the arraydist specialisation.
(ppl) pkg> st
Status `~/ppl/Project.toml`
[0ca39b1e] Chairmarks v1.3.1
[a0c0ee7d] DifferentiationInterface v0.7.1
[31c24e10] Distributions v0.25.120
[ced4e74d] DistributionsAD v0.6.58
[1a297f60] FillArrays v1.13.0
[f6369f11] ForwardDiff v1.0.1
[da2b9cff] Mooncake v0.4.137
[37e2e3b7] ReverseDiff v1.16.1
Happy to deprecate arraydist and filldist in favour of product_distribution, even they might be slower on Normal.
@yebai You'd like to actually just remove both of them and make users use product_distribution directly?
I was thinking that I wouldn't mind keeping them around indefinitely since it'd only be 2 or 3 lines of code, and thus avoid breaking user code.
Happy to deprecate then remove
We should remove specialised types, such as MatrixOfUnivariate, VectorOfMultivariate, FillVectorOfUnivariate, etc., and replace their use with product_distribution. This would allow us to eliminate autodiff rules for these specialised types.