DistributionsAD.jl icon indicating copy to clipboard operation
DistributionsAD.jl copied to clipboard

Confused about arraydist and filldist

Open mhauru opened this issue 10 months ago • 1 comments

Some things I don't understand:

  1. Why does arraydist exist?
  2. Why is arraydist in this package? It seemingly has nothing to do with autodiff.
  3. When should I use arraydist rather than product_distribution?
  4. Why is filldist in this package? It also seemingly has nothing to do with autodiff.

mhauru avatar Feb 07 '25 14:02 mhauru

These are mainly due to historical reasons. arraydist and filldist predate product_distribution.

Let's unify arraydist and filldist into a single, convenient API, or deprecate them in favour of product_distributions. We should also transfer all related code into DynamicPPL instead.

yebai avatar Apr 27 '25 22:04 yebai

Sorry, I should probably have checked here before opening https://github.com/TuringLang/Turing.jl/issues/2613. Anyway, the TL;DR is that I benchmarked it and arraydist and filldist can largely be removed, sans a specialisation for arraydist of Normals which is much faster than Distributions.product_distribution.

Unfortunately I don't think we can completely remove them: filldist uses FillArrays.Fill which is lazy and is faster than instantiating product_distribution([dist for _ in ...]), so we should keep it (unless we want to instruct users to use FillArrays for performance -- but I feel that we may as well keep it just as a convenience method since the definition is only one line and requires no maintenance). As for arraydist there is the specialisation above, which could hopefully be fixed upstream in Distributions, but I wouldn't rely on that.

I would be v happy to move them to either DynamicPPL, or AbstractPPLDistributionsExt, which IMO is a slightly better fit. DynamicPPL is a bit of a heavy dependency for someone who might just want to pull one of those functions in. But my preference isn't strong, and if we only ever expect to use it in DynamicPPL or Turing, then chucking it in DynamicPPL makes perfect sense.

penelopeysm avatar Jul 08 '25 00:07 penelopeysm

Reproducing my comment on that Turing issue.


My hypothesis is that we can completely replace filldist and arraydist with the following:

# Current definitions
filldist = DistributionsAD.filldist
arraydist = DistributionsAD.arraydist

# Alternative, much simpler, definitions
filldist2(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...))
arraydist2 = Distributions.product_distribution

To test this, I benchmarked rand, logpdf, and the gradient of logpdf (with our standard backends) using the existing implementation and the proposed implementation.

Benchmarking code (click to expand)
using Distributions
using DistributionsAD: DistributionsAD
using Chairmarks
using FillArrays: Fill
using DifferentiationInterface: AutoForwardDiff, AutoReverseDiff, AutoMooncake, prepare_gradient, gradient
import ForwardDiff
import ReverseDiff
import Mooncake

# Current definitions
filldist = DistributionsAD.filldist
arraydist = DistributionsAD.arraydist

# Alternative, much simpler, definitions
filldist2(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...))
arraydist2 = Distributions.product_distribution

# AD backends to test
backends = Dict(
    "FD" => AutoForwardDiff(),
    "RD" => AutoReverseDiff(),
    "MC" => AutoMooncake(),
)

println("\n")
println("filldist")
println("========")

# Benchmark filldist
for dist in [Normal(), Beta(2, 2), MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]), Wishart(7, [1.0 0.5; 0.5 1.0])]
    println("\n$(typeof(dist))")
    fd = filldist(dist, 2, 3)
    fd2 = filldist2(dist, 2, 3)
    fd_rand = @be rand($fd)
    fd2_rand = @be rand($fd2)
    println("       rand   filldist/filldist2: $(median(fd_rand).time/median(fd2_rand).time)")
    r = rand(fd)
    fd_logp = @be logpdf($fd, $r)
    fd2_logp = @be logpdf($fd2, $r)
    println("       logpdf filldist/filldist2: $(median(fd_logp).time/median(fd2_logp).time)")
    if !(dist isa Wishart)
        for (name, adtype) in backends
            f = Base.Fix1(logpdf, fd)
            f2 = Base.Fix1(logpdf, fd2)
            prep = prepare_gradient(f, adtype, r)
            prep2 = prepare_gradient(f2, adtype, r)
            fd_grad_logp = @be gradient($f, $prep, $adtype, $r)
            fd2_grad_logp = @be gradient($f2, $prep2, $adtype, $r)
            println("   $name ∇logpdf filldist/filldist2: $(median(fd_grad_logp).time/median(fd2_grad_logp).time)")
        end
    end
end

println("\n")
println("arraydist")
println("=========")

# Benchmark arraydist
for dists in [
    [Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)],
    [Beta(2, 2), InverseGamma(2, 3), Normal()],
    [MvNormal([0.0, 0.0], [1.0 0.5; 0.5 1.0]), MvNormal([2.0, 4.0], [1.0 0.5; 0.5 1.0])],
]
    println("\ndistributions starting with: $(typeof(dists[1]))")
    ad = arraydist(dists)
    ad2 = arraydist2(dists)
    ad_rand = @be rand($ad)
    ad2_rand = @be rand($ad2)
    println("       rand   arraydist/arraydist2: $(median(ad_rand).time/median(ad2_rand).time)")
    r = rand(ad)
    ad_logp = @be logpdf($ad, $r)
    ad2_logp = @be logpdf($ad2, $r)
    println("       logpdf arraydist/arraydist2: $(median(ad_logp).time/median(ad2_logp).time)")
    for (name, adtype) in backends
        f = Base.Fix1(logpdf, ad)
        f2 = Base.Fix1(logpdf, ad2)
        prep = prepare_gradient(f, adtype, r)
        prep2 = prepare_gradient(f2, adtype, r)
        ad_grad_logp = @be gradient($f, $prep, $adtype, $r)
        ad2_grad_logp = @be gradient($f2, $prep2, $adtype, $r)
        println("   $name ∇logpdf arraydist/arraydist2: $(median(ad_grad_logp).time/median(ad2_grad_logp).time)")
    end
end

Results

As can be verified from the benchmarking code above, the numbers reported are (existing / proposed), i.e., a number < 1 means that DistributionsAD is more efficient.

filldist
========

Normal{Float64}
       rand   filldist/filldist2: 1.0085484029484029
       logpdf filldist/filldist2: 1.289915841908632
   FD ∇logpdf filldist/filldist2: 1.1131279951870339
   RD ∇logpdf filldist/filldist2: 1.0306648131490719
   MC ∇logpdf filldist/filldist2: 2.3058095238095238

Beta{Float64}
       rand   filldist/filldist2: 1.069383434601855
       logpdf filldist/filldist2: 1.011111111111111
   FD ∇logpdf filldist/filldist2: 1.0779808895335319
   RD ∇logpdf filldist/filldist2: 1.0
   MC ∇logpdf filldist/filldist2: 1.2243054833003848

FullNormal
       rand   filldist/filldist2: 1.0057641122777223
       logpdf filldist/filldist2: 1.0158160093929138
   FD ∇logpdf filldist/filldist2: 1.014606050887321
   RD ∇logpdf filldist/filldist2: 0.9965654224944137
   MC ∇logpdf filldist/filldist2: 1.0277829296338885

Wishart{Float64, PDMats.PDMat{Float64, Matrix{Float64}}, Int64}
       rand   filldist/filldist2: 1.0972668192219681
       logpdf filldist/filldist2: 1.0186094420600857


arraydist
=========

distributions starting with: Normal{Float64}
       rand   arraydist/arraydist2: 0.8166216049657906
       logpdf arraydist/arraydist2: 0.3643514465925137
   FD ∇logpdf arraydist/arraydist2: 0.7815570133385658
   RD ∇logpdf arraydist/arraydist2: 0.8234206667941607
   MC ∇logpdf arraydist/arraydist2: 0.46417765631907404

distributions starting with: Beta{Float64}
       rand   arraydist/arraydist2: 1.0020691994572593
       logpdf arraydist/arraydist2: 1.0045024289620155
   FD ∇logpdf arraydist/arraydist2: 1.020335676643539
   RD ∇logpdf arraydist/arraydist2: 1.0028416872089838
   MC ∇logpdf arraydist/arraydist2: 0.9901037037037037

distributions starting with: FullNormal
       rand   arraydist/arraydist2: 1.7908701400560223
       logpdf arraydist/arraydist2: 1.137092822555559
   FD ∇logpdf arraydist/arraydist2: 0.8979361936193618
   RD ∇logpdf arraydist/arraydist2: 1.0244971751412428
   MC ∇logpdf arraydist/arraydist2: 1.041990500863558

The only case where DistributionsAD provides an obvious benefit is the first arraydist() case, i.e. a vector of Normals.

In fact, DistributionsAD.arraydist(Normal...) doesn't actually return a custom type. It returns Distributions.Product (which is actually deprecated in Distributions.jl).

In contrast, arraydist2 i.e. Distributions.product_distribution goes one step further and returns a MvNormal:

julia> arraydist([Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)])
Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}(v=Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=1.0, σ=1.0), Normal{Float64}(μ=2.0, σ=3.0), Normal{Float64}(μ=4.0, σ=0.2)])

julia> arraydist2([Normal(0.0), Normal(1.0), Normal(2.0, 3.0), Normal(4.0, 0.2)])
DiagNormal(
dim: 4
μ: [0.0, 1.0, 2.0, 4.0]
Σ: [1.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0; 0.0 0.0 9.0 0.0; 0.0 0.0 0.0 0.04000000000000001]
)

One could argue then that in this case, Distributions.product_distribution has an inefficient implementation for Normals and should be fixed upstream in Distributions.jl. I reported this upstream: https://github.com/JuliaStats/Distributions.jl/issues/1989

Even in the current state, without any upstream fixes, I think we can:

  1. deprecate / remove DistributionsAD
  2. add my suggested definitions of filldist and arraydist to DynamicPPL (or maybe AbstractPPLDistributionsExt)
  3. add an arraydist specialisation for Normals

If the performance is fixed upstream in Distributions, then we can further remove the arraydist specialisation.

(ppl) pkg> st
Status `~/ppl/Project.toml`
  [0ca39b1e] Chairmarks v1.3.1
  [a0c0ee7d] DifferentiationInterface v0.7.1
  [31c24e10] Distributions v0.25.120
  [ced4e74d] DistributionsAD v0.6.58
  [1a297f60] FillArrays v1.13.0
  [f6369f11] ForwardDiff v1.0.1
  [da2b9cff] Mooncake v0.4.137
  [37e2e3b7] ReverseDiff v1.16.1

penelopeysm avatar Jul 08 '25 00:07 penelopeysm

Happy to deprecate arraydist and filldist in favour of product_distribution, even they might be slower on Normal.

yebai avatar Jul 08 '25 21:07 yebai

@yebai You'd like to actually just remove both of them and make users use product_distribution directly?

I was thinking that I wouldn't mind keeping them around indefinitely since it'd only be 2 or 3 lines of code, and thus avoid breaking user code.

penelopeysm avatar Jul 08 '25 23:07 penelopeysm

Happy to deprecate then remove

We should remove specialised types, such as MatrixOfUnivariate, VectorOfMultivariate, FillVectorOfUnivariate, etc., and replace their use with product_distribution. This would allow us to eliminate autodiff rules for these specialised types.

yebai avatar Jul 09 '25 07:07 yebai